In [None]:
pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

In [None]:
!pip -q install pycocotools

#### copy some files to current dir.
#### these files are modified from: https://github.com/pytorch/vision/tree/main/references/detection

In [None]:
!cp ../input/train-utils/* ./

In [None]:
import os
import torch
import gc
import cv2
import time
import random
from torch import nn
import pandas as pd
import numpy as np
import collections
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as albu
from albumentations.pytorch import ToTensorV2
import torchvision
from engine import train_one_epoch, evaluate
import torchvision.transforms as T
from tqdm import tqdm_notebook as tqdm
from skimage.color import label2rgb
from sklearn.model_selection import GroupKFold
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
%matplotlib inline

In [None]:
df = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv');df.shape

In [None]:
folds = df.copy()
kf = GroupKFold(n_splits=5)
for f, (t_idx, v_idx) in enumerate(kf.split(df, groups=df.id.values)):
    folds.loc[v_idx, 'fold'] = int(f)
folds['fold'] = folds['fold'].astype(int)

In [None]:
class CFG:
    debug = False 
    num_workers = 0
    precision = 16
    device = torch.device('cuda')
    img_dir = '../input/sartorius-cell-instance-segmentation/train/'
    epochs = 10 
    patience = 4
    height = 520
    width = 704
    T_max = 5
    momentum = 0.9
    eta_min = 1e-7
    weight_decay = 5e-4
    mask_threshold = 0.5
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    box_detections_per_img = 539
    min_score = 0.59
    lr = 1e-3
    batch_size = 2
    seed = 42
    n_folds = 5

In [None]:
def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_all_seeds(CFG.seed)

In [None]:
label_dict = {'shsy5y':0, 'astro':1, 'cort':2}

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

## Dataset

In [None]:
# https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273
class SartoriusDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')[['annotation', 'cell_type']].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                    'image_id': row['id'],
                    'label': label_dict[row['cell_type'][0]],
                    'image_path': os.path.join(CFG.img_dir, row['id'] + '.png'),
                    'annotations': row["annotation"]
                    }
    
    def get_box(self, a_mask):
        ''' Get the bounding box of a given mask '''
        pos = np.where(a_mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]

    def __getitem__(self, idx):
        ''' Get the image and the target'''
        
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")
        label = self.image_info[idx]["label"]
        info = self.image_info[idx]

        n_objects = len(info['annotations'])
        masks = np.zeros((len(info['annotations']), CFG.height, CFG.width), dtype=np.uint8)
        boxes = []
        
        for i, annotation in enumerate(info['annotations']):
            a_mask = rle_decode(annotation, (CFG.height, CFG.width))
            a_mask = Image.fromarray(a_mask)
            
            a_mask = np.array(a_mask) > 0
            masks[i, :, :] = a_mask
            
            boxes.append(self.get_box(a_mask))

        labels = [label] * n_objects
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8) # uint8

        image_id = torch.tensor([idx])
        #image_id = self.image_info[idx]["image_id"]
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        # This is the required target for the Mask R-CNN
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        if self.transforms is not None:
            img, target = self.transforms(img, target)
            
        return img, target#, image_id

    def __len__(self):
        return len(self.image_info)
    
def collate_fn(batch):
    return tuple(zip(*batch))

## Augmentations

In [None]:
# https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-2)
        return image, target

class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-1)
        return image, target

class Normalize:
    def __call__(self, image, target):
        image = F.normalize(image, RESNET_MEAN, RESNET_STD)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

In [None]:
def get_transform(train):
    transforms = []
    transforms.append(ToTensor())
    if train:
        transforms.append(HorizontalFlip(0.5))
        transforms.append(VerticalFlip(0.5))
    return Compose(transforms)

In [None]:
# https://www.kaggle.com/blondinka/how-to-do-augmentations-for-instance-segmentation
def visualize_bbox(img, bbox, color=(255, 0, 255), thickness=2):  
    """Helper to add bboxes to images 
    Args:
        img : image as open-cv numpy array
        bbox : boxes as a list or numpy array in pascal_voc fromat [x_min, y_min, x_max, y_max]  
        color=(255, 255, 0): boxes color 
        thickness=2 : boxes line thickness
    """
    x_min, y_min, x_max, y_max = bbox
    x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
    return img

def plot_image_aug(image, masks, boxes):
    # glue masks together
    one_mask = np.zeros_like(masks[0])
    for i, mask in enumerate(masks):
        #one_mask += (mask > 0).astype(np.uint8) * (11-i)
        one_mask = np.add(one_mask, (mask > 0).astype(np.uint8) * (11-i), out=one_mask, casting="unsafe")
    
    for box in boxes:
        #print(box)
        image = visualize_bbox(np.ascontiguousarray(image), box)  
        
    # for binary masks we get one channel and need to convert to RGB for visualization
    mask_rgb = label2rgb(one_mask, bg_label=0)            
    
    plt.figure(figsize=(12,12))
    plt.figure(1)
    ax1 = plt.subplot(121)
    plt.imshow(image)
    plt.title('image')   
    ax2 = plt.subplot(122)
    plt.imshow(mask_rgb)
    plt.title('mask')   
    plt.show()

In [None]:
sample_dataset = SartoriusDataset(folds, transforms=get_transform(True))
sample_dataloader = DataLoader(sample_dataset, batch_size=4, collate_fn=collate_fn, shuffle=False)
samples = iter(sample_dataloader).next()
images, targets = samples[0], samples[1]
for it, (image, target) in enumerate(zip(images, targets)):
    plot_image_aug(((image.permute(1,2,0))*255.0).numpy().astype(np.uint8), target['masks'].numpy(), target['boxes'])

## Model

In [None]:
# def get_model():
#     backbone = torchvision.models.efficientnet_b4(pretrained=True).features   #1792
#     backbone.out_channels = 1792
#     anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),))
#     roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=7, sampling_ratio=2)
#     mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=14, sampling_ratio=2)
#     # put the pieces together inside a MaskRCNN model
#     model = MaskRCNN(backbone, num_classes=3, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, mask_roi_pool=mask_roi_pooler)
#     return model

In [None]:
def get_model(num_classes=3):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                              box_detections_per_img=600)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

In [None]:
def train_loop(fold):
    # ------------
    # data
    # ------------
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    
    train_dataset = SartoriusDataset(train_folds, transforms=get_transform(True))
    valid_dataset = SartoriusDataset(valid_folds, transforms=get_transform(False))
    
    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size, 
                              shuffle=True,
                              worker_init_fn=lambda id: np.random.seed(torch.initial_seed() // 2 ** 32 + id),
                              collate_fn=collate_fn,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=False, 
                              collate_fn=collate_fn,
                              worker_init_fn=lambda id: np.random.seed(torch.initial_seed() // 2 ** 32 + id),
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    
    # ------------
    # model
    # ------------
    model = get_model()
    for param in model.parameters():
        param.requires_grad = True
    params = [p for p in model.parameters() if p.requires_grad]
    model.to(CFG.device)
    #optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    optimizer = torch.optim.SGD(params, lr=CFG.lr, momentum=CFG.momentum, weight_decay=CFG.weight_decay)
    #scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.eta_min)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    best_metric = -1.
    for epoch in range(CFG.epochs):
        # train for one epoch, printing every 10 iterations
        train_one_epoch(model, optimizer, train_loader, CFG.device, epoch, print_freq=200)
        # update the learning rate
        scheduler.step()
        # evaluate on the test dataset
        coco_evaluator = evaluate(model, valid_loader, device=CFG.device)
        metric = coco_evaluator.coco_eval['segm'].stats[0]
        if best_metric < metric:
            best_metric = metric
            torch.save(model.state_dict(), os.path.join(f'./best_{best_metric}_fold{fold}.pth'))

In [None]:
for i in range(CFG.n_folds):
    train_loop(i)
    break  

==========================================================================================================

In [None]:
## plz upvote if you like it.
## And do help me to improve this, I can't get a good score with these settings.