In [None]:
import os
import torch
import gc
import cv2
import time
from torch import nn
import pandas as pd
import numpy as np
import collections
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as albu
from albumentations.pytorch import ToTensorV2
import torchvision
from tqdm import tqdm_notebook as tqdm
from skimage.color import label2rgb
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
%matplotlib inline

In [None]:
ls ../input/sartorius-maskrcc-pytorch-exp001/

In [None]:
class CFG:
    num_workers = 0
    img_dir = '../input/sartorius-cell-instance-segmentation/test/'
    model_pth = '../input/sartorius-maskrcc-pytorch-exp001/'
    height = 520
    width = 704
    mask_threshold = 0.5
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    box_detections_per_img = 539
    min_score = 0.59
    batch_size = 4

In [None]:
class TestDataset(Dataset):
    def __init__(self, image_dir, transforms=None):
        self.image_dir = image_dir
        self.imgs = os.listdir(self.image_dir)
        self.transforms = transforms
    
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.imgs[idx])
        img = cv2.imread(img_path)
        img = img[:,:,::-1]

        if self.transforms is not None:
            augmented = self.transforms(image=img)     
            img = augmented['image']
        image_id = self.imgs[idx].split('.')[0]
        
        return {'image': img, 'image_id': image_id}

In [None]:
transforms = albu.Compose([
    albu.Normalize(),
    ToTensorV2()
    ], p=1.0)

In [None]:
def get_model():
    # This is just a dummy value for the classification head
    NUM_CLASSES = 2
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, 
                                                               pretrained_backbone=False,
                                                               box_detections_per_img=CFG.box_detections_per_img,
                                                               image_mean=CFG.mean, 
                                                               image_std=CFG.std)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, NUM_CLASSES)
    return model

In [None]:
test_dataset = TestDataset(CFG.img_dir, transforms)
test_loader = DataLoader(test_dataset,
                          batch_size=CFG.batch_size, 
                          shuffle=True,
                          worker_init_fn=lambda id: np.random.seed(torch.initial_seed() // 2 ** 32 + id),
                          #collate_fn=lambda x: tuple(zip(*x)),
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

In [None]:
def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))

def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

In [None]:
model = get_model()
state = torch.load(f'{CFG.model_pth}/001_best_fold0.pth')
model.load_state_dict(state)
model.eval();
model.cuda()
submission = []
for i, samples in enumerate(test_loader):
    images = samples['image']
    image_ids = samples['image_id']
    images = list(image.cuda() for image in images)
    with torch.no_grad():
        result = model(images)
    
    previous_masks = []
    for i in range(len(images)):
        pred, image_id = result[i], image_ids[i]
        for i, mask in enumerate(pred["masks"]):
            mask = mask.cpu().numpy()
            # Keep only highly likely pixels
            binary_mask = mask > CFG.mask_threshold
            binary_mask = remove_overlapping_pixels(binary_mask, previous_masks)
            previous_masks.append(binary_mask)
            rle = rle_encoding(binary_mask)
            submission.append((image_id, rle))

        # Add empty prediction if no RLE was generated for this image
        all_images_ids = [image_id for image_id, rle in submission]
        if image_id not in all_images_ids:
            submission.append((image_id, ""))

df_sub = pd.DataFrame(submission, columns=['id', 'predicted'])
df_sub.to_csv("submission.csv", index=False)

In [None]:
df_sub.head()