In [None]:
!pip install segmentation-models-pytorch 

In [None]:
import segmentation_models_pytorch  as smp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
import cv2
import math
import time
from tqdm import tqdm
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
import torchmetrics
from torch.optim.lr_scheduler import ReduceLROnPlateau
import gc
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from albumentations import (HorizontalFlip, VerticalFlip, 
                            ShiftScaleRotate, Normalize, Resize, 
                            Compose, GaussNoise)
from albumentations.pytorch import ToTensorV2


In [None]:
train_df = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')
train_df

In [None]:
train_df['cell_type'].unique()

In [None]:
print(train_df['height'].unique())
print(train_df['width'].unique())

In [None]:
sub_df = pd.read_csv('../input/sartorius-cell-instance-segmentation/sample_submission.csv')
sub_df

In [None]:
TEST_IMGS_PATH = "../input/sartorius-cell-instance-segmentation/test/"
TRAIN_IMGS_PATH = "../input/sartorius-cell-instance-segmentation/train/"

IMGS_WIDTH = 704
IMGS_HEIGHT = 520

RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

TARGET_IMGS_HEIGHT=512
TARGET_IMGS_WIDTH=512

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using : ",DEVICE)

In [None]:
def rle_decode(x,shape,color=1):
    
    out = np.zeros((shape[0]*shape[1],shape[2]))
    x=[int(i) for i in x.split(" ")]
    for i in range(0,len(x),2):
        out[ x[i]:(x[i]+x[i+1]) ]=color

    return np.reshape(out,shape)

In [None]:
def rle_encode(x):
    out=[]
    x=x.flatten()
    for i in range(0,x.shape[0]-1):
        if(x[i]==1):
            count=1
            out.append(str(i))
            i+=1
            while(x[i]==1):
                count+=1
                i+=1
            out.append(str(count))
    return " ".join(out)

In [None]:
def show_masks(img_id):
    rle_masks = train_df[train_df["id"]==img_id]['annotation'].tolist()
    
    cells_mask = np.zeros((IMGS_HEIGHT,IMGS_WIDTH,3))
    for rle_mask in rle_masks:
        cells_mask+=rle_decode(rle_mask,(IMGS_HEIGHT,IMGS_WIDTH,3),color=np.random.rand(3))
    img = cv2.cvtColor(cv2.imread(TRAIN_IMGS_PATH+img_id+".png"),cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(16,32))
    plt.imshow(img)
    plt.imshow(cells_mask,alpha=0.3)
    plt.show()

In [None]:
show_masks('0030fd0e6378')
show_masks('ffdb3cc02eef')
show_masks('0df9d6419078')

In [None]:
class SartoriusCellDataset(Dataset):
    def __init__(self,train_df,train_imgs_path,transforms,isVal=False):
        self.train_df=train_df
        self.image_ids=np.unique(train_df['id']).tolist()
        self.train_imgs_path=train_imgs_path
        self.transforms = transforms
    def __len__(self):
        return len(self.image_ids)
    def __getitem__(self,idx):
        
        image = cv2.cvtColor( cv2.imread( self.train_imgs_path +  self.image_ids[idx] + ".png"),cv2.COLOR_BGR2RGB)
        mask = np.zeros((image.shape[0],image.shape[1],1),dtype=np.float32)
        rle_masks=self.train_df[train_df["id"]==self.image_ids[idx]]['annotation'].tolist()
        
        for rle_mask in rle_masks:
            mask+=rle_decode(rle_mask,(image.shape[0],image.shape[1],1)).astype(np.float32)
        mask = mask.clip(0, 1)
        
        if self.transforms:
            aug = self.transforms(image=image,mask=mask)
            image,mask=aug['image'],aug['mask']
            
        return image,mask.reshape((1,image.shape[1],image.shape[2]))

In [None]:
transforms = Compose([Resize(TARGET_IMGS_HEIGHT,TARGET_IMGS_WIDTH),
                    Normalize(mean=RESNET_MEAN,std=RESNET_STD),
                    VerticalFlip(p=0.5),
                    HorizontalFlip(p=0.5),
                    ToTensorV2()])
cell_dataset = SartoriusCellDataset(train_df,
                          TRAIN_IMGS_PATH,
                          transforms)

In [None]:
val_split=0.2
val_len=math.floor(len(cell_dataset)*val_split)
train_len=len(cell_dataset) - val_len
train_ds,val_ds = torch.utils.data.random_split(cell_dataset,[train_len,val_len],generator=torch.Generator().manual_seed(42))

In [None]:
sample_img,sample_mask=train_ds[1]
print(sample_img.shape,"\n",sample_mask.shape)
print(sample_img.dtype)
print(sample_mask.dtype)

In [None]:
plt.imshow(sample_img[0],cmap="gray")

In [None]:
plt.imshow(sample_mask[0])

In [None]:
BATCH_SIZE=32
train_loader=DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=False
)
val_loader=DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False
)

In [None]:

model = smp.DeepLabV3Plus('resnet34',
                  encoder_weights="imagenet",
                )
model

model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)


In [None]:
class dice_bce_loss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(dice_bce_loss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        bce = F.binary_cross_entropy(inputs, targets, reduction='mean')
        dice_bce = bce + dice_loss
        
        return dice_bce

In [None]:
class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #intersection is equivalent to True Positive count
        #union is the mutually inclusive area of all labels & predictions 
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return 1 - IoU

In [None]:
EPOCHS=50
LEARNING_RATE=5e-4

model.to(DEVICE)

loss_fn=IoULoss(1)
optimizer = torch.optim.Adam(model.parameters(),lr=LEARNING_RATE)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode="min",patience=5,verbose=True)
best_val_loss=float(1e6)
since = time.time()
for n_epoch in range(1,EPOCHS):
    
    print("EPOCH : "+str(n_epoch)+"/"+str(EPOCHS))
    
    
    running_train_loss=0.0
    running_val_loss=0.0
    
    
    model.train()
    #TRAINING
    for train_batch_idx,train_batch in enumerate(train_loader):
        optimizer.zero_grad()

        #PREDICT
        images,masks = train_batch
        images,masks=images.to(DEVICE),masks.to(DEVICE)
        
       
        preds=model(images)
        train_loss=loss_fn(preds,masks)
        
        gc.collect()
        del train_batch
        del images
        del masks
        
        #BACKPROPAGATION
        train_loss.backward()
        optimizer.step()
        running_train_loss += train_loss.item()
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


    model.eval()
    #VALIDATION
    with torch.no_grad():
        for val_batch_idx,val_batch in enumerate (val_loader):

            #Predict
            images,masks = val_batch
            images,masks=images.to(DEVICE),masks.to(DEVICE)
            val_preds=model(images)
            val_loss=loss_fn(val_preds,masks)

            gc.collect()
            del val_batch
            del images
            del masks

            running_val_loss+=val_loss.item()
        
    running_train_loss /= train_batch_idx+1
    running_val_loss /= val_batch_idx+1
    
    #Reduce LR on Plateau
    scheduler.step(running_val_loss) 

    print(f"EPOCH : {n_epoch} Train Loss : {running_train_loss:.5f}, Val Loss : {running_val_loss:.5f}")
    if(running_val_loss < best_val_loss):
        torch.save(model.state_dict(), "/kaggle/working/best_model.pth")
        print("Model Saved")
        best_val_loss=running_val_loss
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best Val Loss: {:4f}'.format(best_val_loss))

Inference test on training data

In [None]:
model.load_state_dict(torch.load("../input/deeplabv3plus-resnet34-sartorius/best_model.pth",map_location=torch.device('cpu')))
model.to(DEVICE)

In [None]:
for batch_idx,batch in enumerate(train_loader):
    images,masks=batch
    if torch.cuda.is_available():
        images,masks=images.cuda(),masks.cuda()
    preds=model(images)
    print(preds.shape)
    fig,axs=plt.subplots(16,2,figsize=(10,80))
    images,masks=images.cpu(),masks.cpu()
    preds=preds.cpu().detach().numpy()
    print(preds[0].max(),preds[0].min())
    for i in range(16):
        #axs[i][0].imshow(images[i].reshape(512,512,3))
        axs[i][0].imshow(masks[i].reshape(512,512,1))
        axs[i][0].title.set_text("Ground truth")
        #axs[i][1].imshow(images[i].reshape(512,512,3))
        axs[i][1].imshow(preds[i].reshape(512,512,1))
        axs[i][1].title.set_text("Prediction")

    plt.subplots_adjust(wspace=0.1)

    plt.show()
    break

# INFERENCE (PART 2) HERE:
# [Sartorius-cell-segmentation-Deeplabv3-INFERENCE](https://www.kaggle.com/albertozorzetto/sartorius-cell-segmentation-deeplabv3-inference/edit)