# Purpose
This notebook creates a new cleaner data set from the images provided.  For each sample 4 new images will be created for each cell segmented out of the original 4 images.

For example, if ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0 has 3 cells:
```
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_red.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_green.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_blue.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_yellow.png
```
Will become:
```
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_red_1.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_red_2.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_red_3.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_green_1.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_green_2.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_green_3.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_blue_1.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_blue_2.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_blue_3.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_yellow_1.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_yellow_2.png
./ca035f36-bbc9-11e8-b2bc-ac1f6b6435d0_yellow_3.png
```

References:

https://www.kaggle.com/christopherworley/human-protein-atlas-segmentation
    
https://github.com/CellProfiling/HPA-Cell-Segmentation


In [None]:

import pandas as pd 
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from PIL import Image  
from IPython import display 

import os


In [None]:
!pip install 'git+https://github.com/haoxusci/pytorch_zoo@master#egg=pytorch_zoo'

In [None]:
!pip install https://github.com/CellProfiling/HPA-Cell-Segmentation/archive/master.zip

In [None]:
import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei

# Parameters

In [None]:
ROOT_PATH = '/kaggle/input/hpa-single-cell-image-classification/'

CHANNELS = np.array(['red', 'green', 'blue', 'yellow'])
CHANNEL_DESCRIPTIONS = np.array(['Microtubule Channels', 'Protein of Interest', 'Nuclei Channels', 'Endoplasmic Reticulum'])

CHANNEL_RED = 0    # microtubule channels
CHANNEL_GREEN = 1  # protein of interest
CHANNEL_BLUE = 2   # nuclei channels
CHANNEL_YELLOW = 3 # endoplasmic reticulum 

CHANNEL_SIZE = len(CHANNELS)
SAMPLE_SIZE = -1

# image size for single cell
CELL_IMAGE_SIZE = 2048

# Plot or do not plot images.  Plotting will slow things down
PLOT_OUTPUT = False

# data directory
OUTPUT_DIR = "./data"
if os.path.exists(OUTPUT_DIR) == False:
    os.mkdir(OUTPUT_DIR)

# Functions

In [None]:
#
# Get array of all images for given sample id
#
def get_images(id):
    images = list()
    for channel in CHANNELS:
        path = ROOT_PATH + 'train/{}_{}.png'.format(id, channel)
        image = Image.open(path) 
        images.append( np.asarray(image) )
    return images

In [None]:
#
# Get single image that blends all RGBY into RGB
#
def get_blended_image(id): 
    # get rgby images for sample
    images = get_images(id)

    # blend rgby images into single array
    blended_array = np.stack((
            np.maximum(images[0], images[3]),
            np.maximum(images[1], images[3]),
            images[2]
        ), 2)

    # Create PIL Image
    blended_image = Image.fromarray( np.uint8(blended_array) )
    return blended_image

In [None]:
#
# Crop given image and create new image with cropped image centered
#
def center_mass(image, top, right, bottom, left):
    centered_image = np.zeros((CELL_IMAGE_SIZE, CELL_IMAGE_SIZE))
    x1, y1 = int((centered_image.shape[0] / 2) - ((right-left)/2)), int((centered_image.shape[1] / 2) - ((bottom-top)/2))
    x2, y2 = x1 + (right - left), y1 + (bottom - top)
    centered_image[y1:y2, x1:x2] = image[top:bottom, left:right]
    return np.uint8(centered_image)

In [None]:
#
# Determine boundries of object in given array
#
def get_cell_bounds(image_array):
    top = -1
    right = -1
    bottom = -1
    left = -1
    
    # find upper and lower bounds
    for index in range(image_array.shape[0]):
        is_empty = np.sum( np.unique( image_array[index] ) ) == 0
        if top == -1 and is_empty == False:
            top = index
        elif is_empty == False:
            bottom = index
            
    # find left and right bounds
    for index in range(image_array.shape[1]):
        is_empty = np.sum( np.unique( image_array[:,index] ) ) == 0
        if left == -1 and is_empty == False:
            left = index
        elif is_empty == False:
            right = index
    
    return (top, right, bottom, left)

In [None]:
#
# Create mask for given images
#
def create_cell_mask(images):
        
    # Segment nuclie
    nuc_segmentations = segmentator.pred_nuclei([np.asarray( images[CHANNEL_BLUE] )])

    # For full cells
    cell_segmentations = cell_segmentations = segmentator.pred_cells([
        [np.asarray( images[CHANNEL_RED] )],
        [np.asarray( images[CHANNEL_YELLOW] )],
        [np.asarray( images[CHANNEL_BLUE] )]
    ])

    nuclei_mask, cell_mask = label_cell(nuc_segmentations[0], cell_segmentations[0])
    
    return cell_mask
    


# Read Trainning Data

In [None]:
df_train = pd.read_csv(ROOT_PATH + 'train.csv')
print("Trainning data length: {}".format(len(df_train)))
df_train.head()

In [None]:
# if sample size is set then reduce trainning set accordingly
if SAMPLE_SIZE > -1:
    df_train = df_train.sample(SAMPLE_SIZE)
    df_train.reset_index(inplace=True);

In [None]:
NUC_MODEL = "./nuclei-model.pth"
CELL_MODEL = "./cell-model.pth"
segmentator = cellsegmentator.CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device="cuda",
    padding=False,
    multi_channel_model=True,
)


## Main Loop Through Sampled Data

1. Loop through samples
2. Loop through all cells identified
3. Create new image with single cell centered
4. Write new cell image out

In [None]:

COLUMN_COUNT = CHANNEL_SIZE + 1
ROW_COUNT = 2

sample_count = 1

# setup figure to plot blened image and centered images for each channel
fig = plt.figure(figsize=(40,10))

for sample_index, df_sample in df_train.iterrows():
    images = get_images(df_sample['ID'])
    
    # Blended image
    print("\rGenerating blended image (count={}, current_id={})                  ".format(sample_count, df_sample['ID']), end="")
    blended_image = get_blended_image(df_sample['ID'])
    blended_array = np.asarray(blended_image)

    # Plot blended image
    if PLOT_OUTPUT == True:
        ax = fig.add_subplot(1, COLUMN_COUNT, 1)
        ax.set_title("Blended Cell")
        plt.imshow(np.asarray(blended_image))

        for channel_index in range(len(CHANNELS)):
            ax = fig.add_subplot(ROW_COUNT, COLUMN_COUNT, channel_index + 2)
            ax.set_title("Original {} {}".format(sample_count, CHANNEL_DESCRIPTIONS[channel_index]))
            plt.imshow(images[channel_index], cmap=plt.get_cmap('bone'))
        
    # Cell mask
    print("\rCreating single cell data set (count={}, current_id={})              ".format(sample_count, df_sample['ID']), end="")
    cell_mask = create_cell_mask(images)
    
    # Get unique vector of segment numbers
    numbers = np.unique(cell_mask)
    numbers = np.delete(numbers, [0])

    # Get cell bounds 
    (top, right, bottom, left) = get_cell_bounds(blended_array)
    
    for number in numbers:
        print("\rIsolating single cell (count={}, cell_number={}, current_id={})   ".format(sample_count, number, df_sample['ID']), end="")
        # Isolate the cell within the cell mask
        isolated_mask = np.where(cell_mask == number, True, False)
        isolated_mask_multi_channel = np.stack((isolated_mask, isolated_mask, isolated_mask), axis=2)
        
        # Get boundries of the isolated blended cell
        blended_isolated_image = np.where(isolated_mask_multi_channel == True, blended_array, 0)
        (top, right, bottom, left) = get_cell_bounds(blended_isolated_image)
        
        # create new single cell images for each channel
        index = 7
        for channel_index in range(len(CHANNELS)):
            isolated_image = np.where(isolated_mask == True, images[channel_index], 0)
            centered_image = center_mass(isolated_image, top, right, bottom, left)
            Image.fromarray(centered_image).save("./{}/{}_{}_{}.png".format(OUTPUT_DIR, df_sample['ID'], CHANNELS[channel_index], number), "PNG")

            if PLOT_OUTPUT == True:
                ax = fig.add_subplot(ROW_COUNT, COLUMN_COUNT, index)
                ax.set_title("Clean Cell {} {}".format(number, CHANNEL_DESCRIPTIONS[channel_index]))
                plt.imshow(centered_image, cmap=plt.get_cmap('bone'))
                index = index + 1

        if PLOT_OUTPUT == True:
            display.clear_output(wait=True)
            display.display(fig)
    print("\rSingle cell data set completed (count={}, current_id={})               ".format(sample_count, df_sample['ID']), end="")
    sample_count = sample_count + 1
