# Sartorius - PyTorch Mask R-CNN training

## Libraries

In [None]:
DEBUG = False
KAGGLE = False
COLAB = True

In [None]:
from psutil import virtual_memory, cpu_count

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))
if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')
print('No of CPU cores:', cpu_count())

In [None]:
if COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    !pip install git+https://github.com/albumentations-team/albumentations.git

In [None]:
import os
import cv2
import time
import json
import random
import collections
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import GroupKFold, StratifiedKFold
import torch
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
if DEBUG:
    warnings.filterwarnings('ignore', category=UserWarning) 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print('GPU is available')
else:
    DEVICE = torch.device('cpu')
    print('CPU is used')

## Configs

In [None]:
VER = 'ver0'
FOLD_START = 0
WORK_DIR = '/content/drive/MyDrive/sartorius'
DATA_PATH = '../input/sartorius-cell-instance-segmentation' if KAGGLE else f'{WORK_DIR}/data'
MDLS_PATH = f'../input/sartorius-models-{VER}' if KAGGLE else f'{WORK_DIR}/models_{VER}'
CONFIG = {
    'width': 704,
    'height': 520,
    'resize': None,
    'batch_size': 2,
    'grad_accum': 1,
    'workers': 2,
    'folds': 5,
    'epochs': 4 if DEBUG else 50,
    'lr': 5e-4,
    'mask_th': .5,
    'normalize': False,
    'scheduler': True,
    'num_boxes': 100,
    'min_score': .5,
    'patience': 2 if DEBUG else 5,
    'verbose': 30,
    'seed': 2021
}
if not os.path.exists(MDLS_PATH):
    os.mkdir(MDLS_PATH)
with open(f'{MDLS_PATH}/config.json', 'w') as file:
    json.dump(CONFIG, file)
RESNET_MEAN = (.485, .456, .406)
RESNET_STD = (.229, .224, .225)

def seed_all(seed=0):
    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    return random_state    

random_state = seed_all(CONFIG['seed'])
start_time = time.time()

## Train test split

In [None]:
df = pd.read_csv(f'{DATA_PATH}/train.csv')
if DEBUG: 
    df = df.sample(100)
    df.reset_index(inplace=True)
gkf = GroupKFold(n_splits=CONFIG['folds'])
df['fold'] = -1
for i, (train_idxs, val_idxs) in enumerate(gkf.split(df, groups=df['id'])):
    df.loc[val_idxs, 'fold'] = i
display(df.head())

In [None]:
plt.figure(figsize=(16, 4))
plt.subplot(1, 3, 1)
plt.title(f'train data, {len(df.loc[df.fold != 0].id.unique())} unique imgs')
df.loc[df.fold != 0].cell_type.hist()
plt.subplot(1, 3, 2)
plt.title(f'val data, {len(df.loc[df.fold == 0].id.unique())} unique imgs')
df.loc[df.fold == 0].cell_type.hist()
plt.show()

## Utilities

In [None]:
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, target):
        for t in self.transforms:
            img, target = t(img, target)
        return img, target

class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, img, target):
        if random.random() < self.prob:
            height, width = img.shape[-2:]
            img = img.flip(-2)
            bbox = target['boxes']
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target['boxes'] = bbox
            target['masks'] = target['masks'].flip(-2)
        return img, target

class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, img, target):
        if random.random() < self.prob:
            height, width = img.shape[-2:]
            img = img.flip(-1)
            bbox = target['boxes']
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target['boxes'] = bbox
            target['masks'] = target['masks'].flip(-1)
        return img, target

class Normalize:
    def __call__(self, img, target):
        img = F.normalize(img, RESNET_MEAN, RESNET_STD)
        return img, target

class ToTensor:
    def __call__(self, img, target):
        img = F.to_tensor(img)
        return img, target
    
def transforms(train):
    transforms = [ToTensor()]
    if CONFIG['normalize']:
        transforms.append(Normalize())
    if train: 
        transforms.append(HorizontalFlip(.5))
        transforms.append(VerticalFlip(.5))
    return Compose(transforms)

In [None]:
def rle_decode(mask_rle, shape, color=1):
    """
    Converts string to mask
    mask_rle: run-length as string formated (start length)
    hape: (height, width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    
    """
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

## Datasets and loaders

In [None]:
class CellInferDataset(Dataset):
    def __init__(self, img_dir, transforms=None):
        self.transforms = transforms
        self.img_dir = img_dir
        self.img_idxs = [x[:-4] for x in os.listdir(self.img_dir)]
    
    def __getitem__(self, idx):
        img_idx = self.img_idxs[idx]
        img_path = os.path.join(self.img_dir, img_idx + '.png')
        img = Image.open(img_path).convert('RGB')
        if self.transforms is not None:
            img, _ = self.transforms(img=img, target=None)
        return {'image': img, 'image_id': img_idx}

    def __len__(self):
        return len(self.img_ids)

class CellDataset(Dataset):
    def __init__(self, img_dir, df, transforms=None, aug=None):
        self.transforms = transforms
        self.aug = aug
        self.img_dir = img_dir
        self.df = df
        self.height = CONFIG['height']
        self.width = CONFIG['width'] 
        self.img_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')['annotation'].agg(
            lambda x: list(x)
        ).reset_index()
        for index, row in temp_df.iterrows():
            self.img_info[index] = {
                'image_id': row['id'],
                'image_path': os.path.join(self.img_dir, row['id'] + '.png'),
                'annotations': row["annotation"]
            }
    
    def bbox_from_mask(self, mask):
        pos = np.where(mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]

    def __getitem__(self, idx):
        img_path = self.img_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")
        info = self.img_info[idx]
        n_objects = len(info['annotations'])
        boxes = []
        masks = []
        for i, ann in enumerate(info['annotations']):
            mask = rle_decode(ann, (self.height, self.width))
            mask = Image.fromarray(mask)
            mask = np.array(mask) > 0
            boxes.append(self.bbox_from_mask(mask))
            mask = np.array(mask).astype(np.float32)
            masks.append(mask)
        labels = [1 for _ in range(n_objects)]
        if self.aug:
            img = np.array(img).astype(np.float32) / 255
            augmented = self.aug(image=img, masks=masks, 
                                 bboxes=boxes, class_labels=labels)
            img = augmented['image']
            masks = augmented['masks']
            boxes = augmented['bboxes']
            labels = augmented['class_labels']
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(np.array(masks), dtype=torch.uint8)
        img_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(labels), ), dtype=torch.int64)
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': img_id,
            'area': area,
            'iscrowd': iscrowd
        }
        if self.transforms:
            img, target = self.transforms(img, target)
        return img, target

    def __len__(self):
        return len(self.img_info)

## Model

In [None]:
def cell_model(pretrained=True, pretrained_backbone=True):
    N_CLASSES = 2
    if CONFIG['normalize']:
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(
            pretrained=pretrained, 
            pretrained_backbone=pretrained_backbone,
            box_detections_per_img=CONFIG['num_boxes'],
            image_mean=RESNET_MEAN, 
            image_std=RESNET_STD
        )
    else:
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(
            pretrained=pretrained,
            pretrained_backbone=pretrained_backbone,
            box_detections_per_img=CONFIG['num_boxes']
        )
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, N_CLASSES)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask, 
        hidden_layer, 
        N_CLASSES
    )
    return model

## Training

In [None]:
class CellTrainer:
    def __init__(self, model, device, optimizer, scheduler=None):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.best_val_loss = np.inf
        self.train_losses = []
        self.val_losses = []
        self.train_mask_losses = []
        self.val_mask_losses = []
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, val_loader, save_name, max_patience):     
        n_patience = 0
        for n_epoch in range(1, epochs + 1):
            self.info_message('EPOCH: {}', n_epoch)
            train_loss, train_mask_loss, train_time = self.train_epoch(train_loader)
            val_loss, val_mask_loss, val_time = self.val_epoch(val_loader)
            self.train_losses.append(train_loss)
            self.train_mask_losses.append(train_mask_loss)
            self.val_losses.append(val_loss)
            self.val_mask_losses.append(val_mask_loss)
            self.info_message(
                'epoch train: {} | loss: {:.4f} | mask loss: {:.4f} | time: {:.2f} sec',
                n_epoch, train_loss, train_mask_loss, train_time
            )
            self.info_message(
                'epoch val: {} | loss: {:.4f} | mask loss: {:.4f} | time: {:.2f} sec',
                n_epoch, val_loss, val_mask_loss, val_time
            )
            if self.best_val_loss > val_loss: 
                self.save_model(n_epoch, save_name)
                self.info_message(
                    'val loss improved {:.4f} -> {:.4f} | saved model to "{}"', 
                    self.best_val_loss, val_loss, self.lastmodel
                )
                self.best_val_loss = val_loss
                n_patience = 0
            else:
                n_patience += 1
            if n_patience >= max_patience:
                self.info_message(
                    '\nno improvement for last {} epochs', 
                    n_patience
                )
                break
        history = {
            'train losses': self.train_losses, 
            'train mask losses': self.train_mask_losses, 
            'val losses': self.val_losses,
            'val mask losses': self.val_mask_losses
        }
        return history
            
    def train_epoch(self, train_loader):
        self.model.train()
        scaler = torch.cuda.amp.GradScaler()
        t = time.time()
        sum_loss = 0
        sum_loss_mask = 0
        for step, (imgs, targets) in enumerate(train_loader, 1):
            with torch.cuda.amp.autocast():
                imgs = [img.to(self.device) for img in imgs]
                targets = [{k: v.to(self.device) 
                            for k, v in t.items()} for t in targets]
                loss_dict = self.model(imgs, targets)
                loss = sum(loss for loss in loss_dict.values())
                scaler.scale(loss).backward()
                if ((step + 1) % CONFIG['grad_accum'] == 0) or (step + 1 == len(train_loader)):
                    scaler.step(self.optimizer)
                    scaler.update()
                    self.optimizer.zero_grad()
                    if self.scheduler:
                        self.scheduler.step()
                loss_mask = loss_dict['loss_mask']
                sum_loss += loss.detach().item()
                sum_loss_mask += loss_mask.detach().item()
            if step % CONFIG['verbose'] == 0:
                self.info_message(
                    'train step {}/{} | loss: {:.4f} | mask loss: {:.4f}     ',
                    step, len(train_loader), sum_loss / step, sum_loss_mask / step, end='\n'
                )
        return sum_loss/len(train_loader), sum_loss_mask/len(train_loader), int(time.time()-t)
    
    def val_epoch(self, val_loader):
        t = time.time()
        sum_loss = 0
        sum_loss_mask = 0
        for step, (imgs, targets) in enumerate(val_loader, 1):
            with torch.no_grad():
                imgs = [img.to(self.device) for img in imgs]
                targets = [{k: v.to(self.device) 
                            for k, v in t.items()} for t in targets]
                loss_dict = self.model(imgs, targets)
                loss = sum(loss for loss in loss_dict.values())
                loss_mask = loss_dict['loss_mask']
                sum_loss += loss.detach().item()
                sum_loss_mask += loss_mask.detach().item()
            if step % CONFIG['verbose'] == 0:
                self.info_message(
                    'val step {}/{} | loss: {:.4f} | mask loss: {:.4f}     ',
                    step, len(val_loader), sum_loss / step, sum_loss_mask / step, end='\n'
                )
        return sum_loss/len(val_loader), sum_loss_mask/len(val_loader), int(time.time()-t)
    
    def save_model(self, n_epoch, save_name, loss=None, mask_loss=None):
        if loss:
            self.lastmodel = f'{MDLS_PATH}/' + \
                f'{save_name}-e{n_epoch}-loss{loss:.3f}-maskloss{mask_loss:.3f}.pth'
        else:
            self.lastmodel = f'{MDLS_PATH}/{save_name}.pth'
        dict_save = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'n_epoch': n_epoch,
        }
        dict_save['best_val_loss'] = self.best_val_loss
        torch.save(dict_save, self.lastmodel)
    
    def display_plots(self):
        fig, axes = plt.subplots(figsize=(16, 4), nrows=1, ncols=2)
        axes[0].set_title('training and validation losses')
        axes[0].plot(self.val_losses, label='val')
        axes[0].plot(self.train_losses, label='train')
        axes[0].set_xlabel('iterations')
        axes[0].set_ylabel('loss')
        axes[0].legend()
        axes[1].set_title('training and validation mask losses')
        axes[1].plot(self.val_mask_losses, label='val')
        axes[1].plot(self.train_mask_losses, label='train')
        axes[1].set_xlabel('iterations')
        axes[1].set_ylabel('loss')
        axes[1].legend()
        plt.show()
        plt.close()
    
    @staticmethod
    def info_message(message, *args, end='\n'):
        print(message.format(*args), end=end)

In [None]:
train_aug =  A.Compose([
    A.OneOf([
        A.RandomBrightnessContrast(
            brightness_limit=.2, 
            contrast_limit=.2, 
            p=1), 
        A.RandomGamma(p=1)
    ], p=.25),
    A.Blur(blur_limit=3, p=.25),
    A.GaussNoise(.002, p=.25),
    A.HorizontalFlip(p=.5),
    A.VerticalFlip(p=.5),
    #A.ShiftScaleRotate(p=1),
    #A.RGBShift(p=.2),
    #A.Resize(CONFIG['height'], CONFIG['width'], p=1),
    ToTensorV2(p=1)
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))

val_aug =  A.Compose([
    ToTensorV2(p=1)
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))

In [None]:
def train_cell_model(df_train, df_val, fold, device, 
                     epochs, patience, batch_size):
    print('=' * 20, f'FOLD {fold}', '=' * 20)
    print('train:', df_train.shape, '| val:', df_val.shape)
    train_dataset = CellDataset(
        f'{DATA_PATH}/train', 
        df_train, 
        transforms=None, # transform(train=True)
        aug=train_aug
    )
    val_dataset = CellDataset(
        f'{DATA_PATH}/train', 
        df_val, 
        transforms=None, # transform(train=False)
        aug=val_aug
    )
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=True,
        num_workers=CONFIG['workers'], 
        collate_fn=lambda x: tuple(zip(*x)),
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=False,
        num_workers=CONFIG['workers'], 
        collate_fn=lambda x: tuple(zip(*x)),
        pin_memory=True
    )
    model = cell_model()
    model.to(device)
    for param in model.parameters():
        param.requires_grad = True
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        CONFIG['epochs']
    )
    trainer = CellTrainer(
        model, 
        device, 
        optimizer, 
        scheduler
    )
    history = trainer.fit(
        epochs, 
        train_loader, 
        val_loader, 
        save_name=f'model-f{fold}', 
        max_patience=patience
    )
    trainer.display_plots()
    with open(f'{MDLS_PATH}/history_f{fold}.json', 'w') as file:
        json.dump(history, file)
    return trainer.lastmodel

In [None]:
modelfiles = []
for fold_num in range(FOLD_START, CONFIG['folds']): 
    train_idxs = np.where((df['fold'] != fold_num))[0]
    val_idxs = np.where((df['fold'] == fold_num))[0]
    df_train = df.loc[train_idxs]
    df_val = df.loc[val_idxs]
    modelfiles.append(train_cell_model(
        df_train, 
        df_val, 
        fold_num,
        device=DEVICE, 
        epochs=CONFIG['epochs'],
        patience=CONFIG['patience'],
        batch_size=CONFIG['batch_size']
    ))
print(modelfiles)
with open(f'{MDLS_PATH}/modelfiles.json', 'w') as file:
    json.dump(modelfiles, file)

In [None]:
modelfiles = [f'{MDLS_PATH}/{x.split("/")[-1]}' for x in modelfiles]
allmodelfiles = [f'{MDLS_PATH}/{x}' for x in os.listdir(MDLS_PATH) if '.pth' in x]
for file_path in allmodelfiles:
    if file_path not in modelfiles:
        os.remove(file_path)

## Results

In [None]:
def plot_img_mask_sample(model, train_dataset, idx):
    """
    Plots: the image, image and ground truth mask, 
    image and predicted mask
    
    """
    img, targets = train_dataset[idx]
    print()
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 3, 1)
    plt.imshow(img.numpy().transpose((1, 2, 0)))
    plt.title('image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    masks = np.zeros((CONFIG['height'], CONFIG['width']))
    for mask in targets['masks']:
        masks = np.logical_or(masks, mask)
    plt.imshow(img.numpy().transpose((1, 2, 0)))
    plt.imshow(masks, alpha=.3)
    plt.title('ground truth mask')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    with torch.no_grad():
        preds = model([img.to(DEVICE)])[0]
    plt.imshow(img.cpu().numpy().transpose((1, 2, 0)))
    all_preds_masks = np.zeros((CONFIG['height'], CONFIG['width']))
    for mask in preds['masks'].cpu().detach().numpy():
        all_preds_masks = np.logical_or(all_preds_masks, mask[0] > .5)
    plt.imshow(all_preds_masks, alpha=.4)
    plt.title('predicted masks')
    plt.axis('off')
    
    plt.show()

In [None]:
train_dataset = CellDataset(
    f'{DATA_PATH}/train', 
    df, 
    transforms=None, #transforms(train=True),
    aug=val_aug
)
for model_file in modelfiles:
    model = cell_model()
    model.to(DEVICE)
    checkpoint = torch.load(model_file)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print(model_file)
    plot_img_mask_sample(model, train_dataset, idx=0)