In [None]:
from wholeslidedata.iterators import create_batch_iterator
from matplotlib import pyplot as plt
from utils import init_plot, plot_batch, show_plot, print_dataset_statistics, mean_metrics
import numpy as np
import torch
import torch.nn as nn
from nn_archs import UNet
from pprint import pprint
from torchsummary import summary
from wholeslidedata.annotation import utils as annotation_utils
from label_utils import to_dysplastic_vs_non_dysplastic
from train_unet import load_config
from sklearn.metrics import f1_score
import pandas as pd
from tqdm.notebook import tqdm
import yaml
import wandb
import os
from distutils.dir_util import copy_tree
import shutil

In [None]:
# define some colors
colors_1 = ["white", "green", "orange", "red", 'yellow', 'yellow', 'purple', 'pink', 'grey', "blue"]
colors_2 = ["white", "green", "red", "yellow", 'brown', 'yellow', 'purple', 'pink', 'grey', 'green']

In [None]:
def crop_center(img, cropx, cropy):
    _, x, y, _ = img.shape
    startx = x // 2 - (cropx // 2)
    starty = y // 2 - (cropy // 2)
    return img[:, starty:starty + cropy, startx:startx + cropx, :]

def plot_center_batch(x, y, y_hat, patches=4):   
    
    # how many patches to plot
    patches = len(y_hat) if patches > len(y_hat) else patches
    
    # get the prediction
    y_hat = torch.argmax(y_hat, dim=1).cpu().detach().numpy()
    
    # center crop the image
    _, h, w = y_hat.shape
    x = crop_center(x, h, w)
    
    print("Ground truth")
    fig, axes = init_plot(1, patches, size=(30, 10))
    plot_batch(axes, 0, x[:patches], y[:patches], alpha=0.3, colors=colors_2)
    plt.show()

    # pad and show prediction
    print("Prediction")
    fig, axes = init_plot(1, patches, size=(30, 10))
    plot_batch(axes, 0, x[:patches], y_hat[:patches], alpha=0.3, colors=colors_2)
    plt.show()

In [None]:
# config path
base_dir = '/home/mbotros/code/barrett_gland_grading/'
experiments_dir = '/home/mbotros/experiments/barrett_gland_grading'
user_config = os.path.join(base_dir, 'configs/unet_training_config.yml')
train_config = load_config(user_config)
run_name = 'test_bolero'

# make experiment dir & copy source files (config and training script)
exp_dir = os.path.join(experiments_dir, run_name)
print('Experiment stored at: {}'.format(exp_dir))
copy_tree(os.path.join(base_dir, 'configs'), os.path.join(exp_dir, 'src', 'configs'))
copy_tree(os.path.join(base_dir, 'nn_archs'), os.path.join(exp_dir, 'src', 'nn_archs'))
shutil.copy2(os.path.join(base_dir, 'train_unet.py'), os.path.join(exp_dir, 'src'))

In [None]:
# lets print some configs
with open(user_config, 'r') as yamlfile:
    data = yaml.load(yamlfile, Loader=yaml.FullLoader)
    
for k, v in data['wholeslidedata']['default'].items():
    print('{}: {}'.format(k, v))

# create train and validation generators
training_batch_generator = create_batch_iterator(user_config=user_config,
                                                 mode='training',
                                                 cpus=train_config['cpus'])

validation_batch_generator = create_batch_iterator(mode='validation',
                                                   user_config=user_config,
                                                   cpus=train_config['cpus'])

print('\nTraining dataset ')
print_dataset_statistics(training_batch_generator.dataset)
print('\nValidation dataset ')
print_dataset_statistics(validation_batch_generator.dataset)

In [None]:
# originally defined UNet (with valid convolutions)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=train_config['n_channels'], n_classes=train_config['n_classes'])
model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=train_config['learning_rate'])
criterion = nn.CrossEntropyLoss()

# log with weights and biases
# os.environ["WANDB_API_KEY"] = '272782fa3a98a5f215cc2e580ebb4628245ea8e8'
# wandb.init(project="Barrett's Gland Grading", dir=exp_dir)
# wandb.run.name = run_name

min_val = float('inf')

for n in range(train_config['epochs']):

    train_metrics = {}
    validation_metrics = {}

    for idx in tqdm(range(train_config['train_batches']), desc='Epoch {}'.format(n + 1)):
        x, y, info = next(training_batch_generator)

        # dysplastic vs non-dysplastic
        y = to_dysplastic_vs_non_dysplastic(y)
        
       # store one example as numpy
        example_train_batch_x = x
        example_train_batch_y = y

        # transform x and y
        x = torch.tensor(x.astype('float32'))
        x = torch.transpose(x, 1, 3).to(device)
        y = torch.tensor(y.astype('int64')).to(device)

        # forward and update
        optimizer.zero_grad()
        y_hat = model.forward(x)
        example_train_batch_y_hat = y_hat
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()

        # compute and store metrics
        y = y.cpu().detach().numpy().flatten()
        y_hat = torch.argmax(y_hat, dim=1).cpu().detach().numpy().flatten()
        train_metrics[idx] = {'loss': loss.item(),
                              'dice per class': f1_score(y, y_hat, average=None, labels=[0, 1, 2]),
                              'dice weighted': f1_score(y, y_hat, average='weighted')}

    # validate
    with torch.no_grad():
        for idx in tqdm(range(train_config['val_batches']), desc='Validating'):
            x, y, info = next(validation_batch_generator)
            
            # dysplastic vs non-dysplastic
            y = to_dysplastic_vs_non_dysplastic(y)
            
            # store one example as numpy
            example_val_batch_x = x
            example_val_batch_y = y

            # transform x and y
            x = torch.tensor(x.astype('float32'))
            x = torch.transpose(x, 1, 3).to(device)
            y = torch.tensor(y.astype('int64')).to(device)

            # forward and validate
            y_hat = model.forward(x)
            example_val_batch_y_hat = y_hat
            loss = criterion(y_hat, y)

            # compute dice
            y = y.cpu().detach().numpy().flatten()
            y_hat = torch.argmax(y_hat, dim=1).cpu().detach().numpy().flatten()
            validation_metrics[idx] = {'loss': loss.item(),
                                       'dice per class': f1_score(y, y_hat, average=None, labels=[0, 1, 2]),
                                       'dice weighted': f1_score(y, y_hat, average='weighted')}

    # compute and print metrics
    training_means = mean_metrics(train_metrics)
    validation_means = mean_metrics(validation_metrics)
    print("Train loss: {:.3f}, val loss: {:.3f}".format(training_means['loss'], validation_means['loss']))
    print("Train dice: {}, val dice: {}".format(np.round(training_means['dice per class'], decimals=2),
                                                np.round(validation_means['dice per class'], decimals=2)))
    # plot predictions
    print('Training examples: ')
    plot_center_batch(example_train_batch_x, example_train_batch_y, example_train_batch_y_hat)
    print('Validation examples: ')
    plot_center_batch(example_val_batch_x, example_val_batch_y, example_val_batch_y_hat)
    
    # wandb.log({'epoch': n + 1,
    #            'train loss': training_means['loss'], 'train dice': training_means['dice weighted'],
    #            'val loss': validation_means['loss'], 'val dice': validation_means['dice weighted']})

    # save best model
    if validation_means['loss'] < min_val:
        torch.save(model.state_dict(),
                   os.path.join(exp_dir, 'model_epoch_{}_loss_{:.3f}.pt').format(n, validation_means['loss']))
        min_val = validation_means['loss']

In [None]:
wandb.finish()