In [None]:
from itertools import groupby
import numpy as np
from tqdm.notebook import tqdm
import pandas as pd
import os
import pickle
import cv2
from multiprocessing import Pool
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset, sampler
from albumentations import (HorizontalFlip, VerticalFlip, ShiftScaleRotate, Normalize, Resize, Compose, GaussNoise)
from albumentations.pytorch import ToTensorV2
import random

In [None]:
SAMPLE_SUBMISSION  = '../input/sartorius-cell-instance-segmentation/sample_submission.csv'
TRAIN_CSV = "../input/sartorius-cell-instance-segmentation/train.csv"
TRAIN_PATH = "../input/sartorius-cell-instance-segmentation/train"
TEST_PATH = "../input/sartorius-cell-instance-segmentation/test"

In [None]:
train_csv = pd.read_csv("../input/sartorius-cell-instance-segmentation/train.csv")
train_path = "../input/sartorius-cell-instance-segmentation/train"
train_csv.head()
IMAGE_RESIZE = (704, 520)
RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

In [None]:
def rle_decode(mask_rle, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    
    '''
    shape = [520, 704]
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

def build_masks(df_train, image_id):
    input_shape = (520, 704)
    height, width = input_shape
    labels = df_train[df_train["id"] == image_id]["annotation"].tolist()
    mask = np.zeros((height, width))
    for label in labels:
        mask += rle_decode(label)
    mask = mask.clip(0, 1)
    return mask

In [None]:
class CellDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.base_path = TRAIN_PATH
        self.transforms = Compose([Resize(IMAGE_RESIZE[0], IMAGE_RESIZE[1]), 
                                   Normalize(mean=RESNET_MEAN, std=RESNET_STD, p=1), 
                                   HorizontalFlip(p=0.5),
                                   VerticalFlip(p=0.5),
                                   ToTensorV2()])
        self.gb = self.df.groupby('id')
        self.image_ids = df.id.unique().tolist()

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        df = self.gb.get_group(image_id)
        annotations = df['annotation'].tolist()
        image_path = os.path.join(self.base_path, image_id + ".png")
        image = cv2.imread(image_path)
        mask = build_masks(train_csv ,image_id)
        mask = (mask >= 1).astype('float32')
        augmented = self.transforms(image=image, mask=mask)
        image = augmented['image']
        mask = augmented['mask']
        return image, mask.reshape((1, IMAGE_RESIZE[0], IMAGE_RESIZE[1]))

    def __len__(self):
        return len(self.image_ids)

In [None]:
ds_train = CellDataset(train_csv)
image, mask = ds_train[1]
image.shape, mask.shape

In [None]:
def img_plot(i):
    image, mask = ds_train[i]
    plt.imshow(image[0], cmap='bone', aspect = 'auto')
    plt.show()
    plt.imshow(mask[0], alpha=0.3, aspect = 'auto')
    plt.show()
    plt.imshow(image[0], cmap='bone', aspect = 'auto')
    plt.imshow(mask[0], alpha=0.3, aspect = 'auto')
    plt.show()

In [None]:
n = random.randint(0,606)
img_plot(n)