In [None]:
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
!git clone https://github.com/Bjarten/early-stopping-pytorch.git esp

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision 
from torchvision import transforms

import pandas as pd
import numpy as np
import os

from torchvision import transforms
from PIL import Image

import segmentation_models_pytorch as smp
from esp.pytorchtools import EarlyStopping

import matplotlib.pyplot as plt

# Kostyl'

In [None]:
def save_checkpoint(self, val_loss, model):
    import pickle       
    if self.verbose:
        self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
    with open(self.path, 'wb') as f:
        pickle.dump(model, f) # torch.save(model.state_dict(), self.path)
    self.val_loss_min = val_loss

In [None]:
EarlyStopping.save_checkpoint = save_checkpoint

# Dataset

In [None]:
input_ = '../input/ultrasound-nerve-segmentation'

train_path = f'{input_}/train'
test_path = f'{input_}/test'

train_csv_path = 'train_annotation.csv'
test_path = '../input/ultrasound-nerve-segmentation/test'

In [None]:
def create_csv(data_path, out_csv_path, key_word='mask'):
    to_delete = f'_{key_word}'

    for file_name in os.listdir(data_path):
        if key_word in file_name:
            img = file_name.replace(to_delete, '')
            data = pd.DataFrame([img], index=['img']).transpose()
            data.insert(0, 'mask', file_name)

        else:                
            if not os.path.exists(out_csv_path):
                data.to_csv(out_csv_path, header=True, index=False)
            else:
                data.to_csv(out_csv_path, mode='a', header=False, index=False)

In [None]:
create_csv(data_path=train_path, out_csv_path=train_csv_path)

In [None]:
class ImageDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.root_dir = root_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
       
        mask = Image.open(os.path.join(self.root_dir, self.df.iloc[idx, 0]))
        image = Image.open(os.path.join(self.root_dir, self.df.iloc[idx, 1]))  
    
        if self.transform:
            return self.transform(image), self.transform(mask)
    
        return image, mask

In [None]:
train_df = pd.read_csv(train_csv_path)

In [None]:
train_df.head()

In [None]:
train_samples = ImageDataset(df=train_df, root_dir=train_path)

In [None]:
def draw_samples(data, n_col, n_row):
    fig = plt.figure(figsize=(15, 5))
        
    for i in range(1, n_col + 1):
        img_ax = fig.add_subplot(n_row, n_col, i)
        msk_ax = fig.add_subplot(n_row, n_col, i + n_col)
        
        img_ax.imshow(data[i-1][0], cmap='gray')
        msk_ax.imshow(data[i-1][1], cmap='gray')
        
    fig.show()

In [None]:
draw_samples(data=train_samples, n_col=5, n_row=2)

# Model

In [None]:
ENCODER = 'vgg11_bn'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=1,
    classes=1,
    activation=ACTIVATION
)

In [None]:
loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU()]
optimizer = torch.optim.Adam
scheduler = lr_scheduler.StepLR

In [None]:
my_transforms = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()
])

# Training

In [None]:
def split_df(df, fraction=0.8):  
    df_1 = df.sample(frac=fraction)
    return df_1, df.drop(df_1.index)

In [None]:
def train(model, train_df, train_dir, optimizer, loss, metrics, 
          learning_rate=0.01, batch_size=20, epochs=10, patience=3,
          scheduler=None, step_size=5, gamma=0.1, device='cpu', transform=None):   
    
    early_stopping = EarlyStopping(patience, path='best_model.pkl', verbose=True)
    optimizer = optimizer(model.parameters(), learning_rate)

    if scheduler:
        scheduler = scheduler(optimizer, step_size, gamma) 

    train_epoch = smp.utils.train.TrainEpoch(
        model, loss, metrics, optimizer, device, verbose=True
    )
    
    valid_epoch = smp.utils.train.ValidEpoch(
        model, loss, metrics, device, verbose=True
    ) 
    
    train_logs, valid_logs = [], []
    
    for epoch in range(epochs):   
        train_dataframe, val_dataframe = split_df(train_df) 
          
        train_dataset = ImageDataset(train_dataframe, train_dir, transform=transform)

        valid_dataset = ImageDataset(val_dataframe, train_dir, transform=transform)

        train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                                   batch_size=batch_size, 
                                                   shuffle=True)    

        valid_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                                   batch_size=batch_size, 
                                                   shuffle=False)        

        print(f'\nEpoch: {epoch+1}/{epochs}')

        train_log = train_epoch.run(train_loader)
        valid_log = valid_epoch.run(valid_loader)
        
        train_logs.append(train_log)
        valid_logs.append(valid_log)
   
        early_stopping(valid_log[loss.__name__], model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break

        if scheduler:
            scheduler.step()

    return train_logs, valid_logs

In [None]:
res = train(model=model,
            train_df=train_df, 
            train_dir=train_path, 
            optimizer=optimizer,
            loss=loss,
            learning_rate=0.01,
            metrics=metrics,
            batch_size=20,
            epochs=20,
            scheduler=scheduler,
            step_size=10,
            patience=3,
            device=DEVICE, 
            transform=my_transforms)

# Save results / Load best checkpoint

In [None]:
import pickle

In [None]:
with open('results.pkl', 'wb') as f:
    pickle.dump(res, f)

In [None]:
with open('best_model.pkl', 'rb') as f:
    best_model = pickle.load(f)

# Draw graphics

In [None]:
train_logs_df = pd.DataFrame(res[0])
valid_logs_df = pd.DataFrame(res[1])

res_dict = {'train': train_logs_df, 'valid': valid_logs_df}

In [None]:
def draw_graphic(df_dict, title, criteria, xlab, ylab, colors=['b', 'r'], 
                 legend_loc='best', figsize=(10, 5), fontsize=16):
    fig = plt.figure(figsize=figsize)
    for i, key in enumerate(df_dict):
        plt.plot(df_dict[key].index.tolist(), df_dict[key][criteria].tolist(), colors[i], lw=3, label=key)
    plt.xlabel(xlab, fontsize=fontsize)
    plt.ylabel(ylab, fontsize=fontsize)
    plt.title(title, fontsize=fontsize)
    plt.legend(loc=legend_loc, fontsize=fontsize)
    plt.grid()
    fig.show()    

In [None]:
draw_graphic(df_dict=res_dict, title='IoU Scores', criteria='iou_score', xlab='epochs', ylab='IoU score')

In [None]:
draw_graphic(df_dict=res_dict, title='Dice Losses', criteria='dice_loss', xlab='epochs', ylab='IoU score')

# Create submission

In [None]:
def rle_encoding(x):
    dots = np.where(x.T.flatten()==1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b+1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

In [None]:
from tqdm.notebook import tqdm

In [None]:
imgs = [f for f in os.listdir(test_path)]
imgs = sorted(imgs, key=lambda s: int(s.split('.')[0]))

In [None]:
def create_csv_submission(model, data_path, img_list, out_path):
    submission_df = pd.DataFrame(columns=['img', 'pixels'])
    model.to(DEVICE)
    model.eval()
    
    for i, img in enumerate(tqdm(img_list)):
        x = Image.open(os.path.join(data_path, img))

        x = my_transforms(x)

        x = x.unsqueeze(0).to(DEVICE)
        pred_mask = model.predict(x)

        pred_mask = pred_mask.cpu()#.numpy().round().astype(np.uint8)
        pred_mask = transforms.Resize(size=(420, 580))(pred_mask)

        encoding = rle_encoding(pred_mask)

        pixels = ' '.join(map(str, encoding))
        submission_df.loc[i] = [str(i+1), pixels]

    submission_df.to_csv(out_path, index=False)

In [None]:
create_csv_submission(model=model, 
                      data_path=test_path, 
                      img_list=imgs,
                      out_path='submission.csv')