# **Imports**

In [None]:
import os
import random
import time
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageEnhance, ImageFilter
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

import torch
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from skimage import exposure
import torchvision.transforms as T

In [None]:
TEST = False

data_directory = '../input/sartorius-cell-instance-segmentation'
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
BATCH_SIZE = 2
NUM_EPOCHS = 20

TRAIN_CSV = f"{data_directory}/train.csv"
TRAIN_PATH = f"{data_directory}/train"
TEST_PATH = f"{data_directory}/test"

WIDTH = 704
HEIGHT = 520

RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

# **Utils**

In [None]:
# ref: https://www.kaggle.com/inversion/run-length-decoding-quick-start
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)


def visualize(**images):
    n = len(images)
    plt.figure(figsize=(16, 12))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# **Robust Augmentation Utils**

In [None]:
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-2)
        return image, target


class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-1)
        return image, target


class RandomRatation:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, masks):
        if random.random() < self.prob:

            transformT = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomVerticalFlip(),
                T.RandomChoice([
                    T.RandomAffine(degrees=45, scale=(.9, 2.), translate=(0.1, 0.5), fill=0), # random transformations
                    T.RandomPerspective(distortion_scale=0.32, fill=0),                        # random perspective
                    T.RandomResizedCrop(size = (HEIGHT, WIDTH), scale = (.9, 1.4))             # random resized cropp
                    # ... ( feel free to add yout own augmentation :) )
                ]),
            ])

            transformImg = T.Compose([
                T.RandomAdjustSharpness(sharpness_factor=0.1),
            ])

            # Set seed for same augmentations in one run
            seed = np.random.randint(2147483647)  # make a seed with numpy generator
            random.seed(seed)  # apply this seed to img transforms
            torch.manual_seed(seed)


            # Transform Image
            image = transformT(image)

            # Transform Masks
            masks = np.array(masks)
            new_masks = np.zeros((masks.shape[0], 
                                  HEIGHT, WIDTH
                                  ), dtype=np.uint8)
            for i, mask in enumerate(masks):
                mask = Image.fromarray((mask).astype(np.uint8))
                random.seed(seed)  # apply this seed to img tranfsorms
                torch.manual_seed(seed)
                mask = transformT(mask)
                new_masks[i, :, :] = mask
                
            return image, new_masks
        return image, masks


class Normalize:
    def __call__(self, image, target):
        image = F.normalize(image, RESNET_MEAN, RESNET_STD)
        return image, target


class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target


def get_transform(train):
    transforms = [ToTensor(), Normalize()]
    if train:
        transforms.append(RandomRatation(1.0))

    return Compose(transforms)


# **Dataset + Augmentation**

In [None]:
class CellDataset(Dataset):
    def __init__(self, image_dir, df, transforms=None, resize=False):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = df
        self.height = HEIGHT
        self.width = WIDTH

        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')['annotation'].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                'image_id': row['id'],
                'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
                'annotations': row["annotation"]
            }

    def get_box(self, a_mask):
        ''' Get the bounding box of a given mask '''
        pos = np.where(a_mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]

    def __getitem__(self, idx):
        ''' Get the image and the target'''
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")

        info = self.image_info[idx]

        n_objects = len(info['annotations'])
        masks = np.zeros((len(info['annotations']), self.height, self.width), dtype=np.uint8)
        boxes = []

        for i, annotation in enumerate(info['annotations']):
            a_mask = rle_decode(annotation, (HEIGHT, WIDTH))
            a_mask = Image.fromarray(a_mask)
            a_mask = np.array(a_mask) > 0
            masks[i, :, :] = a_mask

        ################################ AUGMENTATION PART ################################

        if self.transforms is not None:
            img, masks = self.transforms(img, masks)

        new_masks = []
        for i, mask in enumerate(masks):
            if np.any((mask != 0)):
                box = self.get_box(mask)
                if box[0] != box[2] and box[1] != box[3]:
                    new_masks.append(mask)
                    boxes.append(box)
        masks = np.array(new_masks)

        ###################################################################################

        labels = [1 for _ in range(n_objects)]

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        return img, target
    
    def __len__(self):
        return len(self.image_info)

In [None]:
ds_train = CellDataset(TRAIN_PATH, pd.read_csv(TRAIN_CSV), transforms=get_transform(train=True))

# **Results**

In [None]:
a, b = ds_train[0]
masks = np.array(b["masks"])
mask = masks[0, :, :]
for i in range(len(masks)):
    mask += masks[i, :, :]
image = a[0, :, :]

print(a.max(), a.min())
visualize(
    image=image,
    mask=mask,
    mask2=masks[2, :, :],
)

print(mask.max())
print(mask.shape)

In [None]:
a, b = ds_train[20]
masks = np.array(b["masks"])
mask = masks[0, :, :]
for i in range(len(masks)):
    mask += masks[i, :, :]
image = a[0, :, :]

print(a.max(), a.min())
visualize(
    image=image,
    mask=mask,
    mask2=masks[2, :, :],
)

print(mask.max())
print(mask.shape)

In [None]:
a, b = ds_train[20]
masks = np.array(b["masks"])
mask = masks[0, :, :]
for i in range(len(masks)):
    mask += masks[i, :, :]
image = a[0, :, :]

print(a.max(), a.min())
visualize(
    image=image,
    mask=mask,
    mask2=masks[2, :, :],
)

print(mask.max())
print(mask.shape)

In [None]:
a, b = ds_train[20]
masks = np.array(b["masks"])
mask = masks[0, :, :]
for i in range(len(masks)):
    mask += masks[i, :, :]
image = a[0, :, :]

print(a.max(), a.min())
visualize(
    image=image,
    mask=mask,
    mask2=masks[2, :, :],
)

print(mask.max())
print(mask.shape)

# **Enjoy! :)**