# ðŸ¦  Sartorius - Starter Torch Mask R-CNN

_Forked from [julian3833](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-173):_

> Following [this discussion thread](https://www.kaggle.com/c/sartorius-cell-instance-segmentation/discussion/279790), in this notebook we build a base starter Mask R-CNN with pytorch.
>
> The code is an adapted version from [this notebook](https://www.kaggle.com/abhishek/mask-rcnn-using-torchvision-0-17/) by the first quadruple kaggle grandmaster [Abishek](https://www.kaggle.com/abhishek).
>
> The [previous U-net model](https://www.kaggle.com/julian3833/sartorius-starter-baseline-torch-u-net), which I was expecting to enter a steep improvement regime with quick-wins, hit a ceiling at `0.03`, no matter what changes I performed ðŸ¥².
Data augmentation, changes in the architecture, and other changes didn't work. The suggestion that semantic segmentation doesn't work seems reasonable, since the individuals cannot be split by connected components, as they overlap heavily.
>
> This is a follow up notebook with a Mask R-CNN, which was proposed by one of the top competitors ([Inoichan](https://www.kaggle.com/inoueu1)) as a more suitable architecture for this task.
>
> I'm not very familiar with the architecture, but it seems that it is the state-of-the art for "instance segmentation".
> It classifies individuals, gets bounding boxes around them and, most importantly, provide a separated mask for each of them.
>
> You can read more about it [here](https://viso.ai/deep-learning/mask-r-cnn/).
>
> This model predicts different masks for different individual, rather that an unique mask for the whole picture and thus is better for the address the problem at hand.
>
> At the end, any overlapping pixel is removed, to ensure the non-overlapping policy. That wasn't required with the U-net, since the output was only one unique mask and therefore no overlap could have happened.

## Notes from Inoichan

[Tips in submission and baseline](https://www.kaggle.com/c/sartorius-cell-instance-segmentation/discussion/279790):

1. This is an "Instance Segmentation" problem, not a "semantic segmentation". [Mask R-CNN](https://viso.ai/deep-learning/mask-r-cnn/) is a good writeup on the difference.
2. The mask encoding is l->r then t->b, not t->b then l->r.
3. The masks cannot overlap or the submission will fail.

_There are some helper functions in this notebook for dealing with the RLE of the masks._

## Evan's Changes

These are changes I'm making from the original notebook.

1. Adding [Torch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), using [this tutorial](https://www.aicrowd.com/showcase/tutorial-with-pytorch-torchvision-and-pytorch-lightning) from [aicrowd.com](aicrowd.com). (2021-10-20)
2. Update module to use [Logging](https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#logging-from-a-lightningmodule) API to log metrics.
3. Work on training just the head first, then training the whole stack.
  a. First just the head. Hopefully.
4. Also actually use epochs limit.
5. Update to work on colab as well as kaggle.
6. Implement [mean of IOUs](https://www.kaggle.com/c/sartorius-cell-instance-segmentation/overview/evaluation), use it for `val_accuracy` instead of a `val_loss`.
7. Monitor `val_loss` for checkpoints, optimizers
8. Add [CLAHE](https://scikit-image.org/docs/dev/api/skimage.exposure.html?highlight=clahe#skimage.exposure.equalize_adapthist) preprocessing.

## Notes

- Batch size 8 -> OOM

### ADAM

- currently sort of just hovers around 1.3 loss, trying different things to get out of the hole.
- lr 0.01 -> loss nan
- switch to adam. I don't know why it wasn't adam
- Increase batch size to 8, OOM
- Batch size to 4, so far so good
- lr -> 0.005: terrible convergence
- Using `val_loss` as something to maximize isn't great...
- 6 epochs instead of 3 or 9
- Turn wandb into a context manager to reuse the session across training runs
- AMSGRAD flag, lr scheduler, seems to do a bit better than SGD even.
- Batch size 6 -> OOM

### SGD

- Trying [SGD](https://shaoanlu.wordpress.com/2017/05/29/sgd-all-which-one-is-the-best-optimizer-dogs-vs-cats-toy-experiment/) again, with lr 0.0001
- LR -> 0.001 w SGD - With SGD the masks visually look better. ADAM had weird block artifacts, but SGD just has masks.
- LR -> 0.01 - Let's see if it'll converge faster. (Basically has the same loss curve)
- Try [CosineAnnealingWarmRestarts](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html) scheduler
- Use [levbszabo's SGD params](https://github.com/levbszabo/mask-rcnn)
- Qualitatively, SGD produces much sharper masks.
- lr -> 0.005 nan loss
- lr -> 0.001
- clahe the images before train/eval in the loader
- Add option for including difference image

### Configuration

Here's the overall configuration for this run. It's sent to wandb as part of the report so this run can be compared to others.

In [None]:
config = dict (
    project = "sartiorius-cell-instance-segmentation",
    architecture = "maskrcnn_resnet50_fpn",
    dataset_id = "sartorius-cell-instance-segmentation",
    infra = "kaggle",
    lr=0.01,
    min_lr=0.0000001,
    epochs=15,
    batch_size=4,
    nesterov=True,
    momentum=0.9,
    weight_decay=0.0005,
    clip_limit=0.25,
    difference=False,
    notes="Clahe."
)

In [None]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

INTERNET = config["infra"] == "colab"
if INTERNET:
    !pip install wandb clahe --upgrade
    
    import wandb

## Input Configuration

Since Colab and Kaggle have different input paths, we configure them appropriately here. Always use [`pathlib`](https://docs.python.org/3/library/pathlib.html).

In [None]:
if config["infra"] == "colab":
    from google.colab import drive
    drive.mount("/content/gdrive")
    INPUT_ROOT = Path("/content/gdrive/MyDrive/kaggle/input")
    !pip install pytorch-lightning
    with (INPUT_ROOT / Path("wandb.txt")).open("r") as wf:
      wandb_key = wf.read().strip()

    os.environ["WANDB_API_KEY"] = wandb_key
else:
  INPUT_ROOT = Path("../input")

!nvidia-smi 

In [None]:
from contextlib import contextmanager

@contextmanager
def wandb_context(configuration=config):
  if INTERNET:
    run = wandb.init(reinit=True, config=config, project=config["project"])
    try:
      yield run
    finally:
      wandb.finish()
  else:
    yield None


# Imports and constants

First define some transforms for the images. This used to be in a separate file but then I edited it and didn't want to recreate a dataset.

In [None]:
import random
import torch
from torchvision.transforms import functional as F
from skimage.exposure import equalize_adapthist
import numpy as np

def _flip_coco_person_keypoints(kps, width):
    flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
    flipped_data = kps[:, flip_inds]
    flipped_data[..., 0] = width - flipped_data[..., 0]
    # Maintain COCO convention that if visibility == 0, then x, y = 0
    inds = flipped_data[..., 2] == 0
    flipped_data[inds] = 0
    return flipped_data

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ImageEqualize:    
    def __init__(self, clip_limit, difference):
        self._clip_limit = clip_limit
        self._difference = difference

    def equalize(self, img):
        img = np.array(img)
        eimg = equalize_adapthist(img, clip_limit=self._clip_limit)  

        if self._difference:
            img = img - eimg
            img = img - img.min()
            img = img / img.max()
        else:
            return eimg

    def __call__(self, image, target):
        return self.equalize(image), target

class RandomHorizontalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            if "masks" in target:
                target["masks"] = target["masks"].flip(-1)
            if "keypoints" in target:
                keypoints = target["keypoints"]
                keypoints = _flip_coco_person_keypoints(keypoints, width)
                target["keypoints"] = keypoints
        return image, target

class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(np.array(image, dtype=np.float))
        return image, target


In [None]:
import sys
import os
import random
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision.transforms import ToPILImage
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection._utils import Matcher
from torchvision.ops.boxes import box_iou

from PIL import Image, ImageFile
from pathlib import Path

MASKRCNN_UTILS_PATH = INPUT_ROOT / Path("maskrcnn-utils/")

# We only use 3 transformations from this package
sys.path.append(str(MASKRCNN_UTILS_PATH))

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_all_seeds(2021)

In [None]:
import torch

SAMPLE_SUBMISSION  = str(INPUT_ROOT / Path('sartorius-cell-instance-segmentation/sample_submission.csv'))
TRAIN_CSV = str(INPUT_ROOT / Path("sartorius-cell-instance-segmentation/train.csv"))
TRAIN_PATH = str(INPUT_ROOT / Path("sartorius-cell-instance-segmentation/train"))
TEST_PATH = str(INPUT_ROOT / Path("sartorius-cell-instance-segmentation/test"))


NUM_EPOCHS = config["epochs"]

# Traning Dataset

## Utilities

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)


def get_transform(train):
    transforms = [ImageEqualize(clip_limit=config["clip_limit"], difference=config["difference"]),
                  ToTensor()]

    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(RandomHorizontalFlip(0.5))
    return Compose(transforms)

## Training Dataset and DataLoader

First a helper function to run clahe on the images to enhance the contrast.

In [None]:
class CellDataset(Dataset):
    def __init__(self, image_dir, df_path, height, width, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = pd.read_csv(df_path)
        self.height = height
        self.width = width
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')['annotation'].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                    'image_id': row['id'],
                    'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
                    'annotations': row["annotation"]
                    }
            
    def __getitem__(self, idx):
        # load images ad masks
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")
        img = np.array(img, dtype=np.float)
        img = img - img.min()
        img = img / img.max()
        #img = img.resize((self.width, self.height), resample=Image.BILINEAR)

        info = self.image_info[idx]

        mask = np.zeros((len(info['annotations']), self.width, self.height), dtype=np.uint8)
        labels = []
        
        for m, annotation in enumerate(info['annotations']):
            sub_mask = rle_decode(annotation, (520, 704))
            sub_mask = Image.fromarray(sub_mask)
            #sub_mask = sub_mask.resize((self.width, self.height), resample=Image.BILINEAR)
            sub_mask = np.array(sub_mask) > 0
            mask[m, :, :] = sub_mask
            labels.append(1)

        num_objs = len(labels)
        boxes = []
        new_labels = []
        new_masks = []

        for i in range(num_objs):
            try:
                pos = np.where(mask[i, :, :])
                xmin = np.min(pos[1])
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])
                boxes.append([xmin, ymin, xmax, ymax])
                new_labels.append(labels[i])
                new_masks.append(mask[i, :, :])
            except ValueError:
                print("Error in xmax xmin")
                pass

        if len(new_labels) == 0:
            boxes.append([0, 0, 20, 20])
            new_labels.append(0)
            new_masks.append(mask[0, :, :])

        nmx = np.zeros((len(new_masks), self.width, self.height), dtype=np.uint8)
        for i, n in enumerate(new_masks):
            nmx[i, :, :] = n

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(new_labels, dtype=torch.int64)
        masks = torch.as_tensor(nmx, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        img = img.float()
        return img, target

    def __len__(self):
        return len(self.image_info)

We create the val dataloader with shuffle=False for consistent val runs. Also set num_workers to 0 to [prevent bus errors](https://stackoverflow.com/questions/51536114/pytorch-dataloader-killed).

In [None]:
dataset = CellDataset(TRAIN_PATH, TRAIN_CSV, 704, 520, transforms=get_transform(train=True))
train_size = int(len(dataset)*0.9)
val_size = len(dataset)-train_size
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size]) # We sample 10% of the images as a validation dataset

dl_train = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True, 
                      num_workers=2, collate_fn=lambda x: tuple(zip(*x)))
dl_val = DataLoader(val_set, batch_size=config["batch_size"], shuffle=False, 
                    num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

A function to combine all of the mask labels into one array with the sequence numbers in place of just the "present" flag.

In [None]:
from skimage.color import label2rgb

def combine_masks(masks):
    """Combine the masks labeled with their sequence number."""

    masks = masks.numpy()
    all_masks = np.zeros((520, 704))
    for i, mask in enumerate(masks, 1):
        all_masks[mask == True] = i
    return all_masks

def visualize_masks(img, masks, title, model=None):
    """Show the original image, then superimpose it with the provided masks.
    
    1   2
    3   4
    5   6
    
    1: Original Image
    2: Ground Truth masks
    3: Ground Truth masks
    4: Predicted masks (if applicable)
    5: Predicted masks
    """

    fig, axs = plot.subplots(3, 2)
    
    fig.title("Mask Visualization")

    axs[0][0]


In [None]:
enhance = ImageEqualize(config["clip_limit"], config["difference"])

def analyze(img, masks, model=None):
    masks = combine_masks(masks)
    fig, axs = plt.subplots(3, 2, figsize=(20, 20))
    axs[0][0].imshow(img)

    axs[0][1].imshow(label2rgb(masks, img, bg_label=0))

    eimg = enhance.equalize(img)
    print(eimg.shape)
    axs[1][0].imshow(eimg)
    axs[1][1].imshow(label2rgb(masks, eimg, bg_label=0))
 
    if config["difference"]:
        dimg = img - eimg
        dimg = dimg - dimg.min()
        dimg = dimg / dimg.max()
    else:
        dimg = img
    axs[2][0].imshow(dimg)
    axs[2][1].imshow(label2rgb(masks, dimg, bg_label=0))

    plt.show()
    return 
    plt.title("Image")
    plt.imshow(img)
    plt.show()

    all_masks = combine_masks(masks)    
    plt.imshow(all_masks, alpha=0.3)
    plt.title("GT Masks")
    plt.show()

    return
    plt.title("Mask and Enhance Visualization")
    fig, axs = plt.subplots(2, 2, figsize=(30, 30))
    axs[0][0].imshow(img)

    masks = b['masks'].numpy().astype(np.int)
    labels = b['labels']

    print(labels[labels > 1].any())
    all_masks = np.zeros_like(masks[0], np.int)
    for i, m in zip(labels, masks):
        all_masks[m != 0] = i

    axs[1][1].imshow(all_masks)

    if model is not None:
        model.eval()
        with torch.no_grad():
            preds = model([img])[0]

        all_preds_masks = np.zeros((520, 704))
        for mask in preds['masks'].cpu().detach().numpy():
            all_preds_masks = np.logical_or(all_masks, mask[0])
        plt.imshow(all_preds_masks, alpha=0.8)
        plt.title("Predictions")
        plt.show()

def analyze_sample(ds, sample_index, model=None):
    img, targets = ds[sample_index]

    img = img.numpy().transpose((1, 2, 0))
    analyze(img, targets["masks"], model)

In [None]:
analyze_sample(train_set, 8)
#img, d = train_set[2]
#img = img.numpy().transpose(1,2,0)
#plt.imshow(img)
#masks = combine_masks(d["masks"])
#plt.imshow(masks, alpha=0.3)
#plt.show()

# Train loop

## Modeds                  
First we set up model checkpointing



In [None]:
# Override pythorch checkpoint with an "offline" version of the file
os.environ["HOME"] = os.environ.get("HOME", os.getcwd)
HOME = Path(os.environ["HOME"])
chkpt_path = HOME / Path(".cache/torch/hub/checkpoints" % os.environ)
chkpt_path.mkdir(parents=True, exist_ok=True)

model_src_path = Path(INPUT_ROOT) / Path("cocopre/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth")
model_tgt_path = chkpt_path / Path("maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth")
print(model_src_path.exists())
with model_src_path.open("rb") as s:
    with model_tgt_path.open("wb+") as t:
        t.write(s.read())

In many places we have to clear any prior GPU memory allocations by first telling torch to dump it's cuda cache, then we trigger a python GC run to sweep the dangling references and trigger the objects to release the GPU memory in their destruction.

We define a decorator to handle this and wrap the relevant methods.

In [None]:
from functools import wraps
import gc

def flush_and_gc(f):
  @wraps(f)
  def g(*args, **kwargs):
    torch.cuda.empty_cache()
    gc.collect()
    return f(*args, **kwargs)
  return g

### IOU Accuracy Metric

This is mainly based on the [competition accuracy metric](https://www.kaggle.com/c/sartorius-cell-instance-segmentation/overview/evaluation).

$ \frac{1}{|\bf{t}|} \sum_\bf{t} \frac{\bf{TP}(t)}{\bf{TP}(t)\bf{FP}(t)\bf{FN}(t)}$

where:

$ \bf{t} = \{.5, .55, .6, .65, .7, .75, .8, .85, .9, .95\}$

In [None]:
iou_thresholds = [.5, .55, .6, .65, .7, .75, .8, .85, .9, .95]
iou_thresholds_mean = sum(iou_thresholds) / len(iou_thresholds)

def sartorius_iou(src_boxes, pred_boxes):
      """
      The accuracy method is not the one used in the evaluator but very similar
      """

      total_gt = len(src_boxes)
      total_pred = len(pred_boxes)

      thrshs = torch.tensor(iou_thresholds)
      thrshs_mean = torch.mean(thrshs)

      def iou(threshold):
          # Define the matcher and distance matrix based on iou
          matcher = Matcher(threshold,threshold,allow_low_quality_matches=False) 
          match_quality_matrix = box_iou(src_boxes,pred_boxes)

          results = matcher(match_quality_matrix)

          true_positive = torch.count_nonzero(results.unique() != -1)
          matched_elements = results[results > -1]

          #in Matcher, a pred element can be matched only twice 
          false_positive = torch.count_nonzero(results == -1) + ( len(matched_elements) - len(torch.unique(matched_elements)))
          false_negative = total_gt - true_positive

          acc = true_positive / ( true_positive + false_positive + false_negative )

          return acc

      if total_gt > 0 and total_pred > 0:
        return torch.tensor(sum([iou(t) for t in iou_thresholds]) / iou_thresholds_mean)
  
      elif total_gt == 0:
          if total_pred > 0:
              return torch.tensor(0.)
          else:
              return torch.tensor(1.)
      elif total_gt > 0 and total_pred == 0:
            return torch.tensor(0.)

Next we create the model. It's base is a pretrained [Mask RCNN](https://pytorch.org/vision/stable/models.html#id63). We replace the features with 

In [None]:
import pytorch_lightning as pl


class Model(pl.LightningModule):
    def __init__(self, num_classes=2, hidden_layer=256):
        super().__init__()

        # We don't want any of the pretrained model layers trainable at first.
        self.detector = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True, trainable_backbone_layers=0)

        # get the number of input features for the classifier
        in_features = self.detector.roi_heads.box_predictor.cls_score.in_features
        # replace the pre-trained head with a new one
        self.detector.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

        # now get the number of input features for the mask classifier
        in_features_mask = self.detector.roi_heads.mask_predictor.conv5_mask.in_channels

        # and replace the mask predictor with a new one
        self.detector.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    def full_train(self):
        self.detector.requires_grad = True
    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"], nesterov=config["nesterov"])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3, 1, config["min_lr"], verbose=True)
        return {"optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val_accuracy"
                }}

    def forward(self, x):
        self.detector.eval()
        return self.detector(x)

    @flush_and_gc
    def training_step(self, batch, batch_idx):
        images, targets = batch
        
        loss_dict = self.detector(images, targets)
    
        preds = self.forward(images)
        self.detector.train()
        
        loss_dict["loss_sartorius"] = 1 - torch.mean(torch.stack([sartorius_iou(b["boxes"],pb["boxes"]) for b,pb in zip(targets,preds)]))
        #loss_dict = {k:v for k, v in loss_dict.items()  if k in ["loss_sartorius", "loss_classifier", "loss_mask"]}
        loss = sum(loss_dict.values())
        
        loss_dict = {k:(v.detach() if hasattr(v, "detach") else v) for k, v in loss_dict.items()}        
        self.log("loss", loss)
        self.log_dict(loss_dict)

        return {"loss": loss, "log": loss_dict}
    
    @flush_and_gc
    def validation_step(self, batch, batch_idx):
        img, boxes = batch
        pred_boxes = self.forward(img)

        self.val_accuracy = torch.mean(torch.stack([sartorius_iou(b["boxes"],pb["boxes"]) for b,pb in zip(boxes,pred_boxes)]))
    
        self.log("val_accuracy", self.val_accuracy)
        return self.val_accuracy

    @flush_and_gc
    def test_step(self, batch, batch_idx):
        img, boxes, metadata = batch
        pred_boxes = self.forward(img) # in validation, faster rcnn return the boxes
        self.test_accuracy = torch.mean(torch.stack([sartorius_iou(b,pb["boxes"]) for b,pb in zip(boxes,pred_boxes)]))
        r = {"accuracy_test": self.test_accuracy}
        self.log_dict(r)
        return r

## Model Training

Since we train the model twice, we have a function to help us with that. It will configure [Weights and Biases](wandb.ai). It also creates the trainer, then runs it with checkpointing, and at the end will load the best checkpoint and return it.

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

def train_model(model, run_name, run=None):
  wandb_logger = None
  if run:
    run.config["train_run_name"] = run_name

    from pytorch_lightning.loggers import WandbLogger
    wandb_logger = WandbLogger()
    
  chkpt = ModelCheckpoint(f"/kaggle/working/chkpt-{run_name}", monitor="val_accuracy", mode="max")

  trainer = pl.Trainer(gpus=1, logger=wandb_logger, max_epochs=config["epochs"], callbacks=[chkpt])
  trainer.fit(model, dl_train, dl_val)
  
  return Model.load_from_checkpoint(chkpt.best_model_path)
  

## Training loop!

First we train just the head, and then we train the full model. This holds most of the parameters constant while getting the head to converge, and then fine tunes the entire model for the new head.

In [None]:
with wandb_context(config) as run:
  model = Model()
  if run:
    run.watch(model)
  model = train_model(model, "head", run)

  # Set the model for fine-tuning training
  model.full_train()
  model = train_model(model, "full", run)

## Test Dataset and DataLoader

In [None]:
class CellTestDataset(Dataset):
    def __init__(self, image_dir, height, width, transforms=None):
        self.transforms = transforms
        
        self.image_dir = image_dir
        
        self.image_ids = [f[:-4]for f in os.listdir(self.image_dir)]
        self.num_samples = len(self.image_ids)
        
        self.height = height
        self.width = width
        self.image_info = collections.defaultdict(dict)
            
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.image_dir, image_id + '.png')
        image = Image.open(image_path).convert("RGB")
        #image = image.resize((self.width, self.height), resample=Image.BILINEAR)

        if self.transforms is not None:
            image, _ = self.transforms(image=image, target=None)
        return {'image': image, 'image_id': image_id}

    def __len__(self):
        return len(self.image_ids)

In [None]:
ds_test = CellTestDataset(TEST_PATH, 704, 520, transforms=get_transform(train=False))
dl_test = DataLoader(ds_test, batch_size=config["batch_size"], shuffle=True, 
                      num_workers=0, collate_fn=lambda x: tuple(zip(*x)))

# Analyze prediction results for train set

In [None]:
# NOTE: It puts the model in eval mode!! Revert for re-training
analyze_sample(train_set, 20, model)

In [None]:
analyze_sample(train_set, 100, model)

In [None]:
analyze_sample(train_set, 2, model)

# Prediction

## Utilities

In [None]:
# Stolen from: https://www.kaggle.com/arunamenon/cell-instance-segmentation-unet-eda
# Run-length encoding stolen from https://www.kaggle.com/rakhlin/fast-run-length-encoding-python
# Modified by me
def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))


def does_overlap(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            #import pdb; pdb.set_trace()
            #print("Found overlapping masks!")
            return True
    return False


def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

## Submission

In [None]:
sublist = []
counter = 0

width = 704
height = 520
THRESHOLD = 0.5

model.eval()

for sample in ds_test:
    img = sample['image'].float()
    image_id = sample['image_id']
    with torch.no_grad():
        result = model([img])[0]
    if len(result["masks"]) > 0:
        previous_masks = []
        for j, m in enumerate(result["masks"]):
            original_mask = result["masks"][j][0].cpu().numpy()
            original_mask = remove_overlapping_pixels(original_mask, previous_masks)
            previous_masks.append(original_mask)
            rle = rle_encoding(original_mask > THRESHOLD)
            sublist.append([image_id, rle])
    else:
        sublist.append([image_id, ""])

df_sub = pd.DataFrame(sublist, columns=['id', 'predicted'])
df_sub.to_csv("submission.csv", index=False)
df_sub.head()