# Introduction
This notebook is to compare the score whether the models are prepared for each cell type or not.
1. Build cell type classifier <-- shown in <a href=https://www.kaggle.com/yoshikuwano/classified-by-cell-types-before-segmentation-1-2>notebook (1/2)</a>)
1. Build segmentation model <-- this notebook

Ref. https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273

# Imports

In [None]:
import os, glob, time, random, collections
from tqdm.notebook import tqdm

import cv2
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from pprint import pprint
from sklearn.model_selection import train_test_split

import torch
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# Paramaters

In [None]:
INPUT_PATH = '../input/sartorius-cell-instance-segmentation'
CLASSIFIER_MODEL_PATH = '../input/classified-by-cell-types-before-segmentation-1-2/resnet34_crassifier.bin'

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

fix_all_seeds(2021)

# Image size
WIDTH = 704
HEIGHT = 520

NUM_EPOCHS = 5
BATCH_SIZE = 2

# For optimizer
MOMENTUM = 0.9
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0005
USE_SCHEDULER = False # Use a StepLR scheduler if True. Not tried yet.

MASK_THRESHOLD = 0.5 # Changes the confidence required for a pixel to be kept for a mask. 
BOX_DETECTIONS_PER_IMG = 539
MIN_SCORE = 0.59
RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)
NORMALIZE = False # Normalize to resnet mean and std if True. 

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('DEVICE: ', DEVICE)

# Utilities

In [None]:
# Transforms
class Compose_:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class VerticalFlip_:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-2)
        return image, target

class HorizontalFlip_:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-1)
        return image, target

class Normalize_:
    def __call__(self, image, target):
        image = F.normalize(image, RESNET_MEAN, RESNET_STD)
        return image, target

class ToTensor_:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target 

def get_transform(train):
    transforms = [ToTensor_()]
    if NORMALIZE:
        transforms.append(Normalize_())
    
    # Data augmentation for train
    if train: 
        transforms.append(HorizontalFlip_(0.5))
        transforms.append(VerticalFlip_(0.5))

    return Compose_(transforms)


# RLE string -> ndarray mask
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

# Input Data

In [None]:
df_train = pd.read_csv(INPUT_PATH + '/train.csv')
display(df_train)

In [None]:
cell_types = ['all']
cell_types.extend(df_train['cell_type'].unique())
print(f'cell_types: {cell_types}')

In [None]:
from sklearn.model_selection import train_test_split

# Separate dataset by cell type
image_ids = df_train['id'].unique()
df_shsy5y = df_train[df_train['cell_type']=='shsy5y']
df_astro  = df_train[df_train['cell_type']=='astro']
df_cort   = df_train[df_train['cell_type']=='cort']

train_idx, valid_idx = train_test_split(image_ids, test_size=0.2)

df_train_all   = df_train[df_train['id'].isin(train_idx)]
df_valid_all   = df_train[df_train['id'].isin(valid_idx)]
df_train_shsy5y = df_shsy5y[df_shsy5y['id'].isin(train_idx)]
df_valid_shsy5y = df_shsy5y[df_shsy5y['id'].isin(valid_idx)]
df_train_astro  = df_astro[df_astro['id'].isin(train_idx)]
df_valid_astro  = df_astro[df_astro['id'].isin(valid_idx)]
df_train_cort   = df_cort[df_cort['id'].isin(train_idx)]
df_valid_cort   = df_cort[df_cort['id'].isin(valid_idx)]

print('Number of records for each dataframe')
print(f'train        : {len(df_train_all)}')
print(f'valid        : {len(df_valid_all)}')
print(f'train shsy5y : {len(df_train_shsy5y)}')
print(f'valid shsy5y : {len(df_valid_shsy5y)}')
print(f'train astro  : {len(df_train_astro)}')
print(f'valid astro  : {len(df_valid_astro)}')
print(f'train cort   : {len(df_train_cort)}')
print(f'valid cort   : {len(df_valid_cort)}')

##  Dataset

In [None]:
class CellDataset(Dataset):
    
    def __init__(self, image_dir, df, transforms=None, resize=False):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = df
        
        self.should_resize = resize is not False
        if self.should_resize:
            self.height = int(HEIGHT * resize)
            self.width = int(WIDTH * resize)
        else:
            self.height = HEIGHT
            self.width = WIDTH
        
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')['annotation'].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                    'image_id': row['id'],
                    'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
                    'annotations': row["annotation"]
                    }
    
    def get_box(self, a_mask):
        ''' Get the bounding box of a given mask '''
        pos = np.where(a_mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]

    def __getitem__(self, idx):
        ''' Get the image and the target'''
        
        img_path = self.image_info[idx]['image_path']
        img = Image.open(img_path).convert('RGB')
        
        if self.should_resize:
            img = img.resize((self.width, self.height), resample=Image.BILINEAR)

        info = self.image_info[idx]

        n_objects = len(info['annotations'])
        masks = np.zeros((len(info['annotations']), self.height, self.width), dtype=np.uint8)
        boxes = []
        
        for i, annotation in enumerate(info['annotations']):
            a_mask = rle_decode(annotation, (HEIGHT, WIDTH))
            a_mask = Image.fromarray(a_mask)
            
            if self.should_resize:
                a_mask = a_mask.resize((self.width, self.height), resample=Image.BILINEAR)
            
            a_mask = np.array(a_mask) > 0
            masks[i, :, :] = a_mask
            
            boxes.append(self.get_box(a_mask))

        # dummy labels
        labels = [1 for _ in range(n_objects)]
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        # This is the required target for the Mask R-CNN
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.image_info)

In [None]:
# Datasets
ds_train_all    = CellDataset(INPUT_PATH + '/train', df_train_all, resize=False,
                              transforms = get_transform(train=True))
ds_valid_all    = CellDataset(INPUT_PATH + '/train', df_valid_all, resize=False,
                              transforms = get_transform(train=False))
ds_train_shsy5y = CellDataset(INPUT_PATH + '/train', df_train_shsy5y, resize=False,
                              transforms = get_transform(train=True))
ds_valid_shsy5y = CellDataset(INPUT_PATH + '/train', df_valid_shsy5y, resize=False,
                              transforms = get_transform(train=False))
ds_train_astro  = CellDataset(INPUT_PATH + '/train', df_train_astro, resize=False,
                              transforms = get_transform(train=True))
ds_valid_astro  = CellDataset(INPUT_PATH + '/train', df_valid_astro, resize=False,
                              transforms = get_transform(train=False))
ds_train_cort   = CellDataset(INPUT_PATH + '/train', df_train_cort, resize=False,
                              transforms = get_transform(train=True))
ds_valid_cort   = CellDataset(INPUT_PATH + '/train', df_valid_cort, resize=False,
                              transforms = get_transform(train=False))

train_datasets = [ds_train_all, ds_train_shsy5y, ds_train_astro, ds_train_cort]
valid_datasets = [ds_valid_all, ds_valid_shsy5y, ds_valid_astro, ds_valid_cort]

print(f'Number of train dataset : {len(ds_train_all)}')
print(f'Number of valid dataset : {len(ds_valid_all)}')
print(f'Number of shsy5y train dataset : {len(ds_train_shsy5y)}')
print(f'Number of shsy5y valid dataset : {len(ds_valid_shsy5y)}')
print(f'Number of astro train dataset  : {len(ds_train_astro)}')
print(f'Number of astro valid dataset  : {len(ds_valid_astro)}')
print(f'Number of cort train dataset   : {len(ds_train_cort)}')
print(f'Number of cort valid dataset   : {len(ds_valid_cort)}')

# Model

## Modeling

In [None]:
# Override pythorch checkpoint with an "offline" version of the file
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/cocopre/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

In [None]:
from torchvision.models.detection import maskrcnn_resnet50_fpn

def get_model():
    
    # This is just a dummy value for the classification head
    NUM_CLASSES = 2
    
    if NORMALIZE:
        model = maskrcnn_resnet50_fpn(pretrained=True, 
                                      box_detections_per_img=BOX_DETECTIONS_PER_IMG,
                                      image_mean=RESNET_MEAN, 
                                      image_std=RESNET_STD)
    else:
        model = maskrcnn_resnet50_fpn(pretrained=True,
                                      box_detections_per_img=BOX_DETECTIONS_PER_IMG)

    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)
    # Now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, NUM_CLASSES)
    
    return model

model_all = get_model()
model_shsy5y = get_model()
model_astro  = get_model()
model_cort   = get_model()
models = [model_all, model_shsy5y, model_astro, model_cort]

for model in models:
    for param in model.parameters():
        param.requires_grad = True

# Train 

In [None]:
def train(model, ds_train, ds_valid, cell_type=None, verbose=False):
    
    model.to(DEVICE)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    min_valid_loss = np.inf
    min_loss_epoch = 0
    for epoch in range(1, NUM_EPOCHS + 1):
        
        ### Data  Loader ###
        dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, 
                              num_workers=0, collate_fn=lambda x: tuple(zip(*x)))        
        dl_valid = DataLoader(ds_valid, batch_size=BATCH_SIZE, shuffle=False, 
                              num_workers=0, collate_fn=lambda x: tuple(zip(*x)))

        ### TRAIN ###
        model.train()
        time_start = time.time()
        loss_accum = 0.0
        loss_mask_accum = 0.0
        for batch_idx, (images, targets) in tqdm(enumerate(dl_train, 1),
                                                 total = len(dl_train),
                                                 desc = f'[Train] Epoch ({epoch}/{NUM_EPOCHS})'):

            # Forward
            images = list(image.to(DEVICE) for image in images)
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            loss = sum(loss for loss in loss_dict.values())
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Logging
            loss_mask = loss_dict['loss_mask'].item()
            loss_accum += loss.item()
            loss_mask_accum += loss_mask
            # Print
            if verbose:
                if batch_idx % 20 == 0:
                    prefix = f'    [Batch {batch_idx:3d}/{len(dl_train):3d}]'
                    print(f'{prefix} Train loss: {loss.item():7.3f}, Mask-only loss: {loss_mask:7.3f}')

        if USE_SCHEDULER:
            lr_scheduler.step()

        # Train losses
        train_loss = loss_accum / len(dl_train)
        train_loss_mask = loss_mask_accum / len(dl_train)
        elapsed = time.time() - time_start
        prefix = f'[Epoch {epoch:2d}/{NUM_EPOCHS:2d}]'
        print(f'{prefix} Train loss: {train_loss:7.3f}, Mean mask-only loss: {train_loss_mask:7.3f} [{elapsed:.0f} secs]')
        
        
        ### VALIDATION ###
        loss_accum = 0.0
        loss_mask_accum = 0.0
        with torch.no_grad():
            for batch_idx, (images, targets) in tqdm(enumerate(dl_valid, 1),
                                                     total = len(dl_valid),
                                                     desc = f'[Valid] Epoch ({epoch}/{NUM_EPOCHS})'):

                # Forward
                images = list(image.to(DEVICE) for image in images)
                targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
                loss_dict = model(images, targets)
                loss = sum(loss for loss in loss_dict.values())
                # Logging
                loss_mask = loss_dict['loss_mask'].item()
                loss_accum += loss.item()
                loss_mask_accum += loss_mask
                
                # Print
                if verbose:
                    if batch_idx % 20 == 0:
                        prefix = f'    [Batch {batch_idx:3d}/{len(dl_valid):3d}]'
                        print(f'{prefix} Valid loss: {loss.item():7.3f}, Mask-only loss: {loss_mask:7.3f}.')
                    
        if USE_SCHEDULER:
            lr_scheduler.step()

        # Loss per epoch
        valid_loss = loss_accum / len(dl_valid)
        valid_loss_mask = loss_mask_accum / len(dl_valid)
        elapsed = time.time() - time_start
        prefix = f'[Epoch {epoch:2d}/{NUM_EPOCHS:2d}]'
        print(f'{prefix} Valid loss: {valid_loss:7.3f}, Mean mask-only loss: {valid_loss_mask:7.3f} [{elapsed:.0f} secs]')
        if valid_loss < min_valid_loss:
            min_valid_loss = valid_loss
            min_loss_epoch = epoch
            torch.save(model.state_dict(), f'pytorch_model_{cell_type}.bin')    
        
    print(f'Minimum valid loss: {min_valid_loss:7.3f} at epoch-{min_loss_epoch}\n')

In [None]:
for i in range(4):
    print('#'*50)
    print(f'Cell type: {cell_types[i]}')
    train(models[i], train_datasets[i], valid_datasets[i], cell_types[i])

# Validate
Calculate scores for all cell type model and individual cell type models

## Load trained models

In [None]:
trained_models = []
for i in range(4):
    model = get_model()
    model.load_state_dict(torch.load(f'pytorch_model_{cell_types[i]}.bin'))
    trained_models.append(model)

## Score functions

In [None]:
# ndarray mask -> RLE stiring
def rle_encode(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))

def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

# Score
def compute_iou(true_mask, pred_mask, verbose=False):
    """
    Computes the IoU for instance labels and predictions.

    Args:
        true_mask:  ndarray (Height, Width)
        pred_mask:  ndarray (Height, Width)
        * including all objects which are labeled as 1, 2, ..., #objects
    Returns:
        np array: IoU matrix, of size true_objects x pred_objects.
    """

    num_true_objects = len(np.unique(true_mask))
    num_pred_objects = len(np.unique(pred_mask))

    if verbose:
        print("Number of true objects: {}".format(num_true_objects))
        print("Number of predicted objects: {}".format(num_pred_objects))

    # Compute intersection between all objects
    intersection = np.histogram2d(
        true_mask.flatten(), pred_mask.flatten(), bins=(num_true_objects, num_pred_objects)
    )[0]

    # Compute areas (needed for finding the union between all objects)
    area_true = np.histogram(true_mask, bins=num_true_objects)[0]
    area_pred = np.histogram(pred_mask, bins=num_pred_objects)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)

    # Compute union
    union = area_true + area_pred - intersection
    union[union == 0] = 1e-9 # Avoid divergent
    
    iou = intersection / union
    
    return iou[1:, 1:]  # exclude background

def precision_at(threshold, iou):
    """
    Computes the precision at a given threshold.

    Args:
        threshold (float): Threshold.
        iou (np array): IoU matrix.

    Returns:
        int: Number of true positives,
        int: Number of false positives,
        int: Number of false negatives.
    """
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1  # Correct objects
    false_positives = np.sum(matches, axis=0) == 0  # Missed objects
    false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
    tp, fp, fn = (
        np.sum(true_positives),
        np.sum(false_positives),
        np.sum(false_negatives),
    )
    return tp, fp, fn

def iou_map(ture_masks, pred_masks, verbose=False):
    """
    Computes the metric for the competition.
    Masks contain the segmented pixels where each object has one value associated,
    and 0 is the background.

    Args:
        ture_masks (list of masks): Ground truths.
        pred_masks (list of masks): Predictions.
        verbose (int, optional): Whether to print infos. Defaults to 0.

    Returns:
        float: mAP.
    """
    ious = [compute_iou(true_mask, pred_mask, verbose) for true_mask, pred_mask in zip(ture_masks, pred_masks)]

    if verbose:
        print("Thresh\tTP\tFP\tFN\tPrec.")

    precisions = []
    for t in np.arange(0.5, 1.0, 0.05):
        tps, fps, fns = 0, 0, 0
        for iou in ious:
            tp, fp, fn = precision_at(t, iou)
            tps += tp
            fps += fp
            fns += fn

        p = tps / (tps + fps + fns)
        precisions.append(p)

        if verbose:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tps, fps, fns, p))

    if verbose:
        print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(precisions)))

    return np.mean(precisions)


def predict_mask(model, img, mask_threshold=MASK_THRESHOLD):
    model.to(DEVICE)
    model.eval()
    with torch.no_grad():
        output = model([img.to(DEVICE)])[0]
    
    pred_masks = []
    for i, pred_mask in enumerate(output["masks"]):
        score = output['scores'][i].cpu().item()
        if score > MIN_SCORE:
            pred_mask = pred_mask.cpu().numpy().squeeze() # probability
            pred_mask = pred_mask > mask_threshold  # binalize
            pred_mask = remove_overlapping_pixels(pred_mask, pred_masks)
            pred_masks.append(pred_mask)

    return pred_masks

def get_score(model, dataset, mask_threshold=MASK_THRESHOLD, verbose=False):
    """
    Get average IOU mAP score for a dataset
    """
    score_cum = 0
    for i in range(len(dataset)):
        
        img, target = dataset[i]
        
        # Predicted masks
        pred_masks = predict_mask(model, img, mask_threshold)
        
        # combine all objects
        pred_masks_combined = np.zeros((HEIGHT, WIDTH))
        for m, mask in enumerate(pred_masks,1):
            pred_masks_combined[mask>0.5] = m

        # combine all objects
        pred_masks_combined = np.zeros((HEIGHT, WIDTH))
        for m, mask in enumerate(pred_masks,1):
            pred_masks_combined[mask>0.5] = m

        # True masks
        true_masks_combined = np.zeros((HEIGHT, WIDTH))
        for m, mask in enumerate(target['masks'],1):
            true_masks_combined[mask.cpu()>mask_threshold] = m
      
        score_cum += iou_map([true_masks_combined],[pred_masks_combined], verbose)
        
    return score_cum/len(dataset)

In [None]:
## Get scores

In [None]:
all_model_scores = []
individual_model_scores = []

for i in tqdm(range(1, 4)):
    print(f'Calculate the score for cell type {cell_types[i]}')
    score_all = get_score(trained_models[0], valid_datasets[i])
    score_ind = get_score(trained_models[i], valid_datasets[i])
    print(f'Score of all cell type model   : {score_all:7.3f}')
    print(f'Score of {cell_types[i]} model : {score_ind:7.3f}\n')
    all_model_scores.append(score_all)
    individual_model_scores.append(score_ind)

### Confirm predicted train mask

In [None]:
# Plots: the image, The image + the ground truth mask, The image + the predicted mask
from PIL import Image, ImageEnhance
def plot_prediction(model, ds_train, sample_index):
    
    # Predict mask
    img, targets = ds_train[sample_index]
    image_id = ds_train.image_info[0]['image_id']
    pred_masks = predict_mask(model, img)
    pred_masks_combined = np.zeros((HEIGHT, WIDTH))
    for mask in pred_masks:
        pred_masks_combined[mask>0.5] = 1
    
    fig, ax = plt.subplots(1, 3, figsize=(15,4))
    # Image (high contrast)
    img = img.numpy().transpose((1,2,0))
    img = (img*255).astype(np.uint8)
    img_hc = img.max() - img
    img_hc = np.asarray(ImageEnhance.Contrast(Image.fromarray(img_hc)).enhance(24))
    ax[0].imshow(img_hc)
    ax[0].set_title(f'Image: {image_id}')
    # Ground truth
    masks = np.zeros((HEIGHT, WIDTH))
    for mask in targets['masks']:
        masks = np.logical_or(masks, mask)
    ax[1].imshow(img_hc)
    ax[1].imshow(masks, alpha=0.3)
    ax[1].set_title("Ground truth")
    # Prediciton
    ax[2].imshow(img_hc)
    ax[2].imshow(pred_masks_combined, alpha=0.3)
    ax[2].set_title("Prediction")
    
    plt.show()

In [None]:
for i in range(1,4):
    print('#'*50)
    print(f'Prediction by all cell type model')
    plot_prediction(trained_models[0], valid_datasets[i], sample_index=1)
    print(f'Prediction by {cell_types[i]} model')
    plot_prediction(trained_models[i], valid_datasets[i], sample_index=1)

# Predict
predict test data by each cell type model 

## Test images

In [None]:
test_image_paths = glob.glob(INPUT_PATH + '/test/*.png')
pprint(test_image_paths)

## Classify images by cell types

### Load classifier
Pre-trained resnet34 model is loaded

In [None]:
classifier = torch.load(CLASSIFIER_MODEL_PATH)

### Dataset

In [None]:
from albumentations import Normalize, Resize, Compose
from albumentations.pytorch import ToTensorV2

class DatasetImageClassify(Dataset):
    def __init__(self, image_paths):
        self.image_paths = image_paths
        
    def __getitem__(self, idx):
        # image
        transforms = Compose([Resize(224, 224), 
                              Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1), 
                              ToTensorV2()])
        image_path = self.image_paths[idx]
        image = cv2.imread(image_path)
        image = transforms(image=image)['image']
        
        return {'image': image, 'image_path': image_path}

    def __len__(self):
        return len(self.image_paths)

In [None]:
ds_classify = DatasetImageClassify(test_image_paths )
dl_classify = DataLoader(ds_classify, batch_size=64, num_workers=0, pin_memory=True, shuffle=True)

### Classify test images

In [None]:
cell_list = ['shsy5y', 'astro', 'cort']
image_class = {}
classifier.eval()
cnt = 0
for data in dl_classify:
    images, image_paths = data['image'], data['image_path']
    images = images.to(DEVICE)
    outputs = classifier(images)
    cell_idx = [output.argmax().item() for output in outputs]
    for path, idx in zip(image_paths, cell_idx):
        image_class[cnt] = {'id': path.split('/')[-1].split('.')[0], 'cell_type': cell_list[idx]}
        cnt += 1
        
df_test = pd.DataFrame.from_dict(image_class, orient='index')
display(df_test)

In [None]:
# Test dataset
class CellTestDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
        
    def __getitem__(self, idx):
        image_id = self.df['id'].iloc[idx]
        image_path = INPUT_PATH + '/test/' + image_id + '.png'
        image = Image.open(image_path).convert("RGB")
        if self.transforms is not None:
            image, _ = self.transforms(image=image, target=None)
        return {'image': image, 'image_id': image_id}

    def __len__(self):
        return len(self.df)

In [None]:
df_test_all    = df_test.copy() #Not used
df_test_shsy5y = df_test[df_test['cell_type']=='shsy5y']
df_test_astro  = df_test[df_test['cell_type']=='astro']
df_test_cort   = df_test[df_test['cell_type']=='cort']

In [None]:
ds_test_all    = CellTestDataset(df_test_all,    transforms=get_transform(train=False))
ds_test_shsy5y = CellTestDataset(df_test_shsy5y, transforms=get_transform(train=False))
ds_test_astro  = CellTestDataset(df_test_astro,  transforms=get_transform(train=False))
ds_test_cort   = CellTestDataset(df_test_cort,   transforms=get_transform(train=False))
test_datasets = [df_test_all, ds_test_shsy5y, ds_test_astro, ds_test_cort]

## Run predict

In [None]:
def predict(model, dataset, mask_threshold=MASK_THRESHOLD):
    
    if len(dataset)==0:
        return None
    
    model.eval()
    model.to(DEVICE)
    submission = []
    for sample in dataset:
        img = sample['image']
        image_id = sample['image_id']
        with torch.no_grad():
            output = model([img.to(DEVICE)])[0]

        previous_masks = []
        for i, mask in enumerate(output["masks"]):
            score = output['scores'][i].cpu().item()
            if score > MIN_SCORE:
                mask = mask.cpu().numpy() # probability
                mask = mask > mask_threshold  # binalize
                mask = remove_overlapping_pixels(mask, previous_masks)
                previous_masks.append(mask)
                rle = rle_encode(mask)
                submission.append((image_id, rle))

        # Add empty prediction if no RLE was generated for this image
        all_images_ids = [image_id for image_id, rle in submission]
        if image_id not in all_images_ids:
            submission.append((image_id, ""))
        
    return submission

## Submission format

In [None]:
submissions = {}
count = 0
for i in range(1,4):
    submission = predict(trained_models[i], test_datasets[i])
    if not submission==None:
        for s in submission:
            submissions[count] = {'id': s[0], 'predicted': s[1]}
            count += 1
            
pd.DataFrame.from_dict(submissions, orient='index').to_csv('submission.csv', index=False)
display(pd.read_csv('submission.csv'))