In [None]:
!pip install monai

# Table of contents
1. importing libraries
2. importing data
3. exploratory data analysis
4. build ML pipeline
    - define torch.Dataset & split train/val/test data
    - define model architecture
    - define hyperparameters, loss, and optimization method
5. Result analysis
6. Submission

# Import libraries

In [None]:
import os, glob, json, cv2, tqdm, copy
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image

import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms

from monai.networks.nets import UNet
from monai.losses import Dice

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Import data

In [None]:
data_dir = '../input/sartorius-cell-instance-segmentation/'
df_train = pd.read_csv(os.path.join(data_dir, 'train.csv'))

Collect annotations

In [None]:
dict_annotation = {}
for sample_id, annotation in df_train[['id', 'annotation']].values:
    dict_annotation.setdefault(sample_id, [])
    dict_annotation[sample_id].append(annotation)

Define function for encoding/decoding annotations

In [None]:
def decode_annotation(annotations, shape = (520, 704)):
    mask = np.zeros(np.prod(shape), dtype = np.uint8)
    for annotation in annotations:
        annotation = annotation.split()
        list_1s = [(int(start)-1, int(start)-1 + int(length)) for start, length in zip(annotation[0:][::2], annotation[1:][::2])]
        for start, end in list_1s:
            mask[start:end] = 1
    mask = mask.reshape(shape)
    return mask

def encode_annotation(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0]
    runs[1::2] -= runs[::2]
    runs[::2] = runs[::2] + 1
    return ' '.join(str(x) for x in runs)

# Exploratory Data Analysis (EDA)
1. sample cases
2. distributional analysis
    - how many unique ids are there?
    - image shape
    - cell types
3. cell types
    - how do they look different?

## Sample cases

In [None]:
df_train.info()

It seems like each data sample has meta data including id, annotation, image resolution (height and width), cell_type, plate_time, sample_data, sample_id, and elapsed_timedelta.<br>
Most helpful information seem like to be `id`, `annotation`, `image resolution` and `cell_type`.<br>
Other than id and annotation, the column name define what it means directly.
1. `id`: a unique identifier for each cell sample image
2. `annotation`: has information about the pixels where the neuron cells are located (also note that each sample image can have different number of annotation - each unique id has multiple number of rows and annotations)

Let's look at how annotation looks like in detail.

In [None]:
sample_id = df_train['id'][0]
sample_annotation = df_train.loc[df_train['id'] == sample_id, 'annotation'][0]
sample_annotation

In [None]:
starts = list(map(int, sample_annotation.split(' ')[0:][::2]))
lengths = list(map(int, sample_annotation.split(' ')[1:][::2]))
list_pair = [(s,l) for s,l in zip(starts, lengths)]
list_pair[0]

This means that the annotation begins from location 118145 and it has length of 6. From this, we can get (start, end) pair instead of (start,length) pair.

In [None]:
ends = [start+length for start, length in list_pair]
new_list_pair = [(s,e) for s,e in zip(starts, ends)]
new_list_pair[0]

From this new list pair, we can make masks by assigning 1s to the designated pixel locations (Note that the list of index, start and end, are the index from 1d flatten annotation mask image).

In [None]:
sample_image = np.array(Image.open(os.path.join(data_dir, 'train', f'{sample_id}.png')))
sample_mask = np.zeros(np.prod(sample_image.shape))
for start, end in new_list_pair:
    sample_mask[start:end] = 1

Then we need to reshape the mask into 2d.

In [None]:
sample_mask = sample_mask.reshape(sample_image.shape)

plt.imshow(sample_image, cmap = 'gray')
plt.imshow(sample_mask, alpha = 0.3)

Looks like the yellowish area represent the neuron cell. Furthermore, it looks like each row has 1 segmentation for 1 cell. And because this cell image has more than 1 cell, let's bring all segmentations.

In [None]:
image = Image.open(os.path.join(data_dir, 'train', f'{sample_id}.png'))
image_np = np.array(image)

annotations = dict_annotation[sample_id] # pre-defined earlier in the notebook for convenience
mask = decode_annotation(annotations) # this code is also pre-defined earlier in the notebook for conveinence

fig, axes = plt.subplots(1, 2, figsize = (30,15))
axes[0].imshow(image_np, cmap = 'gray')
axes[1].imshow(image_np, cmap = 'gray')
axes[1].imshow(mask, alpha = 0.1)

## Distributional analysis
1. How many unique ids are there?

In [None]:
df_train['id'].unique().shape

2. Image shape

In [None]:
df_train[['height', 'width']].value_counts()

3. Cell types

In [None]:
df_train['cell_type'].unique()

In [None]:
df_train.drop_duplicates('id')['cell_type'].value_counts()

Overall, there are 606 cell images with (520x702) image resolution. Cell images can be among cell_types cort, shsy5y, and astro.

## Cell types
1. how do they look different?

In [None]:
for cell_type in df_train['cell_type'].unique():
    sample_ids = df_train.loc[df_train['cell_type'] == cell_type, 'id'].unique()
    np.random.shuffle(sample_ids)
    fig, axes = plt.subplots(3,2, figsize = (10, 10))
    for i in range(3):
        sample_id = sample_ids[i]
        image = Image.open(os.path.join(data_dir, 'train', f'{sample_id}.png'))
        image_np = np.array(image)
        annotations_string = df_train.loc[df_train['id'] == sample_id, 'annotation'].values
        annotation = decode_annotation(annotations_string)
        im0 = axes[i,0].imshow(image, cmap = 'gray')
        divider = make_axes_locatable(axes[i,0])
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im0, cax = cax, orientation='vertical');
        axes[i,1].imshow(image, cmap = 'gray')
        axes[i,1].imshow(annotation, alpha = 0.2)
        axes[i,0].set_title(f'Raw Image (id: {sample_id})')
        axes[i,1].set_title(f'With Annotation')
    plt.suptitle(f'{cell_type}', fontsize = 20)
    plt.tight_layout()
    plt.show()

To sum up,
1. There are 606 input samples (images) with annotations. 
2. Each sample belongs to specific cell types among cort, shsy5y, and astro. 
3. It seems like cell images have different characteristics in cell shape and pixel intensity by each cell types
    - this fact maybe important when I do feature engineering (image pre-processing) such as normalization in the future.

In conclusion, just by looking around the data by eyes, all samples have same image resolution and they look 'okay' enough to distinguish the background and annotation by my eyes.<br>
For now, we can just go ahead build a model and feed the raw image and annotation itself and see how it works.

## Build dictionary for cell type

In [None]:
dict_cell_type = {}
for sample_id, cell_type in df_train[['id', 'cell_type']].values:
    dict_cell_type[sample_id] = cell_type
    
cell_type_mapping = {
    'shsy5y': 1,
    'astro': 2,
    'cort': 3
}
dict_cell_type_encoded = {sample_id: cell_type_mapping[cell_type] for sample_id, cell_type in dict_cell_type.items()}

# Build ML Pipeline
(1) define data pipeline (2) define model architecture (3) prepare for training experiment configurations (# of epoch, batch size, optimizer, loss function...)

## Define data pipeline

In [None]:
class CellDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 list_ids,
                 data_dir = data_dir,
                 dict_annotation = dict_annotation,
                 dict_cell_type_encoded = dict_cell_type_encoded,
                 target_shape = (520, 704),
                 image_trans = transforms.ToTensor(),
                 target_image_res = (224,224)
                ):
        self.list_ids = list_ids
        self.data_dir = data_dir
        self.dict_annotation = dict_annotation
        self.dict_cell_type_encoded = dict_cell_type_encoded
        self.target_shape = target_shape
        self.image_trans = image_trans
        self.resize_trans = transforms.Resize(target_image_res)
        self.target_image_res = target_image_res
    def __len__(self):
        return len(self.list_ids)
    def __getitem__(self, idx):
        cell_id = self.list_ids[idx]
        image = Image.open(os.path.join(self.data_dir, 'train', f'{cell_id}.png'))
        image_shape = image.size
        image = self.resize_trans(self.image_trans(image))
        annotations = torch.Tensor(decode_annotation(dict_annotation[cell_id], image_shape[::-1])).unsqueeze(0)
        annotations = (self.resize_trans(annotations) > 0.5).float()
        target = torch.zeros(4,*self.target_image_res) # number of channels are 4 including background
        target[self.dict_cell_type_encoded[cell_id]] = annotations
        target[0] = 1 - annotations
        return image, target

In [None]:
random_state = 42

list_ids = df_train['id'].unique().tolist()
train_ids, val_test_ids = train_test_split(list_ids, test_size = 0.2, random_state = random_state)
val_ids, test_ids = train_test_split(val_test_ids, test_size = 0.5, random_state = random_state)

dataset = {
    'train': CellDataset(train_ids),
    'val': CellDataset(val_ids),
    'test': CellDataset(test_ids),
}

for key in dataset:
    print(f'{key}: {len(dataset[key])}')

In [None]:
idx = 100
image, target = dataset['train'][idx]
cell_id = dataset['train'].list_ids[idx]
slice_idx = dataset['train'].dict_cell_type_encoded[cell_id]
plt.imshow(image[0], cmap = 'gray')
plt.imshow(target[slice_idx], alpha = 0.2)

## Define model architecture
I used pre-defined UNet architecture from [Monai](https://monai.io/) with 4 output channels (3 different cell types and additional 1 for background). Then, I added the final output layer to make the binary output. Overall, the final outputs of the model are 2 (one for multi-class prediction and the other for binary prediction).

### What is MONAI?
Quoting from the official website...
>"The MONAI framework is the open-source foundation being created by Project MONAI. MONAI is a freely available, community-supported, PyTorch-based framework for deep learning in healthcare imaging. It provides domain-optimized foundational capabilities for developing healthcare imaging training workflows in a native PyTorch paradigm."

So this MONAI has lots of useful tools mostly for medical image analysis based on PyTorch so that we don't have to manually define model architectures, loss functions, and so on. It's open source and we can also modify and debug codes ourselves.

In [None]:
class UNet2Step(nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = UNet(
            spatial_dims = 2,
            in_channels = 1,
            out_channels = 4, # 3 classes and 1 background
            channels = (16, 32, 64, 128, 256),
            kernel_size = 3,
            strides = (2,2,2,2),
            num_res_units = 2,
            act = 'PRELU',
            norm = 'INSTANCE',
            dropout = 0
        )
        self.output_conv = nn.Conv2d(4, 1, 3, 1, 1)
    def forward(self, x):
        x_multi = self.unet(x) # multi-class classification
        x = self.output_conv(x_multi)
        return x_multi, x

In [None]:
net = UNet2Step()

## Training configuration
For loss functions, I defined 2 different losses (IoU loss for multi-class prediction and IoU loss for binary prediction). To monitor the model performance, I was monitoring (1) binary IOU score and (2) binary DICE score.

In [None]:
EPOCH = 50
BATCH_SIZE = 5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# define dataloader
dataloaders = {
    'train': torch.utils.data.DataLoader(dataset['train'], batch_size = BATCH_SIZE, shuffle = True, num_workers = 2)
}
for split in ['val', 'test']:
    dataloaders[split] = torch.utils.data.DataLoader(dataset[split], batch_size = BATCH_SIZE, shuffle = False, num_workers = 2)

# define loss function
criterion_softmax = Dice(softmax = True, jaccard = True) # IOU (Jaccard index) Loss for multi-class segmentation
criterion_binary = Dice(sigmoid = True, jaccard = True) # IOU loss for binary segmentation and monitoring metric 1
criterion_binary_dice = Dice(sigmoid = True) # monitoring metric 2
# Define optimizer
net = net.to(DEVICE)
optimizer = torch.optim.Adam(net.parameters())

# Training

In [None]:
history = {
    split: {
        'loss_softmax': [], 'dice': [], 'loss': [], 'iou': []
    }
    for split in ['train', 'val', 'test']
}
for epoch in range(EPOCH):
    # model train
    pbar = tqdm.tqdm(total = len(dataloaders['train']), position = 0)
    pbar.set_description(f'Epoch (train): {epoch + 1}/{EPOCH}')
    net.train()
    list_loss = []
    list_loss_softmax = []
    list_dice = []
    list_iou = []
    for data, target in dataloaders['train']:
        net.zero_grad()
        # bring data
        data = data.to(DEVICE)
        target = target.to(DEVICE)
        # inference
        pred_multi, pred = net(data)
        # loss and weight update
        loss_multi = criterion_softmax(pred_multi, target)
        loss_binary = criterion_binary(pred, target[:,1:,:,:].sum(dim=1, keepdim = True))
        loss = loss_multi + loss_binary
        loss.backward()
        optimizer.step()
        # collect metrics and history
        list_loss.append(loss.item())
        list_loss_softmax.append(loss_multi.item())
        iou_score = 1-loss_binary
        dice_score = 1-criterion_binary_dice(pred, target[:,1:,:,:].sum(dim = 1, keepdim = True))
        list_dice.append(dice_score.item())
        list_iou.append(dice_score.item())
        pbar.update(1)
        pbar.set_postfix({'Loss': f'{np.mean(list_loss):.2f}', 'Loss_softmax': f"{np.mean(list_loss_softmax):.2f}", 'Dice': f"{np.mean(list_dice):.2f}", 'IOU': f"{np.mean(list_iou):.2f}"})
    history['train']['loss_softmax'].append(np.mean(list_loss_softmax))
    history['train']['loss'].append(np.mean(list_loss))
    history['train']['dice'].append(np.mean(list_dice))
    history['train']['iou'].append(np.mean(list_iou))
    pbar.close()
    # model eval and test
    net.eval()
    with torch.no_grad():
        for split in ['val', 'test']:
            pbar = tqdm.tqdm(total = len(dataloaders[split]), position = 0)
            pbar.set_description(f'Epoch ({split}): {epoch + 1}/{EPOCH}')
            list_loss = []
            list_loss_softmax = []
            list_metrics = []
            for data, target in dataloaders[split]:
                # bring data
                data = data.to(DEVICE)
                target = target.to(DEVICE)
                # inference
                pred_multi, pred = net(data)
                # loss and weight update
                loss_multi = criterion_softmax(pred_multi, target)
                loss_binary = criterion_binary(pred, target[:,1:,:,:].sum(dim=1, keepdim = True))
                loss = loss_multi + loss_binary
                # collect metrics and history
                list_loss.append(loss.item())
                list_loss_softmax.append(loss_multi.item())
                iou_score = 1-loss_binary
                dice_score = 1-criterion_binary_dice(pred, target[:,1:,:,:].sum(dim = 1, keepdim = True))
                list_dice.append(dice_score.item())
                list_iou.append(dice_score.item())
                pbar.update(1)
                pbar.set_postfix({'Loss': f'{np.mean(list_loss):.2f}', 'Loss_softmax': f"{np.mean(list_loss_softmax):.2f}", 'Dice': f"{np.mean(list_dice):.2f}", 'IOU': f"{np.mean(list_iou):.2f}"})
            history[split]['loss_softmax'].append(np.mean(list_loss_softmax))
            history[split]['loss'].append(np.mean(list_loss))
            history[split]['dice'].append(np.mean(list_dice))
            history[split]['iou'].append(np.mean(list_iou))
            pbar.close()
    # save best model
    if history['val']['dice'][-1] == max(history['val']['dice']):
        best_state_dict = copy.deepcopy(net.state_dict())
        print(f"Best model appeared with DICE score of {history['val']['dice'][-1]:.3f}")
    model_dict = {
        'curr_epoch': epoch,
        'best_state_dict': best_state_dict,
        'curr_state_dict': net.state_dict(),
        'history': history
    }
    torch.save(model_dict, 'model.pt')

In [None]:
best_state_dict = model_dict['best_state_dict']
net.load_state_dict(best_state_dict)

# Evaluation

## Training progress

In [None]:
history = model_dict['history']
for metric in ['loss', 'dice']:
    for split in history:
        plt.plot(history[split][metric], label = f'{split}-{metric}')
    plt.legend()
    plt.show()

## Case-level sample predictions

In [None]:
count = 10
for data, target in dataset['test']:
    count -= 1
    if count == 0:
        break
    data = data.unsqueeze(0).to(DEVICE)
    target = target.unsqueeze(0).to(DEVICE)
    pred_multi, pred = net(data)
    pred_sig = torch.sigmoid(pred)
    target = target[:,1:,:,:].max(dim = 1, keepdim = True).values
    
    fig, axes = plt.subplots(1,2,figsize = (10,5))
    axes[0].imshow(data[0][0].cpu().detach(), cmap = 'gray')
    axes[0].imshow(target[0][0].cpu().detach(), alpha = 0.3)
    axes[1].imshow(data[0][0].cpu().detach(), cmap = 'gray')
    axes[1].imshow(pred_sig[0][0].cpu().detach(), alpha = 0.3)
    plt.show()