In [None]:
import os
import re

import numpy as np
import pandas as pd
import scipy
from skimage.measure import regionprops
from scipy.io import loadmat
from skimage.io import imread
from skimage.transform import resize


def get_cell_idx_partition(df):
    """
    Return partition s.t. partition[i] = [start_i, end_i] are the
    starting and ending indices (closed interval) in df for the i-th FOV (indexing from 1).
    """
    starts = [0]
    for i, curr_cell_idx in enumerate(df['cellLabelInImage']):
        if i > 0 and curr_cell_idx < df['cellLabelInImage'].iloc[i - 1]:
            starts.append(i)
    partition = []
    for i, s in enumerate(starts):
        if i < len(starts) - 1:
            partition.append([s, starts[i + 1] - 1])
        else:
            partition.append([s, df.shape[0] - 1])
    return partition


def add_cell_locations(
        df, path_to_segmentation, shape_of_views=(2, 17), shape_of_each_view=(1024, 1024), verbose=True
):
    """
    Add three new columns to df: location coordinates (x, y) as well as which FOV is each cell in.
    """
    # get the splitting points of different FOVs
    partition = get_cell_idx_partition(df)
    centroid_x = []
    centroid_y = []
    cell_views = []
    
    ind = 1
    for view_j in range(shape_of_views[1]):
        for view_i in range(shape_of_views[0]):
            
            if view_j > shape_of_views[1]:
                ind = ind * -1
            
            view = view_j * shape_of_views[0] + view_i + 1
            
            if ind == 1:
                topleft = [view_i * shape_of_each_view[0], view_j * shape_of_each_view[1]]
            if ind == -1:
                start = view_j * shape_of_each_view[1]
                part1 = shape_of_each_view[1] * (shape_of_views[1] -1)
                topleft = [view_i * shape_of_each_view[0], part1 - start]
                    
            if verbose:
                print("Now at field of view {}, top-left coordinate is {}".format(view, topleft), flush=True)
            seg = scipy.io.loadmat(
                '{}point{}/segmentationParams_obj.mat'.format(path_to_segmentation, view)
            )['newLmod']
            # get unique labels, excluding zero
            unique_seg_labels = list(np.unique(seg.flatten()))[1:]
            # calculate centroids
            props = regionprops(seg)
            # unique labels should align with props
            assert len(unique_seg_labels) == len(props)
            # build dict of seg_label: x and seg_label: y
            seg_label_to_x = {}
            seg_label_to_y = {}
            for i in range(len(props)):
                seg_label_to_x[unique_seg_labels[i]] = props[i]['centroid'][0] + topleft[0]
                seg_label_to_y[unique_seg_labels[i]] = props[i]['centroid'][1] + topleft[1]
            # fill the centroids of this segment of df
            start, end = partition[view - 1]
            for i in range(start, end + 1):
                centroid_x.append(seg_label_to_x[df.iloc[i]['cellLabelInImage']])
                centroid_y.append(seg_label_to_y[df.iloc[i]['cellLabelInImage']])
                cell_views.append(view)
    # add new columns
    df['centroid_x'] = centroid_x
    df['centroid_y'] = centroid_y
    df['field_of_view'] = cell_views


def pad_to_square(img):
    """
    Pad the image to a square image.
    """
    h = img.shape[0]
    w = img.shape[1]
    if h < w:
        top_extra = (w - h) // 2
        bottom_extra = w - h - top_extra
        img = np.pad(img, [(top_extra, bottom_extra), (0, 0)])
    elif h > w:
        left_extra = (h - w) // 2
        right_extra = h - w - left_extra
        img = np.pad(img, [(0, 0), (left_extra, right_extra)])
    return img


def process_each_cell(whole_image, segmentation, bounding_box, label_in_image, max_height, max_width, method='resize'):
    """
    Given the whole image and the segmentation in a field of view,
    extract a certain cell with label_in_image using bounding_box.
    If method='resize', resize the image to shape (max_height, max_width),
    if method='pad', pad the boundaries by zero so the image is of shape (max_height, max_width).
    """
    # need to copy since we are setting some elements to zero
    img_inside_box = whole_image[bounding_box[0]:bounding_box[2], bounding_box[1]:bounding_box[3]].copy()
    seg_inside_box = segmentation[bounding_box[0]:bounding_box[2], bounding_box[1]:bounding_box[3]].copy()
    for i in range(img_inside_box.shape[0]):
        for j in range(img_inside_box.shape[1]):
            if seg_inside_box[i, j] != label_in_image:
                img_inside_box[i, j] = 0
    if method == 'resize':
        img_inside_box = pad_to_square(img_inside_box)
        res = np.array(
            resize(
                img_inside_box, (max_height, max_width), preserve_range=True
            )
        )
    elif method == 'pad':
        h_extra_top = (max_height - img_inside_box.shape[0]) // 2
        h_extra_bottom = max_height - h_extra_top - img_inside_box.shape[0]
        w_extra_left = (max_width - img_inside_box.shape[1]) // 2
        w_extra_right = max_width - w_extra_left - img_inside_box.shape[1]
        res = np.pad(img_inside_box, [(h_extra_top, h_extra_bottom), (w_extra_left, w_extra_right)])
        res = pad_to_square(res)
    else:
        raise NotImplementedError()
    return res


def process_images(
        df, load_path, channels=('HOECHST1', 'mem_CD45_Vim_HLA'),
        method='resize', max_height=40, max_width=40, verbose=True
):
    """
    First calculate the largest bounding box over all bounding boxes.
    Then process the image of each cell if its cluster.term it not in ['empty', 'mix', dirt]:
        - calculate its bounding box
        - set the non-cell area to be zero
        - if method='resize', resize the image to match the shape of (height, width)
            if method='pad', pad it with zeros on the boundary to match the shape of the largest bounding box.
    Save each cell as a numpy array in save_path/point=p_labels=l.npy
        where p is the point of view and l is its label in image.
    """

    partition = get_cell_idx_partition(df)
    partition_iter = iter(partition)
    all_folders = os.listdir(os.path.join(load_path, 'Images_singleChannel/'))

    if method == 'pad':
        # calculate largest bounding boxes
        max_height = float('-inf')
        max_width = float('-inf')
        if verbose:
            print("Calculating largest bounding boxes...", flush=True)
        for x in range(2, 11):
            for y in range(5, 16):
                if verbose:
                    print("Now at ({}, {})...".format(x, y), flush=True)
                # get folder name
                x_str, y_str = ('0' + str(x))[-2:], ('0' + str(y))[-2:]
                curr_folder = [name for name in all_folders if re.search('X' + x_str + '_Y' + y_str, name)]
                assert len(curr_folder) == 1
                curr_folder = curr_folder[0]

                # avoid bad cells
                start, end = next(partition_iter)
                good_cell_indices = set(
                    [df['cellLabelInImage'].iloc[i] for i in range(start, end + 1)
                     if df['cluster.term'].iloc[i] not in ['Other', 'mix']]
                )

                seg = scipy.io.loadmat(
                    os.path.join(
                        load_path, 'Images_singleChannel_0503seg',
                        curr_folder, 'H3_CD45_Vim_HLA_Mesmer0111Whole_AutoHist_mpp0.6/segmentationParams.mat',
                    )
                )['newLmod']

                # get unique labels, excluding zero
                unique_seg_labels = list(np.unique(seg.flatten()))[1:]
                # calculate bounding boxes
                props = regionprops(seg)
                # unique labels should align with props
                assert len(unique_seg_labels) == len(props)
                for i in range(len(props)):
                    if unique_seg_labels[i] not in good_cell_indices:
                        continue
                    bounding_box = props[i]['BoundingBox']
                    max_height = max(max_height, bounding_box[2] - bounding_box[0])
                    max_width = max(max_width, bounding_box[3] - bounding_box[1])
        if verbose:
            print('max_height={}, max_width={}'.format(max_height, max_width))

    if verbose:
        print('Processing images...', flush=True)

    # process images
    partition_iter = iter(partition)
    res = [[] for _ in channels]
    for x in range(2, 11):
        for y in range(5, 16):
            # get folder name
            x_str, y_str = ('0' + str(x))[-2:], ('0' + str(y))[-2:]
            curr_folder = [name for name in all_folders if re.search('X' + x_str + '_Y' + y_str, name)]
            assert len(curr_folder) == 1
            curr_folder = curr_folder[0]

            # avoid bad cells
            start, end = next(partition_iter)
            good_cell_indices = set(
                [df['cellLabelInImage'].iloc[i] for i in range(start, end + 1)
                 if df['cluster.term'].iloc[i] not in ['Other', 'mix']]
            )
            if verbose:
                print("Now at ({}, {})...".format(x, y), flush=True)

            seg = scipy.io.loadmat(
                os.path.join(
                    load_path, 'Images_singleChannel_0503seg',
                    curr_folder, 'H3_CD45_Vim_HLA_Mesmer0111Whole_AutoHist_mpp0.6/segmentationParams.mat',
                )
            )['newLmod']

            # get unique labels, excluding zero
            unique_seg_labels = list(np.unique(seg.flatten()))[1:]
            # calculate bounding boxes
            props = regionprops(seg)

            all_img_names = os.listdir(os.path.join(load_path, 'Images_singleChannel', curr_folder))
            for channel_id, channel in enumerate(channels):
                img_filename = [name for name in all_img_names if re.search(channel+'\W', name)]
                assert len(img_filename) == 1
                img_filename = img_filename[0]
                img = imread(
                    os.path.join(
                        load_path, 'Images_singleChannel', curr_folder, img_filename
                    )
                )

                for i in range(len(props)):
                    if unique_seg_labels[i] not in good_cell_indices:
                        continue
                    each_img = process_each_cell(
                        whole_image=img,
                        segmentation=seg,
                        bounding_box=props[i]['BoundingBox'],
                        label_in_image=unique_seg_labels[i],
                        max_height=max_height,
                        max_width=max_width,
                        method=method
                    )
                    res[channel_id].append(each_img)

    res = np.array(res)
    res = np.transpose(res, (1, 0, 2, 3))
    return res


if __name__ == '__main__':
    # read in raw feature data
    print('Reading in the data...')
    df_raw = pd.read_csv('../data/tonsil/processed_raw_data/all_clusters.csv', index_col=0)
    # clean dirty cells
    df_clean = df_raw.iloc[
        [i for i, label in enumerate(df_raw['cluster.term']) if label not in ['Other', 'mix']]
    ]
    # clean unused columns
    df_clean = df_clean.drop(['PointNum', 'seurat_res1.0'], axis=1)

    print('The cleaned dataframe contain columns: {}'.format(list(df_clean.columns)))

    # add spatial coordinates (centroid_x, centroid_y) and the field_of_view for each cell
    print('Filling in spatial information...')
    add_cell_locations(
        df=df_clean,
        path_to_segmentation='../data/tonsil/processed_raw_data/Images_singleChannel_0503seg/',
        shape_of_each_view=(1008, 1344), verbose=True)

    df_clean.to_csv('../data/tonsil/processed_data/features_and_metadata.csv')

    df_clean.iloc[:50000, :].to_csv('../data/tonsil/processed_data/50k_features_and_metadata.csv')

    # # process single-cell level images
    print('Processing single-cell images...')
    processed_images = process_images(
        df=df_raw, load_path='../data/tonsil/processed_raw_data',
        channels=('HOECHST1', 'mem_CD45_Vim_HLA'),
        method='pad', max_height=None, max_width=None,
        verbose=True
    )
    np.save(file='../data/tonsil/processed_data/images_pad.npy', arr=processed_images)
    np.save(file='../data/tonsil/processed_data/50k_images_pad.npy', arr=processed_images[:50000])

