#  Mask RCNN w/ ResnetV2_50 Pre-Trained on LiveCell

# Libraries

In [None]:
!pip install pycocotools
!pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

# Imports

In [None]:
import os
import json
import torch
import torchvision
import numpy as np
import pandas as pd
import albumentations as A

from PIL import Image
from pathlib import Path
from pycocotools.coco import COCO
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# Constants & Config
The LIVECell train annotations is not a `list` of `dict` but instead a `dict` with keys being the image ids.  COCO expects it to be a `list` of `dict`.  The function `fix_annotations` gets the annotations file and fixes this.  If the `annotations` key is an instance of `list` it just returns the same `json` file otherwise it creates a new one with the correct type to be loaded by COCO.

In [None]:
def fix_annotations(anns_file):
    with open(anns_file, 'r') as f:
        data = json.load(f)
    imgs_anns = data["annotations"]
    if isinstance(imgs_anns, list):
        print("data['annotations'] is a list so using original file")
        return anns_file
    fixed_ann = []
    for key in imgs_anns:
        img_anns = imgs_anns[key]
        fixed_ann.append(img_anns)
    data["annotations"] = fixed_ann
    with open(f"./{anns_file.split('/')[-1]}", 'w') as f:
        json.dump(data, f)
    return os.path.abspath(f"./{anns_file.split('/')[-1]}")

DATASET_PATHS = {
    "train": "../input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json",
    "val": "../input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_val.json",
    "test": "../input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_test.json"
}

for path in DATASET_PATHS:
    DATASET_PATHS[path] = fix_annotations(DATASET_PATHS[path])
    
LIVECELL_IMAGES_ROOT = "../input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/images"

In [None]:
IMAGE_RESIZE = (224, 224)
RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

BATCH_SIZE = 2
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Optimizer
MOMENTUM = 0.9
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.0005

NUM_EPOCHS = 50

In [None]:
def rle_decode(mask_rle):
    
    mask_rle = np.array(mask_rle.split(), dtype=np.int)
    pixels = mask_rle.reshape(-1, 2)
#     assert len(start) == len(length)
    pixels[:, 0] -= 1
    mask = np.zeros(IMAGE_SHAPE[0] * IMAGE_SHAPE[1])
    for pixel in pixels:
        mask[pixel[0]:pixel[0] + pixel[1]] = 1
    return mask.reshape(IMAGE_SHAPE)

def prepare_image_mask(mask_annotations):
    
    mask = np.zeros(IMAGE_SHAPE)
    for mask_annotation in mask_annotations:
        mask += rle_decode(mask_annotation)
        
    mask = mask.clip(0, 1)
    return mask

def compute_iou(labels, y_pred):
    """
    Computes the IoU for instance labels and predictions.

    Args:
        labels (np array): Labels.
        y_pred (np array): predictions

    Returns:
        np array: IoU matrix, of size true_objects x pred_objects.
    """

    true_objects = len(np.unique(labels))
    pred_objects = len(np.unique(y_pred))

    # Compute intersection between all objects
    intersection = np.histogram2d(
        labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects)
    )[0]

    # Compute areas (needed for finding the union between all objects)
    area_true = np.histogram(labels, bins=true_objects)[0]
    area_pred = np.histogram(y_pred, bins=pred_objects)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)

    # Compute union
    union = area_true + area_pred - intersection
    iou = intersection / union
    
    return iou[1:, 1:]  # exclude background

def precision_at(threshold, iou):
    """
    Computes the precision at a given threshold.

    Args:
        threshold (float): Threshold.
        iou (np array): IoU matrix.

    Returns:
        int: Number of true positives,
        int: Number of false positives,
        int: Number of false negatives.
    """
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1  # Correct objects
    false_positives = np.sum(matches, axis=0) == 0  # Missed objects
    false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
    tp, fp, fn = (
        np.sum(true_positives),
        np.sum(false_positives),
        np.sum(false_negatives),
    )
    return tp, fp, fn

def iou_map(truths, preds, verbose=0):
    """
    Computes the metric for the competition.
    Masks contain the segmented pixels where each object has one value associated,
    and 0 is the background.

    Args:
        truths (list of masks): Ground truths.
        preds (list of masks): Predictions.
        verbose (int, optional): Whether to print infos. Defaults to 0.

    Returns:
        float: mAP.
    """
    ious = [compute_iou(truth, pred) for truth, pred in zip(truths, preds)]

    if verbose:
        print("Thresh\tTP\tFP\tFN\tPrec.")

    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        tps, fps, fns = 0, 0, 0
        for iou in ious:
            tp, fp, fn = precision_at(t, iou)
            tps += tp
            fps += fp
            fns += fn

        p = tps / (tps + fps + fns)
        prec.append(p)

        if verbose:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tps, fps, fns, p))

    if verbose:
        print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))

    return np.mean(prec)

# Data Augmentations

In [None]:
transforms_train = A.Compose([
#   A.HorizontalFlip(p=0.5),
#   A.VerticalFlip(p=0.5),
  A.Resize(IMAGE_RESIZE[0], IMAGE_RESIZE[1]),
  A.Normalize(mean=RESNET_MEAN, std=RESNET_STD),
  ToTensorV2()
], bbox_params=A.BboxParams(min_visibility=0.1, format="pascal_voc", label_fields=["class_labels"]))

transforms_valid = A.Compose([
  A.Resize(IMAGE_RESIZE[0], IMAGE_RESIZE[1]),
  A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ToTensorV2()
])
    
def get_transform(train):
    if train: return transforms_train
    return transforms_valid

In [None]:
from torchvision import transforms

class SartoriusDataset(Dataset):

    def __init__(self, root, anns_file, transforms=None):
        self.root = root
        self.coco = COCO(anns_file)
        self.transforms = transforms
        self.img_ids = list(sorted(self.coco.imgs.keys()))

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        anns_ids = self.coco.getAnnIds(imgIds=img_id)
        img_name = self.coco.loadImgs(ids=img_id)[0]["file_name"]
        img_folder = img_name.split('_')[0]
        img = Image.open(os.path.join(self.root, img_folder, img_name)).convert("RGB")
        img_anns = self.coco.loadAnns(ids=anns_ids)

        num_objs = len(img_anns)
        masks = []
        bboxes = []
        labels = []
        areas = []
        masks = []
        iscrowd = []

        for i in range(num_objs):
            xmin = img_anns[i]["bbox"][0]
            ymin = img_anns[i]["bbox"][1]
            xmax = xmin + img_anns[i]["bbox"][2]
            ymax = ymin + img_anns[i]["bbox"][3]

            bboxes.append([xmin, ymin, xmax, ymax])
            areas.append(img_anns[i]["area"])
            masks.append(self.coco.annToMask(img_anns[i]))

        is_crowd = torch.zeros((num_objs,), dtype=torch.int64)
        image_id = torch.tensor([img_id])
        labels = torch.ones((num_objs,), dtype=torch.int64)
        areas = torch.as_tensor(areas, dtype=torch.float32)

        if self.transforms:
            augmented = self.transforms(image=img, masks=masks, bboxes=bboxes, class_labels=["cell"] * len(bboxes))
            image = augmented["image"]
            masks = augmented["masks"]
            bboxes = augmented["bboxes"]
        else:
            image = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(IMAGE_RESIZE)
            ])(img)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
        target = {}
        target["boxes"] = bboxes
        target["masks"] = masks
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = areas
        target["iscrowd"] = is_crowd

        return image, target

In [None]:
dataset = SartoriusDataset(os.path.join(LIVECELL_IMAGES_ROOT, "livecell_train_val_images"), DATASET_PATHS["train"])
dl_train = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, 
                      num_workers=0, collate_fn=lambda x: tuple(zip(*x)))

In [None]:
# Override pythorch checkpoint with an "offline" version of the file
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/cocopre/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

In [None]:
def get_model():
    # This is just a dummy value for the classification head
    NUM_CLASSES = 2
    
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True, image_mean=RESNET_MEAN, image_std=RESNET_STD)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, NUM_CLASSES)
    return model


# Get the Mask R-CNN model
# The model does classification, bounding boxes and MASKs for individuals, all at the same time
# We only care about MASKS
model = get_model()
model.to(DEVICE)

# TODO: try removing this for
for param in model.parameters():
    param.requires_grad = True
    
model.train()

In [None]:
import time
import gc
from warmup_scheduler import GradualWarmupScheduler

torch.cuda.empty_cache()
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5, after_scheduler=lr_scheduler)
n_batches = len(dl_train)
# Zero gradient removes a warning
optimizer.zero_grad()
optimizer.step()
for epoch in range(1, NUM_EPOCHS + 1):
    print(f"Starting epoch {epoch} of {NUM_EPOCHS}")
    
    time_start = time.time()
    loss_accum = 0.0
    loss_mask_accum = 0.0
    scheduler_warmup.step()
    for batch_idx, (images, targets) in enumerate(dl_train, 1):
    
        # Predict
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        
        # Backprop518
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Logging
        loss_mask = loss_dict['loss_mask'].item()
        loss_accum += loss.item()
        loss_mask_accum += loss_mask
        
        if batch_idx % 50 == 0:
            print(f"    [Batch {batch_idx:3d} / {n_batches:3d}] Batch train loss: {loss.item():7.3f}. Mask-only loss: {loss_mask:7.3f}")
        del images
        del targets
        gc.collect()
    lr_scheduler.step()
    
    # Train losses
    train_loss = loss_accum / n_batches
    train_loss_mask = loss_mask_accum / n_batches
    
    
    elapsed = time.time() - time_start
    
    
    torch.save(model.state_dict(), f"pytorch_model-e{epoch}.bin")
    prefix = f"[Epoch {epoch:2d} / {NUM_EPOCHS:2d}]"
    print(f"{prefix} Train mask-only loss: {train_loss_mask:7.3f}")
    print(f"{prefix} Train loss: {train_loss:7.3f}. [{elapsed:.0f} secs]")