In [None]:
import torch
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
# import torchvision.transforms.functional as TF

import random
import os, shutil
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import os
from os.path import join
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 18})
import cv2

import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, sampler
from albumentations import (HorizontalFlip, VerticalFlip, ShiftScaleRotate, Normalize, Resize, Compose, GaussNoise)

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score

In [None]:
data_train = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')

In [None]:
from torchvision.models import resnet34, resnet152, resnet50, resnet101

In [None]:
def get_net():
    n = resnet50(True)
    n.fc = nn.Linear(2048, 65)
    return n

In [None]:
data_train.head()

In [None]:
def imshow(num_to_show=9):
    
    plt.figure(figsize=(20,20))
    
    for i in range(num_to_show):
        plt.subplot(3, 3, i+1)
        plt.grid(False)
        plt.xticks([])
        plt.yticks([])
        
        img = mpimg.imread(f'../input/sartorius-cell-instance-segmentation/train/{data_train.iloc[i,0]}.png')
        plt.imshow(img, cmap='plasma')

imshow()

In [None]:
DATA_PATH = '../input/sartorius-cell-instance-segmentation'
SAMPLE_SUBMISSION = join(DATA_PATH,'train')
TRAIN_CSV = join(DATA_PATH,'train.csv')
TRAIN_PATH = join(DATA_PATH,'train')
TEST_PATH = join(DATA_PATH,'test')

df_train = pd.read_csv(TRAIN_CSV)
print(f'Training Set Shape: {df_train.shape} - {df_train["id"].nunique()} \
Images - Memory Usage: {df_train.memory_usage().sum() / 1024 ** 2:.2f} MB')

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

def build_masks(df_train, image_id, input_shape):
    height, width = input_shape
    labels = df_train[df_train["id"] == image_id]["annotation"].tolist()
    mask = np.zeros((height, width))
    for label in labels:
        mask += rle_decode(label, shape=(height, width))
    mask = mask.clip(0, 1)
    return np.array(mask)

In [None]:
class CellDataset(Dataset):
    def __init__(self, df: pd.core.frame.DataFrame, train:bool):
        self.IMAGE_RESIZE = (224, 224)
        self.RESNET_MEAN = (0.485, 0.456, 0.406)
        self.RESNET_STD = (0.229, 0.224, 0.225)
        self.df = df
        self.base_path = TRAIN_PATH
        self.gb = self.df.groupby('id')
        self.transforms = Compose([Resize( self.IMAGE_RESIZE[0],  self.IMAGE_RESIZE[1]), 
                                   Normalize(mean=self.RESNET_MEAN, std= self.RESNET_STD, p=1), 
                                   HorizontalFlip(p=0.5),
                                   VerticalFlip(p=0.5)])
        
        # Split train and val set
        all_image_ids = np.array(df_train.id.unique())
        np.random.seed(42)
        iperm = np.random.permutation(len(all_image_ids))
        num_train_samples = int(len(all_image_ids) * 0.9)

        if train:
            self.image_ids = all_image_ids[iperm[:num_train_samples]]
        else:
             self.image_ids = all_image_ids[iperm[num_train_samples:]]

    def __getitem__(self, idx: int) -> dict:

        image_id = self.image_ids[idx]
        df = self.gb.get_group(image_id)

        # Read image
        image_path = os.path.join(self.base_path, image_id + ".png")
        image = cv2.imread(image_path)

        # Create the mask
        mask = build_masks(df_train, image_id, input_shape=(520, 704))
        mask = (mask >= 1).astype('float32')
        augmented = self.transforms(image=image, mask=mask)
        image = augmented['image']
        mask = augmented['mask']
        # print(np.moveaxis(image,0,2).shape)
        return np.moveaxis(np.array(image),2,0), mask.reshape((1, self.IMAGE_RESIZE[0], self.IMAGE_RESIZE[1]))


    def __len__(self):
        return len(self.image_ids)

In [None]:
ds_train = CellDataset(df_train, train = True)
dl_train = DataLoader(ds_train, batch_size = 16,
                     num_workers = 2, pin_memory = True,
                     shuffle = False)

In [None]:
ds_test = CellDataset(df_train, train =  False)
dl_test = DataLoader(ds_test, batch_size = 4, 
                    num_workers = 2, pin_memory = True,
                    shuffle = False)

# Visualization of the images and masks

In [None]:
# plot simages and mask from dataloader
batch = next(iter(dl_train))
images, masks = batch
print(f"image shape: {images.shape},\nmask shape:{masks.shape},\nbatch len: {len(batch)}")

plt.figure(figsize=(20, 20))
        
plt.subplot(1, 3, 1)
plt.xticks([])
plt.yticks([])
plt.imshow(images[1][1])
plt.title('Original image')

plt.subplot( 1, 3, 2)
plt.xticks([])
plt.yticks([])
plt.imshow(masks[1][0])
plt.title('Mask')

plt.subplot( 1, 3, 3)
plt.xticks([])
plt.yticks([])
plt.imshow(images[1][1])
plt.imshow(masks[1][0],alpha=0.2)
plt.title('Both')
plt.tight_layout()
plt.show()

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential( 
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
         )
    def forward(self, x):
        x = self.conv(x)
        return x

In [None]:
class InConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InConv, self).__init__()
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x):
        x = self.conv(x)
        return x

In [None]:
class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential( 
            nn.MaxPool2d(2,2),
            DoubleConv(in_ch, out_ch)
         )
    def forward(self, x):
        x = self.mpconv(x)
        return x

In [None]:
class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

In [None]:
class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.conv(x)
        x = self.sigmoid(x)
        return x

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(UNet, self).__init__()
        self.inc = InConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, num_classes)
    def forward(self, x):
        # print(x.shape)
        x1 = self.inc(x)
        # print(x1.shape)
        x2 = self.down1(x1)
        # print(x2.shape)
        x3 = self.down2(x2)
        # print(x3.shape)
        x4 = self.down3(x3)
        # print(x4.shape)
        x5 = self.down4(x4)
        # print(x5.shape)
        # print('up')
        x = self.up1(x5, x4)
        # print(x.shape)
        x = self.up2(x, x3)
        # print(x.shape)
        x = self.up3(x, x2)
        # print(x.shape)
        x = self.up4(x, x1)
        # print(x.shape)
        x = self.outc(x)
        # print(x.shape)
        return x

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
def train_loop(model, optimizer, criterion, train_loader, device=device):
    running_loss = 0
    model.train()
    pbar = tqdm(train_loader, desc='Iterating over train data')
    for imgs, masks in pbar:
        # pass to device
        imgs = imgs.to(device)
        masks = masks.to(device)
        # forward
        out = model(imgs)
        loss = criterion(out, masks)
        running_loss += loss.item()*imgs.shape[0]  # += loss * current batch size
        # optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    running_loss /= len(train_loader.sampler)
    return running_loss

In [None]:
def eval_loop(model, criterion, eval_loader, device=device):
    running_loss = 0
    model.eval()
    with torch.no_grad():
        accuracy, f1_scores = [], []
        pbar = tqdm(eval_loader, desc='Iterating over evaluation data')
        for imgs, masks in pbar:
            # pass to device
            imgs = imgs.to(device)
            masks = masks.to(device)
            # forward
            out = model(imgs)
            loss = criterion(out, masks)
            running_loss += loss.item()*imgs.shape[0]
            # calculate predictions using output
            predicted = (out > 0.5).float()
            predicted = predicted.view(-1).cpu().numpy()
            labels = masks.view(-1).cpu().numpy()
            accuracy.append(accuracy_score(labels, predicted))
            f1_scores.append(f1_score(labels, predicted))
    acc = sum(accuracy)/len(accuracy)
    f1 = sum(f1_scores)/len(f1_scores)
    running_loss /= len(eval_loader.sampler)
    return {
        'accuracy':acc,
        'f1_macro':f1, 
        'loss':running_loss}

In [None]:
def train(model, optimizer, criterion, train_loader, valid_loader,
          device=device, 
          num_epochs=30, 
          valid_loss_min=np.inf,
          logdir='logdir'):
    
    tb_writer = SummaryWriter(log_dir=logdir)
    for e in range(num_epochs):
        # train for epoch
        train_loss = train_loop(
            model, optimizer, criterion, train_loader, device=device)
        # evaluate on validation set
        metrics = eval_loop(
            model, criterion, valid_loader, device=device
        )
        # show progress
        print_string = f'Epoch: {e+1} '
        print_string+= f'TrainLoss: {train_loss:.5f} '
        print_string+= f'ValidLoss: {metrics["loss"]:.5f} '
        print_string+= f'ACC: {metrics["accuracy"]:.5f} '
        print_string+= f'F1: {metrics["f1_macro"]:.3f}'
        print(print_string)

        # Tensorboards Logging
        tb_writer.add_scalar('UNet/Train Loss', train_loss, e)
        tb_writer.add_scalar('UNet/Valid Loss', metrics["loss"], e)
        tb_writer.add_scalar('UNet/Accuracy', metrics["accuracy"], e)
        tb_writer.add_scalar('UNet/F1 Macro', metrics["f1_macro"], e)

        # save the model 
        if metrics["loss"] <= valid_loss_min:
            torch.save(model.state_dict(), 'UNet.pt')
            valid_loss_min = metrics["loss"]

In [None]:
# set_seed(21)
model = UNet(3, 1).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()
train(model, optimizer, criterion, dl_train, dl_test)

In [None]:
model.load_state_dict(torch.load('UNet.pt'))
metrics = eval_loop(model, criterion, dl_test)
print('accuracy:', metrics['accuracy'])
print('f1 macro:', metrics['f1_macro'])
print('test loss:', metrics['loss'])