## Import necessary packages

In [None]:
!pip install segmentation_models_pytorch

In [None]:
!pip install livelossplot

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from torch.optim.lr_scheduler import StepLR
from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout
import os
from PIL import Image
from PIL import ImageFile
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from livelossplot import PlotLosses 
import numpy as np
from datetime import datetime
import pandas as pd
import random 
from shutil import copyfile
import re
import albumentations as albu
from albumentations.pytorch import ToTensor
from albumentations import Compose,Resize,OneOf,RandomBrightness,RandomContrast,Normalize,HorizontalFlip,Blur,ElasticTransform,GridDistortion,OpticalDistortion,GaussNoise 
from sklearn.metrics import roc_auc_score
from skimage.io import imread, imsave
import skimage
import nibabel as nib
import time
import cv2
import copy
import segmentation_models_pytorch as smp
from tqdm import tqdm_notebook as tqdm

In [None]:
seed = 271
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
torch.cuda.empty_cache()

## Load dataset and perform augmentation

In [None]:
imgsize = 224
transforms1 = {
    'both': Compose([
                    Resize(imgsize,imgsize),
                    HorizontalFlip(p=0.5), 
                    OneOf([ElasticTransform(alpha=120, sigma=120*0.05, alpha_affine=120*0.03), GridDistortion(), OpticalDistortion(distort_limit=2, shift_limit=0.5)], p=0.3),
                    ]),
    
    'image': Compose([
                    OneOf([RandomBrightness(limit=0.1, p=0.4), RandomContrast(limit=0.1, p=0.4)]),
                    GaussNoise(),
                    Blur(p=0.1, blur_limit = 3),
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                    ]),
        

}

In [None]:
class Covid19_CT_Dataset(torch.utils.data.Dataset):
    def __init__(self, ct_path, masks_path, transform1=None, transform2=None):
        self.transforms1 = transform1
        self.transforms2 = transform2
        self.ct_path = ct_path
        self.masks_path = masks_path
        self.len = np.array(nib.load(self.ct_path).get_fdata()).shape[-1]
        
    def __getitem__(self,index):
        ct = nib.load(self.ct_path)
        ct = np.rot90(np.array(ct.get_fdata()))
        image = ct[:,:,index]
        image = Image.fromarray(image)
        image = np.array(image.convert('RGB'))
        
        ct_mask = nib.load(self.masks_path)
        ct_mask = np.rot90(np.array(ct_mask.get_fdata()))
        mask = ct_mask[:,:,index]
        
        labels = np.unique(mask).astype("uint8")
        labels = labels[1:]
        target_mask = np.zeros((mask.shape[0], mask.shape[1], 3))
        for label in labels:
            target_mask[:,:, label-1 : label] = np.expand_dims(mask, -1)==label
        
        if self.transforms1 is not None:
            augument = self.transforms1(image=image,mask=target_mask)
            image = augument['image']
            target_mask = augument['mask']
            
        if self.transforms2 is not None:
            image = self.transforms2(image=image)['image']
        
        target_mask = ToTensor()(image=target_mask)['image']
        
        image = ToTensor()(image=image)['image']
        
        return image, target_mask
    
    def __len__(self):
        return self.len

In [None]:
csv_file = r'../input/covid19-ct-scans/metadata.csv'
df = pd.read_csv(csv_file)
ct_path = df['ct_scan'].tolist()
masks_path = df['lung_and_infection_mask'].tolist()
dataset = Covid19_CT_Dataset(ct_path[0],masks_path[0],transform1=transforms1['both'], transform2=transforms1['image'])

In [None]:
img_id = 100
plt.figure(figsize=(8,8))
plt.imshow(dataset[img_id][0].permute(1,2,0).numpy(), cmap='bone')
plt.imshow(dataset[img_id][1].permute(1,2,0).numpy(), alpha=0.5, cmap='gray')

In [None]:
csv_file = r'../input/covid19-ct-scans/metadata.csv'
df = pd.read_csv(csv_file)
ct_path = df['ct_scan'].tolist()
masks_path = df['lung_and_infection_mask'].tolist()
dataset_list = []
for i in range(20):
    dataset_list.append(Covid19_CT_Dataset(ct_path[i],masks_path[i],transform1=transforms1['both'], transform2=transforms1['image']))

In [None]:
dataset_train = torch.utils.data.ConcatDataset(dataset_list[:16])
dataset_val = torch.utils.data.ConcatDataset(dataset_list[16:])
print(dataset_train.__len__())
print(dataset_val.__len__())

## Save transformed image tensors

By saving and loading the transformed tensors the runtime for each eopch is reduced considerably.

In [None]:
os.mkdir('/kaggle/working/new_train')
os.mkdir('/kaggle/working/new_val')

In [None]:
os.mkdir('/kaggle/working/new_train/img')
os.mkdir('/kaggle/working/new_val/img')
os.mkdir('/kaggle/working/new_train/mask')
os.mkdir('/kaggle/working/new_val/mask')

In [None]:
for i, data in enumerate(dataset_val):
  torch.save(data[0], '/kaggle/working/new_val/img/val_transformed_img{}'.format(i))
  torch.save(data[1], '/kaggle/working/new_val/mask/val_transformed_mask{}'.format(i))

In [None]:
for i, data in enumerate(dataset_train):
  torch.save(data[0], '/kaggle/working/new_train/img/val_transformed_img{}'.format(i))
  torch.save(data[1], '/kaggle/working/new_train/mask/val_transformed_mask{}'.format(i))

## Create new dataset to load the transformed tensors

In [None]:
class transformed_data(Dataset):
  def __init__(self, img, mask):
    self.img = img  #img path
    self.mask = mask  #mask path
    self.len = len(os.listdir(self.img))

  def __getitem__(self, index):
    ls_img = sorted(os.listdir(self.img))
    ls_mask = sorted(os.listdir(self.mask))

    img_file_path = os.path.join(self.img, ls_img[index])
    img_tensor = torch.load(img_file_path)

    mask_file_path = os.path.join(self.mask, ls_mask[index])
    mask_tensor = torch.load(mask_file_path)

    return img_tensor, mask_tensor

  def __len__(self):
    return self.len   

In [None]:
dataset_train = transformed_data('/kaggle/working/new_train/img', '/kaggle/working/new_train/mask')
dataset_val = transformed_data('/kaggle/working/new_val/img', '/kaggle/working/new_val/mask')

In [None]:
unet_train_loader = DataLoader(dataset_train, batch_size=10, shuffle=True, num_workers=2)
unet_val_loader = DataLoader(dataset_val, batch_size=5, shuffle=False, num_workers=2)

## Train the model

In [None]:
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation

# create segmentation model with pretrained encoder
model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [None]:
max_score = 0

liveloss = PlotLosses()
for i in range(0, 20):

    print('\nEpoch: {}'.format(i))
    logs = {}
    train_logs = train_epoch.run(unet_train_loader)
    valid_logs = valid_epoch.run(unet_val_loader)
    # do something (save model, change lr, etc.)

    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')
        
    #if i == 25:
    #    optimizer.param_groups[0]['lr'] = 1e-5
    #    print('Decrease decoder learning rate to 1e-5!')
    
    logs['train dice loss'],logs['val dice loss'] = train_logs['dice_loss'], valid_logs['dice_loss']
    #logs[prefix + 'iou_score'] = train_logs['iou_score'], valid_logs['iou_score']
    
    liveloss.update(logs)
    liveloss.send()