In [None]:
"""
Link to notebook:
https://www.kaggle.com/parthvora/mlsp-final-project-preprocessed-data/

"""

In [None]:
# Imports
import os
from glob import glob
import random

from collections import OrderedDict
import numpy as np
import pandas as pd
import cv2 as cv
from PIL import Image
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF


In [None]:
### Read in data paths

DATA_PATH = "/kaggle/input/lggsegmentationpreprocessed/kaggle_3m - processed/"
image_files = []
mask_files = glob('../input/lggsegmentationpreprocessed/kaggle_3m - processed/*/*_mask*')

for i in mask_files:
    image_files.append(i.replace('_mask',''))

# Train test split
df = pd.DataFrame(data={"images": image_files, 'masks' : mask_files})
df_train, df_test = train_test_split(df,test_size = 0.1)
df_train, df_val = train_test_split(df_train,test_size = 0.2)
print(df_train.values.shape)
print(df_val.values.shape)
print(df_test.values.shape)

In [None]:
""" 

Class for MRI dataset.

In original code, images are cropped, padded, resized, and normalized.

Images are also augmented by scaling, rotating, and horizontally flipping.

"""

class MRITrainDataset(Dataset):
    
    # Takes in a dataframe of (image, mask) and a pytorch transform
    def __init__(self, df, transform_params):
        self.image_list = list(df.images)
        self.mask_list = list(df.masks)
        self.transform_params = transform_params
    
    def __len__(self):
        return(len(self.image_list))
    
    def __getitem__(self, i):

        # Read in image
        image = cv.imread(self.image_list[i])
        mask = cv.imread(self.mask_list[i], cv.IMREAD_GRAYSCALE)
        
        # Apply transform or convert image ndarray to tensor
        if self.transform_params is not None:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
            image, mask = self.transform(image, mask)

        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
        
        return image, mask
    
    def transform(self, img, mask):
        
        #(cropsize, padding, imsize, degrees, scale, pflip, mean, stdev) = self.transform_params  
        (degrees, scale, pflip, mean, stdev) = self.transform_params
        
        # Convert to PIL image
        img = transforms.ToPILImage()(img)
        mask = transforms.ToPILImage()(mask)
        
        # Crop boundry by cropsize pixels
        #img = transforms.CenterCrop(cropsize)(img)
        #mask = transforms.CenterCrop(cropsize)(mask)
        
        # Pad boundary by padding pixels
        #img = transforms.Pad(padding)(img)
        #mask = transforms.Pad(padding)(mask)
        
        # Resize image?
        #img = transforms.Resize(imsize)(img)
        #mask = transforms.Resize(imsize)(mask)
        
        # Affine transform = rotation + scaling, ignoring translation
        angle = np.random.uniform(low=degrees[0], high=degrees[1])
        scale = np.random.uniform(low=scale[0], high=scale[1])    
        img = TF.affine(img, angle=angle, scale=scale, translate=(0, 0), shear=0)
        mask = TF.affine(mask, angle=angle, scale=scale, translate=(0, 0), shear=0)
        
        # Horizontal flip
        if np.random.uniform(0, 1) < pflip:
            img = TF.hflip(img)
            mask = TF.hflip(mask)
            
        # Convert to tensor
        img = transforms.ToTensor()(img)
        mask = transforms.ToTensor()(mask)
            
        # Normalize img only
        img = transforms.Normalize(mean, stdev)(img)
        
        return img, mask


In [None]:
""" 

Class for MRI dataset for validation and testing.

Just normalizes images.

"""

class MRITestDataset(Dataset):
    
    # Takes in a dataframe of (image, mask) and a pytorch transform
    def __init__(self, df, transform_params):
        self.image_list = list(df.images)
        self.mask_list = list(df.masks)
        self.transform_params = transform_params
    
    def __len__(self):
        return(len(self.image_list))
    
    def __getitem__(self, i):

        # Read in image
        image = cv.imread(self.image_list[i])
        mask = cv.imread(self.mask_list[i], cv.IMREAD_GRAYSCALE)
        
        # Apply transform or convert image ndarray to tensor
        if self.transform_params is not None:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
            image, mask = self.transform(image, mask)

        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
        
        return image, mask
    
    def transform(self, img, mask):
        
        (mean, stdev) = self.transform_params  
            
        # Normalize img only
        img = transforms.Normalize(mean, stdev)(img)
        
        return img, mask
  


In [None]:
# Calculate mean and stdev of rgb values for training set
"""
mean_pixels = np.zeros(3)
std_pixels = np.zeros(3)
for fn in df_train.images:
    
    img = cv.imread(fn)
    
    mean_pixels += np.mean(img, axis=(0, 1))
    std_pixels += np.std(img, axis=(0, 1))


mean_pixels /= len(df_train.images)
std_pixels /= len(df_train.images)
"""


In [None]:

# Initialize data transformer parameters
#cropsize = 200
#padding = 0
#imsize = 256
degrees = (0, 0)
scale = (1, 1)
pflip = 0.5

# precalculated to speed up time
mean = np.array([22.65424621, 19.91834147, 23.59800301])
stdev = np.array([31.34933184, 30.29778424, 30.87938004])


transform_params = (

    #cropsize,
    #padding,
    #imsize,
    degrees, 
    scale,
    pflip,
    mean,
    stdev

)

In [None]:

# Create MRI datasets and dataloaders
train_dataset = MRITrainDataset(df_train, transform_params)
val_dataset = MRITestDataset(df_val, (mean, stdev))
test_dataset = MRITestDataset(df_test, (mean, stdev))

# Show some data
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

for step, (image, mask) in enumerate(train_loader):

    
    image = torch.reshape(image, (3, 224, 224))
    mask = torch.reshape(mask, (224, 224))
    
    imx = transforms.ToPILImage()(image * 50 )
    mkx = transforms.ToPILImage()(mask)
    plt.figure()
    plt.imshow(imx)
    plt.figure()
    plt.imshow(mkx)

    if step == 4:
        break

# Reset to proper training batchsize
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [None]:
# Taken from other code, computes dice loss
"""
class DiceLoss(nn.Module):

    def __init__(self):
        super(DiceLoss, self).__init__()
        self.smooth = 1.0

    def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        y_pred = y_pred[:, 0].contiguous().view(-1)
        y_true = y_true[:, 0].contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dsc = (2. * intersection + self.smooth) / (
            y_pred.sum() + y_true.sum() + self.smooth
        )
        return 1. - dsc
"""

# Also copied from somewhere else
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [None]:
def train(model, train_loader, val_loader, epochs, lr, mean, stdev):
    
    # Can try different optimizers
    optimizer = torch.optim.Adam(model.parameters(), lr)
    compute_loss = DiceLoss()
    
    train_loss, val_loss = [], []
    
    if torch.cuda.is_available():
        model = model.cuda()
    else:
        print("ERROR: CUDA NOT AVAILABLE")
        return [], []
    
    for epoch in range(epochs):
        
        # Train
        tl = []
        for step, (image, mask) in enumerate(train_loader):
            
            optimizer.zero_grad()
            
            image = image.cuda()
            mask = mask.cuda()
                
            mask_hat = model(image)
            loss = compute_loss(mask_hat, mask)
            #print(loss.item())
            tl.append(loss.item())
            
            loss.backward()
            optimizer.step()
            
            del image, mask, mask_hat
        train_loss.append(np.mean(tl))
        
        # Validate
        vl = []
        for step, (image, mask) in enumerate(val_loader):
            
            image = image.cuda()
            mask = mask.cuda() 
                
            mask_hat = model(image)
            loss = compute_loss(mask_hat, mask)
            vl.append(loss.item())
            
            del image, mask, mask_hat
        val_loss.append(np.mean(vl))
        
        print("Epoch: {0}".format(epoch + 1))
        print("Train Loss: {0}".format(train_loss[-1]))
        print("Validation Loss: {0}".format(val_loss[-1]))
     
    return train_loss, val_loss
            

In [None]:
def test(model, test_loader, mean, stdev):
    
    compute_loss = DiceLoss()
    test_loss = []
    
    if torch.cuda.is_available():
        model = model.cuda()
    else:
        print("ERROR: CUDA NOT AVAILABLE")
        return []
    
    for step, (image, mask) in enumerate(test_loader):
            
        image = image.cuda()
        mask = mask.cuda()
        
        mask_hat = model(image)
        loss = compute_loss(mask_hat, mask)
        test_loss.append(loss.item())
        
        del image, mask, mask_hat
        
    return test_loss  
        
        

In [None]:
class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [None]:
model = UNet()

# Params taken from other code
epochs = 20
lr = 0.0001
train_loss, val_loss = train(model, train_loader, val_loader, epochs, lr, mean, stdev)