In [None]:
import numpy as np
import pandas as pd
import random
import os
import time
import collections
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold, train_test_split
from tqdm.auto import tqdm

import torch
import torchvision
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [None]:
!pip install pycocotools
import pycocotools.mask as mask_util

In [None]:
VAL_SPLIT = 0.1
BATCH_SIZE = 4
NUM_EPOCHS = 20
BOX_DETECTIONS_PER_IMG = 540
MIN_SCORE = 0.5
TEST_FLAG = False

MOMENTUM = 0.9
LEARNING_RATE = 0.001
#WEIGHT_DECAY = 0.0005
WEIGHT_DECAY = 0.0
USE_SCHEDULER = True

TRAIN_CSV = '../input/sartorius-cell-instance-segmentation/train.csv'
TRAIN_PATH = '../input/sartorius-cell-instance-segmentation/train'
TEST_PATH = '../input/sartorius-cell-instance-segmentation/test'

ORIG_WIDTH = 704
ORIG_HEIGHT = 520
WIDTH = 520
HEIGHT = 520
MIN_BOX_SIDE = 5

MASK_THRESHOLD = 0.5
SEED = 2021

SCORE_THRESHOLDS = {
    'cort': 0.75,
    'shsy5y': 0.50,
    'astro': 0.55
}

MASK_THRESHOLDS = {
    'cort': 0.75,
    'shsy5y': 0.60,
    'astro': 0.55
}

SCORE_THRESHOLDS = {
    'cort': 0.50,
    'shsy5y': 0.50,
    'astro': 0.50
}

MASK_THRESHOLDS = {
    'cort': 0.50,
    'shsy5y': 0.50,
    'astro': 0.50
}

MIN_PIXELS = 75

In [None]:
# Override pythorch checkpoint with an "offline" version of the file
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/cocopre/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

In [None]:
def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_all_seeds(SEED)

In [None]:
# These are slight redefinitions of torch.transformation classes
# The difference is that they handle the target and the mask
# Copied from Abishek, added new ones
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-2)
        return image, target

class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-1)
        return image, target

class Normalize:
    def __call__(self, image, target):
        image = F.normalize(image, RESNET_MEAN, RESNET_STD)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target
    
class ContrastFromMiddle:
    def __call__(self, img, target):
        m2 = (img > 0.5) * (img - 0.5)
        m1 = (img <= 0.5) * (0.5 - img)
        conv = (img > 0.5) * (np.sin(m2 * np.pi) / 2 + 0.5) 
        conv += (img <= 0.5) * (0.5 - np.sin(m1 * np.pi) / 2)
        image = conv
        return image, target

class Rotate:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            k = random.randint(1, 3)
            height, width = image.shape[-2:]
            image = image.rot90(k, [1, 2])
            bbox = target["boxes"]
            x = bbox[:, [0, 2]]
            y = bbox[:, [1, 3]]
            if k == 1:
                bbox[:, [0, 2]] = y
                bbox[:, [3, 1]] = width - x
            elif k == 2:
                bbox[:, [2, 0]] = width - x
                bbox[:, [3, 1]] = height - y
            else:
                bbox[:, [2, 0]] = height - y
                bbox[:, [1, 3]] = x
            target["boxes"] = bbox
            target["masks"] = target["masks"].rot90(k, [1, 2])
        return image, target

class ColorChange:
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue
        
    def __call__(self, image, target):
        if self.brightness > 0:
            image = F.adjust_brightness(image, random.uniform(1 - self.brightness, 1 + self.brightness))
        if self.contrast > 0:
            image = F.adjust_contrast(image, random.uniform(1 - self.contrast, 1 + self.contrast))
        if self.saturation > 0:
            image = F.adjust_saturation(image, random.uniform(1 - self.saturation, 1 + self.saturation))
        if self.hue > 0:
            image = F.adjust_hue(image, random.uniform(-self.hue, self.hue))
        return image, target
    
def get_transform(train):
    transforms = [ToTensor()]
    # Data augmentation for train
    if train: 
        transforms.append(HorizontalFlip(0.50))
        transforms.append(VerticalFlip(0.50))
        transforms.append(Rotate(0.75))
        transforms.append(ColorChange(0.2, 0.2, 0.2, 0.1))

    return Compose(transforms)

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

def rle_encode(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))

def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

In [None]:
def precision_at(threshold, iou):
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1  # Correct objects
    false_positives = np.sum(matches, axis=0) == 0  # Missed objects
    false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
    tp, fp, fn = (
        np.sum(true_positives),
        np.sum(false_positives),
        np.sum(false_negatives),
    )
    return tp, fp, fn

def image_metric(pred, targ):
    enc_preds = [mask_util.encode(np.asarray(p, order='F')) for p in pred]
    enc_targs = [mask_util.encode(np.asarray(p, order='F')) for p in targ]
    ious = mask_util.iou(enc_preds, enc_targs, [0]*len(enc_targs))
    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        tp, fp, fn = precision_at(t, ious)
        p = tp / (tp + fp + fn)
        prec.append(p)
    return np.mean(prec)

In [None]:
class CellDataset(Dataset):
    def __init__(self, image_dir, df, trainFlag):
        self.transforms = get_transform(trainFlag)
        self.image_dir = image_dir
        self.df = df
        self.height = HEIGHT
        self.width = WIDTH
        self.trainFlag = trainFlag
        
        cell_type_list = {}
        self.image_info = {}
        temp_df = self.df.groupby('id')['annotation'].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                    'image_id': row['id'],
                    'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
                    'annotations': row["annotation"],
                    'cell_type': df[df['id'] == row['id']].iloc[0]['cell_type']
                    }
    
    def get_box(self, a_mask):
        ''' Get the bounding box of a given mask '''
        pos = np.where(a_mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]

    def __getitem__(self, idx):
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")
        info = self.image_info[idx]
        
        if self.trainFlag:
            shift = random.randint(0, ORIG_WIDTH - WIDTH)
            n_img = np.array(img)
            n_img = n_img[:, shift:shift + WIDTH, :]
            img = Image.fromarray(n_img)

        if self.trainFlag:
            masks = np.zeros((0, HEIGHT, WIDTH), dtype=np.uint8)
        else:
            masks = np.zeros((0, ORIG_HEIGHT, ORIG_WIDTH), dtype=np.uint8)
            
        boxes = []
        n_objects = 0
        for i, annotation in enumerate(info['annotations']):
            a_mask = rle_decode(annotation, (ORIG_HEIGHT, ORIG_WIDTH))
            a_mask = a_mask > 0
            if self.trainFlag:
                a_mask = a_mask[:, shift:shift + WIDTH]
            if a_mask.sum() > 0:
                # have non-empty mask after crop
                box = self.get_box(a_mask)
                if (box[2] - box[0] >= MIN_BOX_SIDE) and (box[3] - box[1] >= MIN_BOX_SIDE):
                    # ignore very small boxes
                    boxes.append(box)
                    a_mask = np.expand_dims(a_mask, axis=0)
                    masks = np.concatenate((masks, a_mask), axis=0)
                    n_objects += 1

        # dummy labels
        labels = [1] * n_objects
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        # This is the required target for the Mask R-CNN
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        if self.transforms is not None:
            img, target = self.transforms(img, target)
            
        return img, target

    def __len__(self):
        return len(self.image_info)

In [None]:
class CellTestDataset(Dataset):
    def __init__(self, image_dir, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.image_ids = [f[:-4]for f in os.listdir(self.image_dir)]
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.image_dir, image_id + '.png')
        image = Image.open(image_path).convert("RGB")

        if self.transforms is not None:
            image, _ = self.transforms(image=image, target=None)
        return {'image': image, 'image_id': image_id}

    def __len__(self):
        return len(self.image_ids)

In [None]:
def get_model():
    # This is just a dummy value for the classification head
    NUM_CLASSES = 2
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True, box_detections_per_img=BOX_DETECTIONS_PER_IMG)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, NUM_CLASSES)
    return model

In [None]:
def validate_model(model, ds_valid):
    # when in eval mode, model doesn't return loss_dict. It returns predictions instead (boxes, labels, scores, masks)
    model.train()
    
    loss_list = []
    loss_mask_list = []
    loss_by_type = collections.defaultdict(list)
    loss_by_type_mask = collections.defaultdict(list)
    
    for i in range(len(ds_valid)):
        images, targets = ds_valid[i]
        images = list(image.to(DEVICE) for image in [images])
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in [targets]]
    
        with torch.no_grad():
            loss_dict = model(images, targets)
            loss = sum(x for x in loss_dict.values())            
            loss_mask = loss_dict['loss_mask'].item()
            loss_list.append(loss.item())
            loss_mask_list.append(loss_mask)
            
            cell_type = ds_valid.image_info[i]['cell_type']
            loss_by_type[cell_type].append(loss.item())
            loss_by_type_mask[cell_type].append(loss_mask)
    
    return np.mean(loss_list), np.mean(loss_mask_list), loss_by_type, loss_by_type_mask

In [None]:
def calc_metric(model, ds_valid):
    model.eval()
    metric_list = []
    for i in range(len(ds_valid)):
        image, target = ds_valid[i]
        images = list(x.to(DEVICE) for x in [image])
        gt = target['masks']
        cell_type = ds_valid.image_info[i]['cell_type']
        
        with torch.no_grad():
            result = model(images)[0]
            found_masks = []
            for j, mask in enumerate(result["masks"]):
                score = result["scores"][j].cpu().item()
                if score < SCORE_THRESHOLDS[cell_type]:
                    continue
        
                mask = mask.cpu().numpy()[0]
                # Keep only highly likely pixels
                binary_mask = mask > MASK_THRESHOLDS[cell_type]
                binary_mask = remove_overlapping_pixels(binary_mask, found_masks)
                if binary_mask.sum() < MIN_PIXELS:
                    continue
                found_masks.append(binary_mask)

        metric = image_metric(gt, found_masks)
        metric_list.append(metric)
    return np.mean(metric_list)

In [None]:
def show_sample(ds_valid, sample_index):
    
    cell_type = ds_valid.image_info[sample_index]['cell_type']
    print('Cell type:', cell_type)
    img, target = ds_valid[sample_index]
    gt = target['masks']
        
    found_masks = []

    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.title("Image")
    plt.show()
    
    masks = gt[0]
    for mask in gt:
        masks = np.logical_or(masks, mask)
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.imshow(masks, alpha=0.3)
    plt.title("Ground truth")
    plt.show()

In [None]:
df_all_train = pd.read_csv(TRAIN_CSV, nrows=5000 if TEST_FLAG else None)
print('Samples loaded:', df_all_train.shape[0])
# remove images with bad masks: 
# https://www.kaggle.com/tolgadincer/sartorius-eda-general-overview-and-outliers
#df_all_train = df_all_train[~df_all_train['id'].isin(['ce5d0de993bd', 'a9fc5e872671', 'db5260527117'])].copy()
df_all_train = df_all_train[~df_all_train['id'].isin(['ce5d0de993bd'])].copy()
print('Samples remaining:', df_all_train.shape[0])

In [None]:
df_all_ids = df_all_train.groupby('id')['cell_type'].agg('max').reset_index()
train_ids, valid_ids = train_test_split(df_all_ids['id'].values, test_size=VAL_SPLIT, random_state=SEED, stratify=df_all_ids['cell_type'].values)
df_train = df_all_train[df_all_train['id'].isin(train_ids)]
df_valid = df_all_train[df_all_train['id'].isin(valid_ids)]

ds_train = CellDataset(TRAIN_PATH, df_train, trainFlag=True)
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

ds_valid = CellDataset(TRAIN_PATH, df_valid, trainFlag=False)

print('Train samples:', df_train.shape[0])
print('Validation samples:', df_valid.shape[0])

In [None]:
show_sample(ds_train, 1)

In [None]:
DEVICE = torch.device('cuda') 
model = get_model()
model.to(DEVICE);

In [None]:
for param in model.parameters():
    param.requires_grad = True
    
params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.4)

n_batches = len(dl_train)

In [None]:
best_metric = 0

for epoch in range(NUM_EPOCHS):
    print(f"Starting epoch {epoch} of {NUM_EPOCHS}")
    model.train()    
    
    time_start = time.time()
    loss_accum = 0.0
    loss_mask_accum = 0.0
    
    for batch_idx, (images, targets) in enumerate(dl_train, 1):
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Logging
        loss_mask = loss_dict['loss_mask'].item()
        loss_accum += loss.item()
        loss_mask_accum += loss_mask
        
#        if batch_idx % 50 == 0:
#            print(f"  [Batch {batch_idx:3d} / {n_batches:3d}] Batch train loss: {loss.item():5.3f}. Mask-only loss: {loss_mask:5.3f}")
    
    if USE_SCHEDULER:
        lr_scheduler.step()
    
    # Train losses
    train_loss = loss_accum / n_batches
    train_loss_mask = loss_mask_accum / n_batches
    elapsed = time.time() - time_start
    
#    val_loss, val_mask_loss, loss_by_type, loss_by_type_mask = validate_model(model, ds_valid)

#    print('Mask-only_loss: {:.4f}, total loss: {:.4f}'.format(val_mask_loss, val_loss))
#    for cell_type in loss_by_type.keys():
#        print('Cell type: {:6}; mask-only_loss: {:.4f}, total loss: {:.4f}'.format(cell_type, np.mean(loss_by_type_mask[cell_type]), np.mean(loss_by_type[cell_type])))

    metric = calc_metric(model, ds_valid)
    print('Validation MAP IoU: {:.4f}'.format(metric))
    if metric > best_metric:
        best_metric = metric
        print('Saving a better model at epoch:', epoch)
        torch.save(model.state_dict(), 'pytorch_model.bin')

print('\nBest validation MAP IoU: {:.4f}'.format(best_metric))   

In [None]:
#torch.save(model.state_dict(), 'pytorch_model.bin')
with open('train_valid_ids.npz', 'wb') as outfile:
    np.savez(outfile, train_ids=train_ids, valid_ids=valid_ids)

In [None]:
# Plots: the image, The image + the ground truth mask, The image + the predicted mask
def analyze_train_sample(model, ds_train, sample_index):
    
    img, targets = ds_train[sample_index]
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.title("Image")
    plt.show()
    
    masks = np.zeros((HEIGHT, WIDTH))
    for mask in targets['masks']:
        masks = np.logical_or(masks, mask)
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.imshow(masks, alpha=0.3)
    plt.title("Ground truth")
    plt.show()
    
    model.eval()
    with torch.no_grad():
        preds = model([img.to(DEVICE)])[0]

    plt.imshow(img.cpu().numpy().transpose((1,2,0)))
    all_preds_masks = np.zeros((HEIGHT, WIDTH))
    for mask in preds['masks'].cpu().detach().numpy():
        all_preds_masks = np.logical_or(all_preds_masks, mask[0] > MASK_THRESHOLD)
    plt.imshow(all_preds_masks, alpha=0.4)
    plt.title("Predictions")
    plt.show()