In [None]:
# Jupyter notebook for training Gleason Segmentation Model
#
# Based upon pytorch_resnet18_unet.ipynb by Naoto Usuyama
#
# http://github.com/usuyama/pytorch-unet/
#
# Original source code licensed under MIT License:
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os,sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import helper

In [None]:
from torch.utils.data import Dataset, DataLoader
import cv2
import glob
import numpy as np

class GlandDataset(Dataset):
    def __init__(self, root_dir, net_tile_size=(500, 500), labels=['Mask'], levels=[1, 2, 3], transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.mask_paths = glob.glob(os.path.join(root_dir, '*', 'Level' + str(levels[0]), '*.png'))
        
        self.labels = labels
        self.levels = levels
        
        self.net_tile_size = net_tile_size
        self.transform = transform

    def __len__(self):
        return len(self.mask_paths)

    def __getitem__(self, idx):
        image_root, image_name = os.path.split(self.mask_paths[idx])
        image_root = os.path.dirname(image_root)
        
        mask = -np.ones(self.net_tile_size, dtype=np.int64)
        
        for i in range(len(self.labels)):
            mask_path = os.path.join(image_root, self.labels[i], image_name)
            
            if os.path.exists(mask_path):
                mask_part = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32)/255
                if mask_part.shape[0] != self.net_tile_size[0] or mask_part.shape[1] != self.net_tile_size[1]:
                    mask_part = cv2.resize(mask_part, self.net_tile_size, interpolation = cv2.INTER_NEAREST)
            else:
                mask_part = np.zeros(self.net_tile_size).astype(np.float32)
            
            mask[mask_part>0] = i
            
        mask = mask+1
            
        image = ()
        
        for i in range(len(self.levels)):
            image_path = os.path.join(image_root, 'Level' + str(self.levels[i]), image_name)
            im_part = cv2.imread(image_path)
            if im_part.shape[0] != self.net_tile_size[0] or im_part.shape[1] != self.net_tile_size[1]:
                im_part = cv2.resize(im_part, self.net_tile_size, interpolation = cv2.INTER_AREA)
            im_part = im_part[:, :, ::-1].astype(np.float32)/255
            
            image = image + (im_part, )
            
        image = np.concatenate(image, axis=2)
            
        if self.transform:
            image = self.transform(image)

        return [image, mask]

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

trans = transforms.Compose([
    transforms.ToTensor()
])

lev = [0, 1]
lab = ['Normal', 'PIN', 'Gleason3', 'Gleason4', 'Gleason5']

train_set = GlandDataset('../../training_data/new_patches/Training/', net_tile_size=(224, 224), levels=lev, labels=lab, transform = trans)
val_set = GlandDataset('../../training_data/new_patches/Validation/', net_tile_size=(224, 224), levels=lev, labels=lab, transform = trans)

image_datasets = {
    'train': train_set, 'val': val_set
}

batch_size = 1

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

dataset_sizes = {
    x: len(image_datasets[x]) for x in image_datasets.keys()
}

In [None]:
import torch
import torch.nn as nn

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True)
    )

class ResNetUNet(nn.Module):

    def __init__(self, n_class, n_channels=3):
        super().__init__()
        
        self.base_model = models.resnet18(pretrained=True)
        self.base_in_features = [n_channels, 64, 64, 128, 256, 512]
        self.base_out_features = [64, 64, 256, 512, 512, 1024]
        self.model_features = [512, 512, 256, 128, 64, 64, 64]
        
        self.base_layers = list(self.base_model.children())
        if n_channels != 3:
            X = nn.Conv2d(self.level_features[0][0], self.level_features[0][1], kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).cuda()
            self.layer0 = nn.Sequential(*([X]+self.base_layers[1:3])) # size=(N, 64, x.H/2, x.W/2)
        else:
            self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(self.base_in_features[1], self.base_out_features[1], 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 256, x.H/4, x.W/4)        
        self.layer1_1x1 = convrelu(self.base_in_features[2], self.base_out_features[2], 1, 0)       
        self.layer2 = self.base_layers[5]  # size=(N, 512, x.H/8, x.W/8)        
        self.layer2_1x1 = convrelu(self.base_in_features[3], self.base_out_features[3], 1, 0)  
        self.layer3 = self.base_layers[6]  # size=(N, 1024, x.H/16, x.W/16)        
        self.layer3_1x1 = convrelu(self.base_in_features[4], self.base_out_features[4], 1, 0)  
        self.layer4 = self.base_layers[7]  # size=(N, 2048, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(self.base_in_features[5], self.base_out_features[5], 1, 0)
        
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.conv_up3 = convrelu(self.base_out_features[4] + self.base_out_features[5], self.model_features[0], 3, 1)
        self.conv_up2 = convrelu(self.base_out_features[3] + self.model_features[0], self.model_features[1], 3, 1)
        self.conv_up1 = convrelu(self.base_out_features[2] + self.model_features[1], self.model_features[2], 3, 1)
        self.conv_up0 = convrelu(self.base_out_features[1] + self.model_features[2], self.model_features[3], 3, 1)
        
        self.conv_original_size0 = convrelu(n_channels, self.model_features[4], 3, 1)
        self.conv_original_size1 = convrelu(self.model_features[4], self.model_features[5], 3, 1)
        self.conv_original_size2 = convrelu(self.model_features[5] + self.model_features[3], self.model_features[6], 3, 1)
        
        self.conv_last = nn.Conv2d(self.model_features[6], n_class, 1)
        
    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)
        
        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)
 
        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)
        
        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)        
        
        out = self.conv_last(x)        
        
        return out

In [None]:
import torch
import torch.nn as nn

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.7)
    )

class ResNetUNetEnsemble(nn.Module):

    def __init__(self, n_class, n_models):
        super().__init__()
        
        self.n_class = n_class
        self.n_models = n_models
        
        self.base_models = nn.ModuleList([ResNetUNet(n_class, 3) for i in range(self.n_models)])
        
        self.base_in_features = [sum(x) for x in zip(*[model.base_in_features for model in self.base_models])]
        self.base_out_features = [sum(x) for x in zip(*[model.base_out_features for model in self.base_models])]
        self.model_features = [512, 512, 256, 128, 64, sum([model.model_features[5] for model in self.base_models]), 64]
        
        self.layer0_1x1 = convrelu(self.base_in_features[1], self.base_out_features[1], 1, 0)
        self.layer1_1x1 = convrelu(self.base_in_features[2], self.base_out_features[2], 1, 0)
        self.layer2_1x1 = convrelu(self.base_in_features[3], self.base_out_features[3], 1, 0)
        self.layer3_1x1 = convrelu(self.base_in_features[4], self.base_out_features[4], 1, 0)
        self.layer4_1x1 = convrelu(self.base_in_features[5], self.base_out_features[5], 1, 0)
        
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.conv_up3 = convrelu(self.base_out_features[4] + self.base_out_features[5], self.model_features[0], 3, 1)
        self.conv_up2 = convrelu(self.base_out_features[3] + self.model_features[0], self.model_features[1], 3, 1)
        self.conv_up1 = convrelu(self.base_out_features[2] + self.model_features[1], self.model_features[2], 3, 1)
        self.conv_up0 = convrelu(self.base_out_features[1] + self.model_features[2], self.model_features[3], 3, 1)
        
        self.conv_original_size2 = convrelu(self.model_features[5] + self.model_features[3], self.model_features[6], 3, 1)
        
        self.conv_last = nn.Conv2d(self.model_features[6], n_class, 1)
        
    def forward(self, input):
        x_original = [self.base_models[i].conv_original_size0(input[:,(3*i):(3*(i+1)),:,:]) for i in range(self.n_models)]
        x_original = [self.base_models[i].conv_original_size1(x_original[i]) for i in range(self.n_models)]
        x_original = torch.cat(x_original, dim=1)
        
        layer0 = [self.base_models[i].layer0(input[:,(3*i):(3*(i+1)),:,:]) for i in range(self.n_models)]
        layer1 = [self.base_models[i].layer1(layer0[i]) for i in range(self.n_models)]
        layer2 = [self.base_models[i].layer2(layer1[i]) for i in range(self.n_models)]
        layer3 = [self.base_models[i].layer3(layer2[i]) for i in range(self.n_models)]
        layer4 = [self.base_models[i].layer4(layer3[i]) for i in range(self.n_models)]

        layer4 = self.layer4_1x1(torch.cat(layer4, dim=1))
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(torch.cat(layer3, dim=1))
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)
 
        x = self.upsample(x)
        layer2 = self.layer2_1x1(torch.cat(layer2, dim=1))
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(torch.cat(layer1, dim=1))
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(torch.cat(layer0, dim=1))
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)
        
        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)        
        
        out = self.conv_last(x)        
        
        return out

In [None]:
from collections import defaultdict
import torch.nn.functional as F
import time
import copy

def calc_loss(pred, target, metrics, bce_weight=0.5):
    def label_to_onehot(inp):
        out = torch.zeros((inp.shape[0], num_class, inp.shape[1], inp.shape[2])).cuda()
    
        for i in range(0,num_class):
            out[:,i-1,:,:] = (inp==i)
    
        return out
    
    w = torch.cuda.FloatTensor([0.5, 1, 1, 1, 1, 1])
    loss = F.cross_entropy(pred, target, weight=w, ignore_index=-1)
    
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    
    return loss

def print_metrics(metrics, epoch_samples, phase):    
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
        
    print("{}: {}".format(phase, ", ".join(outputs)))    

def train_model(model, optimizer, scheduler, num_epochs=25, best_loss=1e10):
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)
        
        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)      
                labels = labels.to(device)       

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)
                
                time_elapsed = time.time() - since
                print('{:d}/{:d} {:.0f}m {:.0f}s, loss: {:.2f}'.format(epoch_samples // inputs.size(0), len(dataloaders[phase]), time_elapsed // 60, time_elapsed % 60, loss), end='\r')

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, best_loss

In [None]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = len(train_set.labels)+1
num_models = len(train_set.levels)
model = ResNetUNetEnsemble(n_class=num_class, n_models=num_models).to(device)

best_loss = 1e10

In [None]:
from torch.optim import lr_scheduler
import torch.optim as optim

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-6)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)        
       
model, best_loss = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=5, best_loss=best_loss)

In [None]:
import math
import torchvision.utils

def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)
    
    return inp[:,:,:3]

def label_to_onehot(inp):
    out = np.zeros(inp.shape + (num_class, ))
    
    for i in range(0,num_class):
        out[:,:,i-1] = (inp==i)
    
    out = out.transpose([2, 0, 1])
    
    return out

model.eval()   # Set model to evaluate mode
 
inputs, labels = next(iter(dataloaders['val']))
inputs = inputs.to(device)
labels = labels.to(device)

pred = F.softmax(model(inputs), dim=1)
pred = pred.data.cpu().numpy()
pred = np.argmax(pred, axis=1)

# Change channel-order and make 3 channels for matplot
input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]

# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(label_to_onehot(x)) for x in labels.cpu().numpy()]
pred_rgb = [helper.masks_to_colorimg(label_to_onehot(x)) for x in pred]

helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])

In [None]:
import os

save_location = '../../models/NewModel.h5'
os.makedirs(os.path.dirname(save_location), exist_ok=True)

torch.save(model, save_location)

In [None]:
load_location = '../../models/NewModel.h5'

model = torch.load(load_location).to(device)