In [None]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cv2
from tqdm import tqdm
from glob import glob
from PIL import Image
from skimage.transform import resize
from sklearn.model_selection import train_test_split, KFold

import shutil

import torch
from torch.utils.data import DataLoader,Dataset
import torch.nn as nn
import albumentations
import torchvision 
from torchvision import transforms, models

import random

import segmentation_models_pytorch as smp
from PIL import Image

In [None]:
ROOT = "../input/ultrasound-nerve-segmentation"
trainpath = "../input/ultrasound-nerve-segmentation/train/"
testpath = "../input/ultrasound-nerve-segmentation/test/"

In [None]:
masks = [os.path.join(trainpath,i) for i in os.listdir(trainpath) if "mask" in i]
imgs = [i.replace("_mask","") for i in masks]

df = pd.DataFrame({"Image":imgs,"Mask":masks})

df_train, df_test = train_test_split(df,test_size = 0.1)
df_train, df_val = train_test_split(df_train,test_size = 0.2)
print(df_train.values.shape)
print(df_val.values.shape)
print(df_test.values.shape)

In [None]:
rows,cols=3,3
fig=plt.figure(figsize=(10,10))
for i in range(1,rows*cols+1):
    fig.add_subplot(rows,cols,i)
    img_path=df['Image'][i]
    msk_path=df['Mask'][i]
    img=cv2.imread(img_path)
    img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    msk=cv2.imread(msk_path)
    plt.imshow(img)
    plt.imshow(msk,alpha=0.4)
plt.show()

In [None]:
def displayimages(**images):
    n_images = len(images)
    plt.figure(figsize=(12,7))
    for idx,(name,image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]); 
        plt.yticks([])
        plt.imshow(image)
    plt.show()

    
def convert_to_tensor(x,**kwargs):
    return x.transpose(2,0,1).astype("float32")

def func_for_preprocessing(preprocessing_fn=None):
    transform = []
    if preprocessing_fn:
        transform.append(albumentations.Lambda(image=preprocessing_fn))
    transform.append(albumentations.Lambda(image=convert_to_tensor,mask=convert_to_tensor))
    return albumentations.Compose(transform)

def trainaugs():
    transform =  [
                albumentations.Resize(height=256,width=256,interpolation=Image.BILINEAR),
                albumentations.HorizontalFlip(),
                albumentations.VerticalFlip()
            ]
    return albumentations.Compose(transform)

def valaugs():
    transform = [
                albumentations.Resize(height=256,width=256,interpolation=Image.BILINEAR),
                albumentations.HorizontalFlip(),
                albumentations.VerticalFlip()
            ]
    return albumentations.Compose(transform)

In [None]:
encoder = "resnet101"
encoder_wts = "imagenet"
activation = "sigmoid"

model = smp.DeepLabV3Plus(encoder_name=encoder,encoder_weights=encoder_wts,activation=activation)


preprocess_func = smp.encoders.get_preprocessing_fn(encoder,encoder_wts)

In [None]:
class GetDataset(Dataset):
    def __init__(self,imagespath,maskspath,augment=None,preprocess=None):
        self.imagespath = imagespath
        self.maskspath = maskspath
        self.augment = augment
        self.preprocess = preprocess
        
    def __len__(self):
        return len(self.imagespath)
    
    def __getitem__(self,idx):
        image = cv2.cvtColor(cv2.imread(self.imagespath[idx]),cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.maskspath[idx]),cv2.COLOR_BGR2RGB)
        
        if self.augment:
            sample = self.augment(image=image, mask=mask)
            image,mask = sample['image'],sample['mask']
        if self.preprocess:
            sample = self.preprocess(image=image,mask=mask)
            image,mask = sample['image'],sample['mask']
        return image,mask

In [None]:
traindata = GetDataset(imagespath = df_train['Image'].tolist(),
                            maskspath = df_train['Mask'].tolist(),
                            augment = trainaugs(),
                            preprocess = func_for_preprocessing(preprocess_func))


validationdata = GetDataset(imagespath = df_val['Image'].tolist(),
                            maskspath = df_val['Mask'].tolist(),
                            augment = valaugs(),
                           preprocess = func_for_preprocessing(preprocess_func))

trainloader = DataLoader(traindata,batch_size = 16,shuffle=True)
valloader = DataLoader(validationdata,batch_size=8,shuffle=False)

In [None]:
trainmodel = True
epochs = 100
device = "cuda"
loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5)]
optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.0001)])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2, eta_min=5e-5)

if os.path.exists('./deeplabv3-using-pytorch/bestmodel.pth'):
    model = torch.load('./deeplabv3using-pytorch/bestmodel.pth', map_location=device)

trainepoch = smp.utils.train.TrainEpoch(model,loss=loss,optimizer=optimizer,metrics=metrics,device=device,verbose=True)
validepoch = smp.utils.train.ValidEpoch(model,loss=loss,metrics=metrics,device=device,verbose=True)

In [None]:
if trainmodel:
    best_iou_score = 0.0 
    train_logs_list, valid_logs_list = [], []
    for i in range(0,epochs):
        print('\nEpoch: {}'.format(i))
        trainlogs = trainepoch.run(trainloader)
        validlogs = validepoch.run(valloader)
        train_logs_list.append(trainlogs)
        valid_logs_list.append(validlogs)
        if best_iou_score < validlogs['iou_score']:
            best_iou_score = validlogs['iou_score']
            torch.save(model, './best_model.pth')
    print("Model Training completed successfully !")

In [None]:
train_logs_df = pd.DataFrame(train_logs_list)
valid_logs_df = pd.DataFrame(valid_logs_list)
train_logs_df.T

In [None]:
plt.figure(figsize=(20,8))
plt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(),'g-',lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(),'ro' ,lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('IoU Score', fontsize=20)
plt.title('IoU Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
# plt.savefig('iou_score_plot.png')
plt.show()