In [None]:
from wholeslidedata.iterators import create_batch_iterator
from matplotlib import pyplot as plt
from utils import init_plot, plot_batch, show_plot, print_dataset_statistics
from shapely.prepared import prep
import os
import numpy as np
import torch
import torch.nn as nn
from nn_archs import UNet
from pprint import pprint
from torchsummary import summary
import yaml
from wholeslidedata.annotation import utils as annotation_utils
from label_utils import to_dysplastic_vs_non_dysplastic

In [None]:
# define some colors
colors_1 = ["white", "blue", "green", "orange", "red", 'brown', 'yellow', 'purple', 'pink', 'grey']
colors_2 = ["white", "green", "red", "orange", 'brown', 'yellow', 'purple', 'pink', 'grey', 'green']

In [None]:
# a originally defined UNet (with valid convs)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=3, n_classes=3).to(device)
# summary(model, input_size=(3, 512, 512), batch_size=8)

In [None]:
# open config for plotting
user_config = './configs/plot_config.yml'
batches = 3 
repeats = 1
cpus = 1
mode = 'training'

# lets print some configs
with open(user_config, 'r') as yamlfile:
    data = yaml.load(yamlfile, Loader=yaml.FullLoader)
    
for k, v in data['wholeslidedata']['default'].items():
    print('{}: {}'.format(k, v))

In [None]:
# create a batch iterator for plotting
training_batch_generator =  create_batch_iterator(user_config=user_config, 
                            number_of_batches=batches, 
                            mode=mode,  
                            cpus=cpus)
    
# print dataset statistics
dataset = training_batch_generator.dataset
print('Training dataset:')
print_dataset_statistics(dataset)

# show annotations
color_map = {'e-stroma': "blue", 'ndbe-g': "green", 'lgd-g': "orange", 'hgd-g': "red"}

# for label in dataset.sample_references.keys():
#     references = dataset.sample_references[label][:10]

#     fig, ax = plt.subplots(1, len(references), figsize=(30,5))
#     for idx, reference in enumerate(references):

#         # get the associated image
#         image = dataset.get_wsi_from_reference(reference)

#         # get the polygon from the associated annotations
#         wsa = dataset.get_wsa_from_reference(reference)
#         annotation = wsa.annotations[reference.annotation_index]

#         # note the spacing 0.25 (magnification level = ?)
#         patch = image.get_annotation(annotation, 0.25)
#         ax[idx].imshow(patch)
#         title = f'{label} {idx}\n area={int(annotation.area)} \n loc={annotation.center}'
#         annotation_utils.plot_annotations([annotation], title=title, ax=ax[idx], use_base_coordinates=True, color_map=color_map)
#     plt.show()

# show a batch
for r in range(repeats):
    for idx, (x_batch, y_batch, info) in enumerate(training_batch_generator):
        # print(info)
        fig, axes = plt.subplots(1, len(y_batch), figsize=(30, 8), squeeze=False)
        plot_batch(axes, 0, x_batch, y_batch, alpha=0.3, colors=colors_1)
        plt.show()
        # print('Labels contain: {}'.format(np.unique(y_batch, return_counts=True)))

        # show simplified dataset
        y_batch = to_dysplastic_vs_non_dysplastic(y_batch)
        # print('Labels contain: {}'.format(np.unique(y_batch, return_counts=True)))
        fig, axes = plt.subplots(1, len(y_batch), figsize=(30, 8), squeeze=False)
        plot_batch(axes, 0, x_batch, y_batch, alpha=0.3, colors=colors_2)
        plt.show()

In [None]:
# the training loop
user_config = './configs/unet_training_config.yml'
cpus = 4
mode = 'training'
epochs = 10
batches = 5
criterion = nn.CrossEntropyLoss()
lr = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
avg_losses = []
preds = []

with create_batch_iterator(mode=mode, 
                           number_of_batches=batches, 
                           user_config=user_config, 
                           cpus=cpus) as training_batch_generator:
    
    for n in range(epochs):
        losses = []
        for idx, (x, y, info) in enumerate(training_batch_generator):
  
            # dysplastic vs non dysplastic, numpy for plots
            x_np = x
            y = to_dysplastic_vs_non_dysplastic(y) 
            y_np = y

            # transform x and y
            x = torch.tensor(x.astype('float32'))
            x = torch.transpose(x, 1, 3).to(device)
            y = torch.tensor(y.astype('int64')).to(device)
            
            optimizer.zero_grad()
            y_hat = model.forward(x)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())       
  
        print("Epoch: {}, loss: {}".format(n, np.mean(losses)))
    
        if n % 50 == 0:

            # pad and show input
            y_pad = np.pad(y_np, pad_width=((0,), (94,), (94,)))
            print("Ground truth")
            fig, axes = init_plot(1, training_batch_generator.batch_size, size=(30, 10))
            plot_batch(axes, 0, x_np, y_pad, alpha=0.3, colors=colors_2)
            plt.show()

            # pad and show prediction
            print("Prediction")
            y_hat = torch.argmax(y_hat, dim=1).cpu().detach().numpy()
            y_hat_pad = np.pad(y_hat, pad_width=((0,), (94,), (94,)))
            fig, axes = init_plot(1, training_batch_generator.batch_size, size=(30, 10))
            plot_batch(axes, 0, x_np, y_hat_pad, alpha=0.3, colors=colors_2)
            plt.show()