In [None]:
from wholeslidedata.iterators import create_batch_iterator
from matplotlib import pyplot as plt
from utils import init_plot, plot_batch, show_plot, print_dataset_statistics
import numpy as np
import torch
import torch.nn as nn
from nn_archs import UNet
from pprint import pprint
from torchsummary import summary
from wholeslidedata.annotation import utils as annotation_utils
from label_utils import to_dysplastic_vs_non_dysplastic
from train_unet import load_config
from tqdm.notebook import tqdm
import torchvision.transforms as transforms
from sklearn.metrics import f1_score
import pandas as pd

In [None]:
# define some colors
colors_1 = ["white", "blue", "green", "orange", "red", 'brown', 'yellow', 'purple', 'pink', 'grey']
colors_2 = ["white", "green", "red", "orange", 'brown', 'yellow', 'purple', 'pink', 'grey', 'green']

In [None]:
# load network config and store in experiment dir
user_config = './configs/unet_training_config.yml'
train_config = load_config(user_config)

# create train and validation generators
training_batch_generator = create_batch_iterator(user_config=user_config,
                                                 mode='training',
                                                 cpus=train_config['cpus'],
                                                 number_of_batches=train_config['train_batches'])

validation_batch_generator = create_batch_iterator(mode='validation',
                                                   user_config=user_config,
                                                   cpus=train_config['cpus'],
                                                   number_of_batches=train_config['val_batches'])

print('\nTraining dataset ')
print_dataset_statistics(training_batch_generator.dataset)
print('\nValidation dataset ')
print_dataset_statistics(validation_batch_generator.dataset)

In [None]:
def crop_center(img, cropx, cropy):
    _, x, y, _ = img.shape
    startx = x // 2 - (cropx // 2)
    starty = y // 2 - (cropy // 2)
    return img[:, starty:starty + cropy, startx:startx + cropx, :]

def plot_center_batch(x, y, y_hat, patches=4):   
    
    # how many patches to plot
    patches = len(y_hat) if patches > len(y_hat) else patches
    
    # get the prediction
    y_hat = torch.argmax(y_hat, dim=1).cpu().detach().numpy()
    
    # center crop the image
    _, h, w = y_hat.shape
    x = crop_center(x, h, w)
    
    print("Ground truth")
    fig, axes = init_plot(1, patches, size=(30, 10))
    plot_batch(axes, 0, x[:patches], y[:patches], alpha=0.3, colors=colors_2)
    plt.show()

    # pad and show prediction
    print("Prediction")
    fig, axes = init_plot(1, patches, size=(30, 10))
    plot_batch(axes, 0, x[:patches], y_hat[:patches], alpha=0.3, colors=colors_2)
    plt.show()

In [None]:
# originally defined UNet (with valid convolutions)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=train_config['n_channels'], n_classes=train_config['n_classes'])

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=train_config['learning_rate'])
criterion = nn.CrossEntropyLoss()

min_val = float('inf')

for n in range(train_config['epochs']):

    tr_losses = []
    val_losses = []
    tr_f1 = []
    val_f1 = []

    for idx, (x, y, info) in enumerate(tqdm(training_batch_generator, desc='Epoch {}'.format(n+1))):
        
        # dysplastic vs non-dysplastic
        y = to_dysplastic_vs_non_dysplastic(y)
        
        # store one example as numpy
        example_train_batch_x = x
        example_train_batch_y = y

        # transform x and y
        x = torch.tensor(x.astype('float32'))
        x = torch.transpose(x, 1, 3).to(device)
        y = torch.tensor(y.astype('int64')).to(device)

        # forward and update
        optimizer.zero_grad()
        y_hat = model.forward(x)
        example_train_batch_y_hat = y_hat
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        tr_losses.append(loss.item())
        
        # compute f1
        y = y.cpu().detach().numpy().flatten()
        y_hat = torch.argmax(y_hat, dim=1).cpu().detach().numpy().flatten()
        tr_f1.append(f1_score(y, y_hat, average=None))
        
    # validate
    with torch.no_grad():
        for idx, (x, y, info) in enumerate(validation_batch_generator):
            
            # dysplastic vs non-dysplastic
            y = to_dysplastic_vs_non_dysplastic(y)         
            
            # store one example as numpy
            example_val_batch_x = x
            example_val_batch_y = y

            # transform x and y
            x = torch.tensor(x.astype('float32'))
            x = torch.transpose(x, 1, 3).to(device)
            y = torch.tensor(y.astype('int64')).to(device)

            # forward and validate
            y_hat = model.forward(x)
            example_val_batch_y_hat = y_hat
            loss = criterion(y_hat, y)
            val_losses.append(loss.item())
            
            # compute f1
            y = y.cpu().detach().numpy().flatten()
            y_hat = torch.argmax(y_hat, dim=1).cpu().detach().numpy().flatten()
            val_f1.append(f1_score(y, y_hat, average=None))

    # compute & plot metrics
    avg_tr_loss = np.mean(tr_losses)
    avg_val_loss = np.mean(val_losses)
    avg_tr_f1 = np.round(np.mean(np.asarray(tr_f1), axis=0), decimals=2)
    avg_val_f1 = np.round(np.mean(np.asarray(val_f1), axis=0), decimals=2)
    print("Train loss: {:.3f}, val loss: {:.3f}".format(avg_tr_loss, avg_val_loss))
    print("Train dice: {}, val dice: {}".format(avg_tr_f1, avg_val_f1))
    
    # plot every 50 epochs
    if n % 50 == 0:
        print('Training examples: ')
        plot_center_batch(example_train_batch_x, example_train_batch_y, example_train_batch_y_hat)
        print('Validation examples: ')
        plot_center_batch(example_val_batch_x, example_val_batch_y, example_val_batch_y_hat)