# THE TRAINING NOTEBOOK (PART 1 CAN BE FOUND HERE)
# [Sartorius-Cell-Segmentation-DeepLabv3-Training](https://www.kaggle.com/albertozorzetto/sartorius-cell-segmentation-deeplabv3-training)

`█▀▀▄░░░░░░░░░░░▄▀▀█
░█░░░▀▄░▄▄▄▄▄░▄▀░░░█
░░▀▄░░░▀░░░░░▀░░░▄▀
░░░░▌░▄▄░░░▄▄░▐▀▀
░░░▐░░█▄░░░▄█░░▌▄▄▀▀▀▀█
░░░▌▄▄▀▀░▄░▀▀▄▄▐░░░░░░█
▄▀▀▐▀▀░▄▄▄▄▄░▀▀▌▄▄▄░░░█
█░░░▀▄░█░░░█░▄▀░░░░█▀▀▀
░▀▄░░▀░░▀▀▀░░▀░░░▄█▀
░░░█░░░░░░░░░░░▄▀▄░▀▄
░░░█░░░░░░░░░▄▀█░░█░░█
░░░█░░░░░░░░░░░█▄█░░▄▀
░░░█░░░░░░░░░░░████▀
░░░▀▄▄▀▀▄▄▀▀▄▄▄█▀
`

In [None]:
!pip install segmentation-models-pytorch 
!pip install sewar 

In [None]:
import sewar
import segmentation_models_pytorch  as smp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
import cv2
from tqdm import tqdm
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import gc
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from albumentations import (HorizontalFlip, VerticalFlip, 
                            ShiftScaleRotate, Normalize, Resize, 
                            Compose, GaussNoise)
from albumentations.pytorch import ToTensorV2


In [None]:
train_df = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')
train_df

In [None]:
sub_df = pd.read_csv('../input/sartorius-cell-instance-segmentation/sample_submission.csv')
sub_df

In [None]:
TEST_IMGS_PATH = "../input/sartorius-cell-instance-segmentation/test/"
TRAIN_IMGS_PATH = "../input/sartorius-cell-instance-segmentation/train/"

IMGS_WIDTH = 704
IMGS_HEIGHT = 520

RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

TARGET_IMGS_HEIGHT=512
TARGET_IMGS_WIDTH=512

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using : ",DEVICE)

In [None]:
def rle_decode(x,shape,color=1):
    
    out = np.zeros((shape[0]*shape[1],shape[2]))
    x=[int(i) for i in x.split(" ")]
    for i in range(0,len(x),2):
        out[ x[i]:(x[i]+x[i+1]) ]=color

    return np.reshape(out,shape)

In [None]:
def rle_encode(x):
    out=[]
    x=x.flatten()
    for i in range(0,x.shape[0]-1):
        if(x[i]==1):
            count=1
            out.append(str(i))
            i+=1
            while(x[i]==1):
                count+=1
                i+=1
            out.append(str(count))
    return " ".join(out)

In [None]:
class SartoriusCellDataset(Dataset):
    def __init__(self,train_df,train_imgs_path,transforms):
        self.train_df=train_df
        self.image_ids=np.unique(train_df['id']).tolist()
        self.train_imgs_path=train_imgs_path
        self.transforms = transforms
    def __len__(self):
        return len(self.image_ids)
    def __getitem__(self,idx):
        
        image = cv2.cvtColor( cv2.imread( self.train_imgs_path +  self.image_ids[idx] + ".png"),cv2.COLOR_BGR2RGB)
        mask = np.zeros((image.shape[0],image.shape[1],1),dtype=np.float32)
        rle_masks=self.train_df[train_df["id"]==self.image_ids[idx]]['annotation'].tolist()
        
        for rle_mask in rle_masks:
            mask+=rle_decode(rle_mask,(image.shape[0],image.shape[1],1)).astype(np.float32)
        mask = mask.clip(0, 1)
        
        if self.transforms:
            aug = self.transforms(image=image,mask=mask)
            image,mask=aug['image'],aug['mask']
            
        return image,mask.reshape((1,image.shape[1],image.shape[2]))

In [None]:
transforms = Compose([Resize(TARGET_IMGS_HEIGHT,TARGET_IMGS_WIDTH),
                    Normalize(mean=RESNET_MEAN,std=RESNET_STD),
                    VerticalFlip(p=0.5),
                    HorizontalFlip(p=0.5),
                    ToTensorV2()])
train_ds = SartoriusCellDataset(train_df,
                          TRAIN_IMGS_PATH,
                          transforms)

# Loading the model

In [None]:

model = smp.DeepLabV3Plus("resnet34",
                  encoder_weights="imagenet", 
                )


# Inference test on training data

In [None]:
if torch.cuda.is_available():
    model.load_state_dict(torch.load("../input/deeplabv3plus-resnet34-sartorius/best_model.pth"))
else:
    model.load_state_dict(torch.load("../input/deeplabv3plus-resnet34-sartorius/best_model.pth",map_location=torch.device('cpu')))

model.to(DEVICE)

In [None]:
BATCH_SIZE=32
train_loader=DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=False
)

In [None]:
for batch_idx,batch in enumerate(train_loader):
    images,masks=batch
    if torch.cuda.is_available():
        images,masks=images.cuda(),masks.cuda()
    preds=model(images)
    print(preds.shape)
    fig,axs=plt.subplots(16,2,figsize=(10,80))
    images,masks=images.cpu(),masks.cpu()
    preds=preds.cpu().detach().numpy()
    print(preds[0].max(),preds[0].min())
    for i in range(16):
        #axs[i][0].imshow(images[i].reshape(512,512,3))
        axs[i][0].imshow(masks[i].reshape(512,512,1))
        axs[i][0].title.set_text("Ground truth")
        #axs[i][1].imshow(images[i].reshape(512,512,3))
        axs[i][1].imshow(preds[i].reshape(512,512,1))
        axs[i][1].title.set_text("Prediction")

    plt.subplots_adjust(wspace=0.1)

    plt.show()
    break

In [None]:
sample_pred_1=preds[1].reshape((512,512,1))
plt.imshow(sample_pred_1)

In [None]:
thresholds=[0.05,0.1,0.5,0.7]
fig,axs=plt.subplots(4,2,figsize=(8,16))

for idx,t in enumerate(thresholds):
    axs[idx][0].imshow(masks[1].reshape(512,512,1))
    axs[idx][0].title.set_text("Ground Truth")
    thresh_img=cv2.threshold(sample_pred_1,t,1,cv2.THRESH_BINARY)[1]
    axs[idx][1].imshow(thresh_img)
    axs[idx][1].title.set_text("Pred,Treshold : "+str(t)+",UQI score : "+"{:.4f}".format(sewar.full_ref.uqi(sample_pred_1.reshape((512,512)),thresh_img)))

In [None]:
sample_pred_2=preds[2].reshape((512,512,1))
plt.imshow(sample_pred_2) 

In [None]:
thresholds=[0.05,0.1,0.5,0.7]
fig,axs=plt.subplots(4,2,figsize=(8,16))

for idx,t in enumerate(thresholds):
    axs[idx][0].imshow(masks[2].reshape(512,512,1))
    axs[idx][0].title.set_text("Ground Truth")
    thresh_img=cv2.threshold(sample_pred_2,t,1,cv2.THRESH_BINARY)[1]
    axs[idx][1].imshow(thresh_img)
    axs[idx][1].title.set_text("Pred,Treshold : "+str(t)+",UQI score : "+"{:.4f}".format(sewar.full_ref.uqi(sample_pred_2.reshape((512,512)),thresh_img)))

In [None]:
sample_pred_3=preds[7].reshape((512,512,1))
plt.imshow(sample_pred_3)

In [None]:
thresholds=[0.05,0.1,0.5,0.7,1]
fig,axs=plt.subplots(5,2,figsize=(8,20))

for idx,t in enumerate(thresholds):
    axs[idx][0].imshow(masks[7].reshape(512,512,1))
    axs[idx][0].title.set_text("Ground Truth")
    thresh_img=cv2.threshold(sample_pred_3,t,1,cv2.THRESH_BINARY)[1]
    axs[idx][1].imshow(thresh_img)
    axs[idx][1].title.set_text("Pred,Treshold : "+str(t)+",UQI score : "+"{:.4f}".format(sewar.full_ref.uqi(sample_pred_3.reshape((512,512)),thresh_img)))
plt.subplots_adjust(hspace=0.3)

### Apparently 0.5 threshold is better quite often 

In [None]:
def preds_postprocess(mask,threshold=0.6,min_size=300):
    mask=cv2.threshold(mask,threshold,1,cv2.THRESH_BINARY)[1]
    n_component, component = cv2.connectedComponents(mask.astype(np.uint8))
    predictions=[]
    for c in range(1,n_component):
        p = (component == c)
        if p.sum() > min_size:
            a_prediction = np.zeros((512, 512), np.float32)
            a_prediction[p] = 1
            predictions.append(a_prediction)
    return predictions

post_preds=preds_postprocess(sample_pred_1)
print(len(post_preds))
print(post_preds[0].shape)
print(np.unique(post_preds[0]))

In [None]:
fig,axs=plt.subplots(1,3,figsize=(15,5))
axs[0].imshow(masks[1].reshape((512,512,1)))
axs[0].title.set_text("Ground Truth")
axs[1].imshow(sample_pred_1.reshape((512,512,1)))
axs[1].title.set_text("Prediction")
axs[2].imshow(cv2.threshold(sample_pred_1.reshape((512,512,1)),0.5,1,cv2.THRESH_BINARY)[1])
axs[2].title.set_text("Thresholded Prediction")

In [None]:
print("RLE Test")
encode_sample_pred = post_preds[2]
fig,axs=plt.subplots(1,2)
reconstr_pred = rle_decode(rle_encode(encode_sample_pred),(512,512,1))
axs[0].imshow(encode_sample_pred)
axs[1].imshow(reconstr_pred)

axs[0].title.set_text("Original Image")
axs[1].title.set_text("Reconstructed Image")

# Test-Time Data Augmentation (TTA) INFERENCE

In [None]:
def read_img(path,resize_shape=(512,512)):
    img=cv2.cvtColor( cv2.imread(path),cv2.COLOR_BGR2RGB)
    img=cv2.resize(img,resize_shape)
    return img.astype(np.double)

def np_to_torch(img,img_shape=(512,512)):
    return torch.tensor(np.expand_dims(img, axis=0).astype(np.float32)).reshape(1,3,img_shape[0],img_shape[1])

def torch_to_plt(img,shape=(512,512,1)):
    return img.detach().numpy().astype(np.uint8).reshape(shape)

In [None]:
test_transforms = Compose([Resize(512,512),Normalize(mean=RESNET_MEAN,std=RESNET_STD), ToTensorV2()])

In [None]:
#For some reason doesnt work ლ(╥﹏╥ლ)
PREDS_THRESHOLD = 0.5
model.cpu()
single_mask_preds=[]
for id in sub_df['id'].tolist():
    
    plt.figure(figsize=(32,32))
    
    image = read_img(TEST_IMGS_PATH + id +".png").astype(np.float32)
    
    #Original
    original_image=test_transforms(image=image)["image"]
    original_pred = model(torch.unsqueeze(original_image,0)).detach().numpy()
    plt.subplot(1,8,1)
    plt.imshow(original_pred.reshape(512,512,1))
    
    #Vertical flip 
    v_flip_image=test_transforms(image=np.flipud(image).astype(np.float32))["image"]
    v_flip_pred=np.flipud(model(torch.unsqueeze(v_flip_image,0)).detach().numpy())
    plt.subplot(1,8,2)
    plt.imshow(np.flipud(v_flip_pred.reshape(512,512,1)))
    
    #Horizontal flip 
    h_flip_image=test_transforms(image=np.fliplr(image).astype(np.float32))["image"]
    h_flip_pred=np.fliplr(model(torch.unsqueeze(h_flip_image,0)).detach().numpy())
    plt.subplot(1,8,3)
    plt.imshow(np.fliplr(h_flip_pred.reshape(512,512,1)))
    
    #Diagonal
    v_h_flip_image=test_transforms(image=np.fliplr(np.flipud(image)).astype(np.float32))["image"]
    v_h_flip_pred=np.fliplr(np.flipud(model(torch.unsqueeze(v_h_flip_image,0)).detach().numpy()))
    plt.subplot(1,8,4)
    plt.imshow(np.fliplr(np.flipud(v_h_flip_pred.reshape(512,512,1))))
    
    pred = (original_pred+v_flip_pred+h_flip_pred+v_h_flip_pred) / 4
    single_mask_preds.append(pred)
    
single_mask_preds=np.asarray(single_mask_preds)
thresh_preds= single_mask_preds > PREDS_THRESHOLD

In [None]:
PREDS_THRESHOLD = 0.5
model.cpu()
model.eval()
single_mask_preds=[]
for id in sub_df['id'].tolist():
    
    plt.figure(figsize=(32,32))
    
    image = read_img(TEST_IMGS_PATH + id +".png").astype(np.float32)
    
    #Original
    image=test_transforms(image=image)["image"]
    pred = model(torch.unsqueeze(image,0)).detach().numpy()
    plt.subplot(1,8,1)
    plt.imshow(pred.reshape(512,512,1))
    single_mask_preds.append(pred)
    
single_mask_preds=np.asarray(single_mask_preds)
thresh_preds= single_mask_preds > PREDS_THRESHOLD

In [None]:
plt.figure(figsize=(80,80))
plt.subplot(1,10,1)
plt.imshow(thresh_preds[0].reshape(512,512,1))

plt.subplot(1,10,2)

plt.imshow(thresh_preds[1].reshape(512,512,1))

plt.subplot(1,10,3)

plt.imshow(thresh_preds[2].reshape(512,512,1))

# Post-process and sub

In [None]:
def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

In [None]:
def preds_postprocess(mask,threshold=0.5,min_size=300):
    mask=cv2.threshold(mask,threshold,1,cv2.THRESH_BINARY)[1]
    n_component, component = cv2.connectedComponents(mask.astype(np.uint8))
    predictions=[]
    for c in range(1,n_component):
        p = (component == c)
        if p.sum() > min_size:
            a_prediction = np.zeros((520, 704), np.float32)
            a_prediction[p] = 1
            predictions.append(a_prediction)
    return predictions

In [None]:
sample_sub = pd.read_csv('../input/sartorius-cell-instance-segmentation/sample_submission.csv')
sample_sub

In [None]:
test_transforms = Compose([Resize(512,512),Normalize(mean=RESNET_MEAN,std=RESNET_STD), ToTensorV2()])

In [None]:
pred_list=[]
for idx,id in enumerate (np.unique(sample_sub['id']).tolist()):
    
    image=read_img(TEST_IMGS_PATH +str(id) + ".png")
    image=test_transforms(image=image)["image"]
    single_mask_pred=model(image.unsqueeze(0)).detach().squeeze(0).numpy().reshape(512,512)

    single_mask_pred=cv2.resize(single_mask_pred,(704,520),interpolation = cv2.INTER_AREA)

    img_preds=preds_postprocess(single_mask_pred)
    masks=[]
    for img_pred in img_preds:
        fixed_mask=remove_overlapping_pixels(img_pred,masks)
        masks.append(fixed_mask)
        pred_list.append((str(id),rle_encode(fixed_mask)))
        
sub_df = pd.DataFrame(sub_list,columns=['id','predicted'])
sub_df

In [None]:
def build_masks(df, image_id, input_shape):
    height, width = input_shape
    labels = df[df["id"] == image_id]["predicted"].tolist()
    mask = np.zeros((height, width,1))
    for label in labels:
        mask += rle_decode(label, shape=(height, width,1))
    mask = mask.clip(0, 1)
    return mask
fig,axs=plt.subplots(1,3,figsize=(30,30))
for n,id in enumerate(np.unique(sub_df['id']).tolist()):
    sample_img=plt.imread(TEST_IMGS_PATH+id+".png")
    sample_masks=build_masks(sub_df,id,input_shape=(520, 704))

    masked = np.ma.masked_where(sample_masks == 0, sample_masks)

    axs[n].imshow(sample_img,cmap="seismic")
    axs[n].imshow(masked,alpha=0.6,cmap="bone")

plt.show()

`░░░░░░░░░░░░░░░░░░░░░▄▀░░▌ GUD.
░░░░░░░░░░░░░░░░░░░▄▀▐░░░▌
░░░░░░░░░░░░░░░░▄▀▀▒▐▒░░░▌
░░░░░▄▀▀▄░░░▄▄▀▀▒▒▒▒▌▒▒░░▌
░░░░▐▒░░░▀▄▀▒▒▒▒▒▒▒▒▒▒▒▒▒█
░░░░▌▒░░░░▒▀▄▒▒▒▒▒▒▒▒▒▒▒▒▒▀▄
░░░░▐▒░░░░░▒▒▒▒▒▒▒▒▒▌▒▐▒▒▒▒▒▀▄
░░░░▌▀▄░░▒▒▒▒▒▒▒▒▐▒▒▒▌▒▌▒▄▄▒▒▐
░░░▌▌▒▒▀▒▒▒▒▒▒▒▒▒▒▐▒▒▒▒▒█▄█▌▒▒▌
░▄▀▒▐▒▒▒▒▒▒▒▒▒▒▒▄▀█▌▒▒▒▒▒▀▀▒▒▐░░░▄
▀▒▒▒▒▌▒▒▒▒▒▒▒▄▒▐███▌▄▒▒▒▒▒▒▒▄▀▀▀▀
▒▒▒▒▒▐▒▒▒▒▒▄▀▒▒▒▀▀▀▒▒▒▒▄█▀░░▒▌▀▀▄▄
▒▒▒▒▒▒█▒▄▄▀▒▒▒▒▒▒▒▒▒▒▒░░▐▒▀▄▀▄░░░░▀
▒▒▒▒▒▒▒█▒▒▒▒▒▒▒▒▒▄▒▒▒▒▄▀▒▒▒▌░░▀▄
▒▒▒▒▒▒▒▒▀▄▒▒▒▒▒▒▒▒▀▀▀▀▒▒▒▄▀`

Aaaaaand , submission ᕕ( ᐛ )ᕗ

In [None]:
sub_df.to_csv('submission.csv')

In [None]:
sub_df