# Cell Segmentation using 3 Mask R-CNNs and Resnet18

Here, the approach we have taken is to use a Mask R-CNN model (of the state-of-the-art models for instance segmentation), which is based on Faster R-CNNs, to perform instance segmentation. This model, developed by the Facebook AI Research team, is based on a *instance first* strategy instead of *segmentation first* which has been done in other similar models, and has outperformed other approaches on instance segmentation task. 

The R-CNN paper can be found [here](https://arxiv.org/pdf/1703.06870.pdf).

It has mainly two phases:
* Region Proposal Network: Here multiple Regions-of-Interest (RoI) are generated by the models.
* Predicting class, box, and masks: From each RoI, features are extracted which are used to make the predictions.
 

The approach we take here is to first use a Resnet18 Classification model to predict the `cell_type`. We also train 3 Mask R-CNNs for the different `cell_type`s and based on the previous prediction, use one of them to get the masks.

The main intuition behind this is the fact that the shape and number of instances for each type is different, and an improvement is shown with this approach getting a score of 0.283 over the previous 0.275 (while using just a single Mask R-CNN).

# Imports

In [None]:
import os
import time
import random
import collections
import cv2

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from torchvision.models import resnet34
from torch.nn import CrossEntropyLoss
from albumentations import Normalize, Resize, Compose
from albumentations.pytorch import ToTensorV2

import torch.nn as nn

import wandb

import fastai
from fastai.vision.all import *

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("wandb_key")
wandb.login(key=wandb_key)

In [None]:
# run = wandb.init(
#     project="sartorius-cell-segmentation", 
#     entity="manikya", 
#     job_type="train",
#     name="maskRCNN_full_14/12_v5_OneCycleSched",
#     reinit=True,
#     resume=True)

In [None]:
SEED = 3011

def fix_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_seeds(SEED)

## Configuration

In [None]:
TRAIN_CSV = "../input/sartorius-cell-instance-segmentation/train.csv"
TRAIN_PATH = "../input/sartorius-cell-instance-segmentation/train"
TEST_PATH = "../input/sartorius-cell-instance-segmentation/test"

In [None]:
train_df = pd.read_csv(TRAIN_CSV)

In [None]:
train_df.width.unique(), train_df.height.unique()

In [None]:
WIDTH = 704
HEIGHT = 520

TEST = False

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

BATCH_SIZE = 2

MOMENTUM = 0.9
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0005

MASK_THRESHOLD = 0.6

NUM_EPOCHS = 10

BOX_DETECTIONS_PER_IMG = 539

MIN_SCORE = 0.59

In [None]:
df_train = pd.read_csv(TRAIN_CSV, nrows=3000 if TEST else None)

In [None]:
df_train.head()

# Loading the Classification model

We load a `fastai` learner for a resnet18 model trained to classify the cell type present in the images. This is then used later in the pipeline along with multiple R-CNNs to make mask predictions.

In [None]:
learn = load_learner("../input/cell-classification-helper-notebook/cell-classification-learner.pkl")

## Utilities


### Transformations

Some of the transformations have been referred from [mask r cnn utils](https://www.kaggle.com/abhishek/maskrcnn-utils) as it was not possible to direclty use the implementation for the transformaions from `torchvision`. The Instance Segmentation task here requires us to also transform the target bounding boxes along with the image, and thus custom transformations are required.

Here, all the transformations all take in `target` as well.

We referred to [this mask r cnn utils](https://www.kaggle.com/abhishek/maskrcnn-utils?select=transforms.py) package for the transformations.

In [None]:
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-2)
        return image, target

class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-1)
        return image, target

class ToTensorNew:
    def __call__(self, image, target):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, target
    

def get_transform(train):
    transforms = [ToTensorNew()]

    if train: 
        transforms.append(HorizontalFlip(0.5))
        transforms.append(VerticalFlip(0.5))

    return Compose(transforms)

In [None]:
torchvision.transforms.functional.to_tensor

In [None]:
def runlen_decoding(mask_rl, shape, color=1):
    '''
    mask_rl: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rl.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

In [None]:
train_df.iloc[0].annotation

In [None]:
runlen_decoding(train_df.iloc[0].annotation, (HEIGHT, WIDTH)).shape

## Training Dataset and DataLoader

In [None]:
temp = train_df.groupby('id')['annotation', 'cell_type'].agg(lambda x: list(x)).reset_index()

In [None]:
((np.unique(temp.cell_type[0]), len(temp.iloc[0].annotation)),
(np.unique(temp.cell_type[1]), len(temp.iloc[1].annotation)),
(np.unique(temp.cell_type[2]), len(temp.iloc[2].annotation)))

For defining the custom dataset, `torchvision`'s tutorial notebook (available [here](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html#)) for the PennFudan dataset was helpful.

In [None]:
df_train = pd.read_csv(TRAIN_CSV, nrows=3000 if TEST else None)

In [None]:
class NeuronalCellDataset(Dataset):
    def __init__(self, image_dir, df, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = df
        
        self.height = HEIGHT
        self.width = WIDTH
        
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')['annotation'].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                'image_id': row['id'],
                'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
                'annotations': row["annotation"]
                }

    def __getitem__(self, idx):
        
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")

        info = self.image_info[idx]

        n_objects = len(info['annotations'])
        masks = np.zeros((n_objects, self.height, self.width), dtype=np.uint8)
        boxes = []
        
        for i, annotation in enumerate(info['annotations']):
            a_mask = runlen_decoding(annotation, (HEIGHT, WIDTH))
            a_mask = Image.fromarray(a_mask)
            
            a_mask = np.array(a_mask) > 0
            masks[i, :, :] = a_mask
            
            pos = np.where(a_mask)
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            
            boxes.append([xmin, ymin, xmax, ymax])

        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        labels = torch.ones((n_objects,), dtype=torch.int64) # As there is only 1 class.
        
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        # Required target for the Mask R-CNN model
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.image_info)

In [None]:
ds_train = NeuronalCellDataset(TRAIN_PATH, df_train, transforms=get_transform(train=True))
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, 
                      num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

In [None]:
ds_train[0]

In [None]:
df_train.shape

In [None]:
df_train_shsy5y = df_train[df_train.cell_type=='shsy5y'].reset_index()
ds_train_shsy5y = NeuronalCellDataset(TRAIN_PATH, df_train_shsy5y, transforms=get_transform(train=True))
dl_train_shsy5y = DataLoader(ds_train_shsy5y, batch_size=BATCH_SIZE, shuffle=True, 
                      num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

In [None]:
df_train_cort = df_train[df_train.cell_type=='cort'].reset_index()
ds_train_cort = NeuronalCellDataset(TRAIN_PATH, df_train_cort, transforms=get_transform(train=True))
dl_train_cort = DataLoader(ds_train_cort, batch_size=BATCH_SIZE, shuffle=True, 
                      num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

In [None]:
df_train_astro = df_train[df_train.cell_type=='astro'].reset_index()
ds_train_astro = NeuronalCellDataset(TRAIN_PATH, df_train_astro, transforms=get_transform(train=True))
dl_train_astro = DataLoader(ds_train_astro, batch_size=BATCH_SIZE, shuffle=True, 
                      num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

In [None]:
df_train_cort.shape, df_train_shsy5y.shape, df_train_astro.shape

# Train loop

## Model

In [None]:
# Override pytorch checkpoint with an "offline" version of the file
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/cocopre/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

In [None]:
def get_model():
    
    # 1 class for the background, and one for the cell type. Here, each image will only have neurons of one type (cell_type).
    NUM_CLASSES = 2 
    
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                                  box_detections_per_img=BOX_DETECTIONS_PER_IMG)


    in_features = model.roi_heads.box_predictor.cls_score.in_features

    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)


    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256

    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, NUM_CLASSES)
    return model

In [None]:
# ['shsy5y', 'astro', 'cort']

In [None]:
# model_0_shsy5y = get_model()
# model_0_shsy5y.to(DEVICE);

# model_1_astro = get_model()
# model_1_astro.to(DEVICE);

# model_2_cort = get_model()
# model_2_cort.to(DEVICE);

In [None]:
model_0_shsy5y = get_model()
model_0_shsy5y.to(DEVICE);
model_0_shsy5y.load_state_dict(torch.load("../input/r-cnn-models-for-cell-segmentation/model_shsy5y-e10.bin", map_location=DEVICE))
model_0_shsy5y.eval()

model_1_astro = get_model()
model_1_astro.to(DEVICE);
model_1_astro.load_state_dict(torch.load("../input/r-cnn-models-for-cell-segmentation/model_astro-e10.bin", map_location=DEVICE))
model_1_astro.eval()

model_2_cort = get_model()
model_2_cort.to(DEVICE);
model_2_cort.load_state_dict(torch.load("../input/r-cnn-models-for-cell-segmentation/model_cort-e10.bin", map_location=DEVICE))
model_2_cort.eval();

In [None]:
temp_count = 0
for i, param in enumerate(model_2_cort.parameters()):
    if (not param.requires_grad):
        temp_count+=1
        print(i)
print(f'Count of frozen layers is {temp_count}')

In [None]:
for param in model_2_cort.parameters():
    param.requires_grad = True
    
model_2_cort.train();

for param in model_0_shsy5y.parameters():
    param.requires_grad = True
    
model_0_shsy5y.train();

for param in model_1_astro.parameters():
    param.requires_grad = True
    
model_1_astro.train();

The model has multiple heads for predicting the bounding boxes, classification, and instance masks.

In [None]:
model_1_astro.roi_heads

## Training

In [None]:
# wandb.config = {
#     "learning_rate": LEARNING_RATE,
#     "epochs": NUM_EPOCHS,
#     "batch_size": BATCH_SIZE,
#     "momentum": MOMENTUM,
#     "learning_rate": LEARNING_RATE,
#     "weight_decay": WEIGHT_DECAY,
#     "num_box_preds": BOX_DETECTIONS_PER_IMG,
#     "seed": SEED,
#     "scheduler": "OneCyclePolicy"
# }

In [None]:
params_1_astro = torch.nn.ParameterList([p for p in model_1_astro.parameters() if p.requires_grad])
optimizer_1_astro = torch.optim.SGD(params_1_astro, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [None]:
params_0_shsy5y = torch.nn.ParameterList([p for p in model_0_shsy5y.parameters() if p.requires_grad])
optimizer_0_shsy5y = torch.optim.SGD(params_0_shsy5y, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [None]:
params_2_cort = torch.nn.ParameterList([p for p in model_2_cort.parameters() if p.requires_grad])
optimizer_2_cort = torch.optim.SGD(params_2_cort, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [None]:
num_batches_cort = len(dl_train_cort)
lr_scheduler_cort = torch.optim.lr_scheduler.OneCycleLR(optimizer_2_cort, max_lr=0.01, steps_per_epoch=num_batchs_cort, epochs=NUM_EPOCHS)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"Starting epoch {epoch} of {NUM_EPOCHS}")
    
    time_start = time.time()
    loss_accum = 0.0
    loss_mask_accum = 0.0
    
    for batch_idx, (images, targets) in enumerate(dl_train_cort, 1):
    
        # Predict
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = model_2_cort(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        
        # Backprop
        optimizer_2_cort.zero_grad()
        loss.backward()
        optimizer_2_cort.step()
        lr_scheduler_cort.step()
        
        # Logging
        loss_mask = loss_dict['loss_mask'].item()
        loss_accum += loss.item()
        loss_mask_accum += loss_mask
        
        if batch_idx % 50 == 0:
            print(f"    [Batch {batch_idx:3d} / {num_batches_cort:3d}] Batch train loss: {loss.item():7.3f}. Mask-only loss: {loss_mask:7.3f}")
            
#         wandb.log({"batch_train_loss": loss.item(), "mask_loss":loss_mask})
#         wandb.watch(model)
    
    # Train losses
    train_loss = loss_accum / num_batches_cort
    train_loss_mask = loss_mask_accum / num_batches_cort
    
    elapsed = time.time() - time_start
    
    torch.save(model_2_cort.state_dict(), f"/kaggle/working/model_cort-e{epoch}.bin")
    
    prefix = f"[Epoch {epoch:2d} / {NUM_EPOCHS:2d}]"
    print(f"{prefix} Train mask-only loss: {train_loss_mask:7.3f}")
    print(f"{prefix} Train loss: {train_loss:7.3f}. [{elapsed:.0f} secs]")

In [None]:
num_batches_shsy5y = len(dl_train_shsy5y)
lr_scheduler_shsy5y = torch.optim.lr_scheduler.OneCycleLR(optimizer_0_shsy5y, max_lr=0.01, steps_per_epoch=num_batches_shsy5y, epochs=NUM_EPOCHS)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"Starting epoch {epoch} of {NUM_EPOCHS}")
    
    time_start = time.time()
    loss_accum = 0.0
    loss_mask_accum = 0.0
    
    for batch_idx, (images, targets) in enumerate(dl_train_shsy5y, 1):
    
        # Predict
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = model_0_shsy5y(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        
        # Backprop
        optimizer_0_shsy5y.zero_grad()
        loss.backward()
        optimizer_0_shsy5y.step()
        lr_scheduler_shsy5y.step()
        
        # Logging
        loss_mask = loss_dict['loss_mask'].item()
        loss_accum += loss.item()
        loss_mask_accum += loss_mask
        
        if batch_idx % 50 == 0:
            print(f"    [Batch {batch_idx:3d} / {num_batches_shsy5y:3d}] Batch train loss: {loss.item():7.3f}. Mask-only loss: {loss_mask:7.3f}")
            
#         wandb.log({"batch_train_loss": loss.item(), "mask_loss":loss_mask})
#         wandb.watch(model)
    
    # Train losses
    train_loss = loss_accum / num_batches_shsy5y
    train_loss_mask = loss_mask_accum / num_batches_shsy5y
    
    elapsed = time.time() - time_start
    
    torch.save(model_0_shsy5y.state_dict(), f"/kaggle/working/model_shsy5y-e{epoch}.bin")
    
    prefix = f"[Epoch {epoch:2d} / {NUM_EPOCHS:2d}]"
    print(f"{prefix} Train mask-only loss: {train_loss_mask:7.3f}")
    print(f"{prefix} Train loss: {train_loss:7.3f}. [{elapsed:.0f} secs]")

In [None]:
num_batches_astro = len(dl_train_astro)
lr_scheduler_astro = torch.optim.lr_scheduler.OneCycleLR(optimizer_1_astro, max_lr=0.01, steps_per_epoch=num_batches_astro, epochs=NUM_EPOCHS)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"Starting epoch {epoch} of {NUM_EPOCHS}")
    
    time_start = time.time()
    loss_accum = 0.0
    loss_mask_accum = 0.0
    
    for batch_idx, (images, targets) in enumerate(dl_train_astro, 1):
    
        # Predict
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = model_1_astro(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        
        # Backprop
        optimizer_1_astro.zero_grad()
        loss.backward()
        optimizer_1_astro.step()
        lr_scheduler_astro.step()
        
        # Logging
        loss_mask = loss_dict['loss_mask'].item()
        loss_accum += loss.item()
        loss_mask_accum += loss_mask
        
        if batch_idx % 50 == 0:
            print(f"    [Batch {batch_idx:3d} / {num_batches_astro:3d}] Batch train loss: {loss.item():7.3f}. Mask-only loss: {loss_mask:7.3f}")
            
#         wandb.log({"batch_train_loss": loss.item(), "mask_loss":loss_mask})
#         wandb.watch(model)
    
    # Train losses
    train_loss = loss_accum / num_batches_astro
    train_loss_mask = loss_mask_accum / num_batches_astro
    
    elapsed = time.time() - time_start
    
    torch.save(model_1_astro.state_dict(), f"/kaggle/working/model_astro-e{epoch}.bin")
    
    prefix = f"[Epoch {epoch:2d} / {NUM_EPOCHS:2d}]"
    print(f"{prefix} Train mask-only loss: {train_loss_mask:7.3f}")
    print(f"{prefix} Train loss: {train_loss:7.3f}. [{elapsed:.0f} secs]")

In [None]:
# run.finish()

In [None]:
img, targets = ds_train[2]
masks = np.zeros((HEIGHT, WIDTH))
for mask in targets['masks']:
    masks = np.logical_or(masks, mask)
plt.imshow(img.numpy().transpose((1,2,0)))
plt.imshow(masks, alpha=0.3)

In [None]:
learn.predict("../input/sartorius-cell-instance-segmentation/train/0140b3c8f445.png")

In [None]:
learn.predict(Path(TRAIN_PATH)/'0140b3c8f445.png')[0]

In [None]:
df_train[df_train.cell_type=='astro'].head(3)

In [None]:
learn.model(torch.unsqueeze(img, dim=0))

In [None]:
models_dict={
    'astro': model_1_astro,
    'cort': model_2_cort,
    'shsy5y': model_0_shsy5y}

In [None]:
ds_train.image_info[2]["image_path"]

# Analyze prediction results for train set

In [None]:
def analyze_train_sample(models_dict, learn, ds_train, sample_index):
    
    img, targets = ds_train[sample_index]
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.title("Image")
    plt.axis('off')
    plt.show()
    
    masks = np.zeros((HEIGHT, WIDTH))
    for mask in targets['masks']:
        masks = np.logical_or(masks, mask)
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.imshow(masks, alpha=0.3)
    plt.title("Ground truth")
    plt.axis('off')
    plt.show()
    
    for model in models_dict.values():
        model.eval()
    cell_type = learn.predict(ds_train.image_info[sample_index]["image_path"])[0]
    with torch.no_grad():
        preds = models_dict[cell_type]([img.to(DEVICE)])[0]

    plt.imshow(img.cpu().numpy().transpose((1,2,0)))
    all_preds_masks = np.zeros((HEIGHT, WIDTH))
    for mask in preds['masks'].cpu().detach().numpy():
        all_preds_masks = np.logical_or(all_preds_masks, mask[0] > 0.8)
    plt.imshow(all_preds_masks, alpha=0.4)
    plt.title(f"Predictions using {cell_type} model")
    plt.axis('off')
    plt.show()

In [None]:
analyze_train_sample(models_dict, learn, ds_train, 200)

In [None]:
analyze_train_sample(models_dict, learn, ds_train, 399)

In [None]:
analyze_train_sample(models_dict, learn, ds_train, 395)

## References and Important Links

For creating and working with the R-CNN model, `torchvision`'s object detection tutorial notebook (available [here](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html#)) was particularly helpful.  
For working with instance segmentation task, we referred to an entry from a previous kaggle challenge on segmenting fashion images (available [here](https://www.kaggle.com/abhishek/mask-rcnn-using-torchvision-0-17/notebook)).  
We also used pytorch utils for visualizing the masks and understanding the R-CNN output. [Check them here](https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html#instance-seg-output).  
Some helper functions and the Mask R-CNN model were referred from [this](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273) notebook as well.  
The various reference scripts ([avaiable here](https://github.com/pytorch/vision/tree/main/references/detection)) in torchvision were also helpful.

The Mask R-CNN model was first introduced in [this](https://arxiv.org/abs/1703.06870) paper.
Model used for classification was trained in [this notebook](https://www.kaggle.com/manikyab/cell-classification-helper-notebook). Interpretation from this notebook has also been done and can be found [here](https://www.kaggle.com/manikyab/cell-classification-captum-interpretation).

The weights for the model trained can be found [here](https://www.kaggle.com/manikyab/r-cnn-models-for-cell-segmentation).