# Training and prediction notebook

This notebook containg everything to:
* define the Deep Learning model
* define the data sets and data loaders
* train the models
* make and save final predictions

In [None]:
import os
os.add_dll_directory("C:\\Users\\33631\\Desktop\\openslide-win64-20171122\\bin")

import openslide
from histolab.tiler import RandomTiler, GridTiler
from histolab.slide import Slide

In [None]:
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import os
import random
import cv2
import matplotlib
from sklearn.model_selection import train_test_split
from skimage.filters import threshold_otsu
import torchmetrics
import pickle
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torchvision.models as models
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn
from torch.optim import Adam
from torch import LongTensor as LongTensor
from torch import FloatTensor as FloatTensor
import cv2
from collections import Counter
from sklearn.model_selection import StratifiedKFold

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

## Parameters

In [None]:
batch_size = 16
number_of_tiles_per_bag = 10
number_of_folds = 10
n_epoch = 100
early_stopping = 10
learning_rate = 0.0001
embedding_shape = 1408

## Load data

In [None]:
train_set_path = "./Data/raw_data/train.csv"
train_tiles_encoding_folder = "./Data/processed_data/train_tiles_encoding_grid/"

test_set_path = "./Data/raw_data/test.csv"
test_tiles_encoding_folder = "./Data/processed_data/test_tiles_encoding_grid/"

In [None]:
initial_train_set = pd.read_csv(train_set_path)
test_set = pd.read_csv(test_set_path)

In [None]:
# in case we want to predict gleason score instead of ISUP grades
gleason_isup = initial_train_set[['isup_grade','gleason_score']].groupby('gleason_score').first()
gleason_isup_dict = gleason_isup.to_dict()['isup_grade']

id2gleason = dict((k,v) for k,v in enumerate(list(initial_train_set.gleason_score.unique())))
gleason2id = dict((v,k) for k,v in id2gleason.items())
initial_train_set.gleason_score = initial_train_set.gleason_score.apply(lambda x: gleason2id[x])

In [None]:
## keep only images where enough tiles have been retrieved
to_remove = []
for idx in range(initial_train_set.shape[0]):
  value = initial_train_set.iloc[idx]
  image_id = value.image_id
  # compute file
  image_embeddings_path = train_tiles_encoding_folder+image_id+'.pkl'

  # open embeddings  dict
  embeddings_dict = pickle.load(open(image_embeddings_path, 'rb'))
  if len(embeddings_dict) < number_of_tiles_per_bag:
    to_remove.append(image_id)

initial_train_set = initial_train_set.query('image_id not in @to_remove')

## Cross Validations sets computation

In [None]:
kf = StratifiedKFold(n_splits=number_of_folds, random_state=None, shuffle=True)
folds_dict = {}
for i, (train_index, validation_index) in enumerate(kf.split(initial_train_set.drop(columns = ['isup_grade']), initial_train_set[['isup_grade']])):
    folds_dict[i] = (train_index, validation_index)

## Define epoch training

In [None]:
def train_one_epoch(model, trainloader, validationloader, optimizer, device, num_classes = 6): 
    """
    This function trains a model for one epoch
    inputs:
        - model: model to train 
        - trainloader: loader for the training data
        - validationloader: loader for the validation data
        - optimizer: optimizer to use
        - device: cuda or cpu
        - num class: number of class to predict (6 for IUSP and 11 for gleason)
    """
    losses = []

    val_auroc = torchmetrics.AUROC(num_classes = num_classes)
    val_f1 = torchmetrics.F1Score(num_classes = num_classes)
    train_f1 = torchmetrics.F1Score(num_classes = num_classes)
    best_validation_f1 = -np.inf

    ### traning 
    model.train()
    for (features, target) in tqdm(trainloader):
        features, target = features.to(device), target.to(device)

        optimizer.zero_grad()
        predictions = model(features)
        predicted_classes = torch.argmax(predictions, dim=1)
      
        criterion = nn.CrossEntropyLoss()
        loss = criterion(predictions, target)
        losses.append(float(loss))
        loss.backward()
        optimizer.step()
        f1_train = train_f1(predicted_classes.cpu(), target.cpu())

    ### model evaluation 
    model.eval()
    with torch.no_grad():
      for (features, target) in (validationloader):
          features, target = features.to(device), target.to(device)

          predictions = model(features)
          predicted_classes = torch.argmax(predictions, dim=1)
          
          validation_auroc = val_auroc(predictions.cpu().detach(), target.cpu().detach())
          f1_val = val_f1(predicted_classes.cpu().detach(), target.cpu().detach())


    val_f1_final = val_f1.compute()    
    train_f1_final = train_f1.compute()  
    val_auroc_final = val_auroc.compute()
    print('average train loss: ', np.mean(losses))
    print('validation f1: ', val_f1_final)
    print('train f1: ', train_f1_final)
    print('validation auroc: ', val_auroc_final)

    return val_auroc_final

In [None]:
def train_model(model, trainloader, validationloader, optimizer, device, n_epoch, checkpoint_path, early_stopping, num_classes = 6):
    """
    define the whole training of the model with early stopping strategy if the validation AUC stops increasing for a certain number 
    of epochs
        inputs:
        - model: model to train 
        - trainloader: loader for the training data
        - validationloader: loader for the validation data
        - optimizer: optimizer to use
        - device: cuda or cpu
        - n epochs: maximum number of epochs (if no early stopping)
        - checkpoint path: where to save the best models
        - early stopping: number of epochs after which early stopping is triggered
        - num class: number of class to predict (6 for IUSP and 11 for gleason)
    """
    model.to(device)

    
    best_roc = -np.inf
    previous_roc = -np.inf
    counter = 0
    for epoch in range(0, n_epoch):
        print(f"epoch {epoch+1}/{n_epoch}")
        roc = train_one_epoch(model , trainloader, validationloader, optimizer, device, num_classes = num_classes)
        if roc >= best_roc:
            best_roc = roc
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'roc': roc,
            }, checkpoint_path)

            print('New best model saved !')
            counter = 0

        else:
            print('no iprovement')
            counter += 1

        if counter == early_stopping:
            print('early stopping !')
            break 
    model.cpu() # remove model from gpu
    return



## Define data sets


In [None]:
class EmbeddingDatasetNoPosition(Dataset):
    """
    data set that outputs bags and labels but no information about position of the tile
    """

    def __init__(self, df, dir, bag_size = 180, test = False):
        """
        Inputs:
            df: test set or train set
            dir: where to find the tiles embeddings
            bag size: number of tiles per bag
            test: train or test set 
        outputs:
            - bag content and label (if test = False)
        """
        self.df = df
        self.dir = dir
        self.bag_size = bag_size
        self.test = test

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        value = self.df.iloc[idx]
        image_id = value.image_id
        # compute file
        image_embeddings_path = self.dir+image_id+'.pkl'

        # open embeddings  dict
        embeddings_dict = pickle.load(open(image_embeddings_path, 'rb'))

        bag = None
        embedding_list = np.array(list(embeddings_dict.values()))
        np.random.shuffle(embedding_list)
        bag = np.vstack(embedding_list[:self.bag_size])

        bag = torch.Tensor(bag)
        # compute label
        if not self.test:
            label = torch.tensor(value.isup_grade)
            return bag, label
        else:
            return(bag)

## Define Models

In [None]:
class ModelVanilla(nn.Module):
    """
    baseline model
    """
    def __init__(self, number_of_embeddings = 180, output_shape = 6, embedding_shape = 1280):
        super(ModelVanilla, self).__init__()
        
        # to encode each bag
        self.aggregator = torch.nn.AvgPool1d(number_of_embeddings)


        self.classifier = nn.Sequential(
                        nn.Linear(embedding_shape, int(embedding_shape/2)),
                        nn.ReLU(),
                        torch.nn.BatchNorm1d(int(embedding_shape/2)),
                        nn.Dropout(0.25),
                        nn.Linear(int(embedding_shape/2), int(embedding_shape/4)),
                        nn.ReLU(),
                        nn.Dropout(0.25),
                        torch.nn.BatchNorm1d(int(embedding_shape/4)),
                        nn.Linear(int(embedding_shape/4), output_shape)
                        )
        
        
        self.softmax = nn.Softmax(dim = 1)


        
    def forward(self, bag):
      bag = torch.transpose(bag, 1, 2)
      wsi_descriptor = self.aggregator(bag).squeeze(-1)
      out = self.classifier(wsi_descriptor)
      out = self.softmax(out)
      return out

In [None]:
class AttentionModel(nn.Module):
    """
    initial attention model
    """
    def __init__(self, number_of_embeddings = 180, output_shape = 6, embedding_shape = 1280):
        super(AttentionModel, self).__init__()
        
        # to compute attenation map for aggregation
        self.attention = nn.Sequential(
            nn.Linear(embedding_shape, int(embedding_shape/4)),
            nn.Tanh(),
            nn.Linear(int(embedding_shape/4), 1)
            )

        self.aggregator = torch.nn.AvgPool1d(number_of_embeddings)


        self.classifier = nn.Sequential(
                        nn.Linear(embedding_shape, int(embedding_shape/2)),
                        nn.ReLU(),
                        torch.nn.BatchNorm1d(int(embedding_shape/2)),
                        nn.Dropout(0.25),
                        nn.Linear(int(embedding_shape/2), int(embedding_shape/4)),
                        nn.ReLU(),
                        torch.nn.BatchNorm1d(int(embedding_shape/4)),
                        nn.Dropout(0.25),
                        nn.Linear(int(embedding_shape/4), output_shape)
        )
        
        
        self.softmax_attention = nn.Softmax(dim = 2)
        self.softmax = nn.Softmax(dim = 1)


        
    def forward(self, bag):
      # compute attention map
      attention_map = self.attention(bag)
      attention_map = torch.transpose(attention_map, 2, 1) 
      attention_map = self.softmax_attention(attention_map).squeeze(1)

      # apply attention map
      attention_map = torch.diag_embed(attention_map)
      M = torch.bmm(attention_map, bag)
      M = torch.transpose(M, 2,1)
      # aggregate
      wsi_descriptor = self.aggregator(M).squeeze(-1)

      # classification
      out = self.classifier(wsi_descriptor)
      out = self.softmax(out)
      return out

In [None]:
class GatedAttentionModel(nn.Module):
    """
    model with gated attention
    """
    def __init__(self, number_of_embeddings = 180, output_shape = 6, embedding_shape = 1280):
        super(GatedAttentionModel, self).__init__()
        
        # to compute attenation map for aggregation
        self.attention_tanh = nn.Sequential(
            nn.Linear(embedding_shape, int(embedding_shape/4)),
            nn.Tanh()
            )

        self.attention_sig = nn.Sequential(
            nn.Linear(embedding_shape, int(embedding_shape/4)),
            nn.Sigmoid()
            )

        self.attention_global = nn.Linear(int(embedding_shape/4), 1)

        self.aggregator = torch.nn.AvgPool1d(number_of_embeddings)


        self.classifier = nn.Sequential(
                        nn.Linear(embedding_shape, int(embedding_shape/2)),
                        nn.ReLU(),
                        torch.nn.BatchNorm1d(int(embedding_shape/2)),
                        nn.Dropout(0.25),
                        nn.Linear(int(embedding_shape/2), int(embedding_shape/4)),
                        nn.ReLU(),
                        torch.nn.BatchNorm1d(int(embedding_shape/4)),
                        nn.Dropout(0.25),
                        nn.Linear(int(embedding_shape/4), output_shape)
        )
        
        
        self.softmax_attention = nn.Softmax(dim = 2)
        self.softmax = nn.Softmax(dim = 1)

        
    def forward(self, bag):
      # compute attention map
      attention_sig = self.attention_sig(bag)
      attention_tan = self.attention_tanh(bag)
      attention_map = self.attention_global(torch.mul(attention_sig, attention_tan))
      
      attention_map = torch.transpose(attention_map, 2, 1) 
      attention_map = self.softmax_attention(attention_map).squeeze(1)

      # apply attention map
      attention_map = torch.diag_embed(attention_map)
      M = torch.bmm(attention_map, bag)
      M = torch.transpose(M, 2,1)
      # aggregate
      wsi_descriptor = self.aggregator(M).squeeze(-1)

      # classification
      out = self.classifier(wsi_descriptor)
      out = self.softmax(out)
      return out

In [None]:
class ModelChowder(nn.Module):
    """
    Chowder Model
    """
    def __init__(self, number_of_embeddings = 180, output_shape = 6, embedding_shape = 1280, R = 5):
        super(ModelChowder, self).__init__()
        
        # to encode each bag
        self.R = R
        self.conv = torch.nn.Conv1d(number_of_embeddings,number_of_embeddings,kernel_size =embedding_shape)
        
        self.classifier = nn.Sequential(
                        nn.Linear(2*self.R, 200),
                        nn.Sigmoid(),
                        torch.nn.BatchNorm1d(200),
                        nn.Dropout(0.25),
                        nn.Linear(200, 100),
                        nn.Sigmoid(),
                        torch.nn.BatchNorm1d(100),
                        nn.Dropout(0.25),
                        nn.Linear(100, output_shape)
                        )
        
        
        self.softmax = nn.Softmax(dim = 1)


        
    def forward(self, bag):
      post_conv = self.conv(bag).squeeze(-1)
      post_conv = torch.sort(post_conv)[0]

      max_values = post_conv[:,-self.R:]
      min_values = post_conv[:,:self.R]

      wsi_descriptor = torch.cat((max_values,min_values), 1)



      out = self.classifier(wsi_descriptor)
      out = self.softmax(out)
      return out

## Model training

In [None]:
number_of_classes = 6

In [None]:
# training loop for each model and each folds
for fold_number, (train_index, val_index) in folds_dict.items():
    print(f"train fold: {fold_number}")

    # split train/val
    train_set, validation_set = initial_train_set.iloc[train_index], initial_train_set.iloc[val_index]

    # create data sets
    train_set = EmbeddingDatasetNoPosition(train_set, train_tiles_encoding_folder, bag_size = number_of_tiles_per_bag)
    validation_set = EmbeddingDatasetNoPosition(validation_set, train_tiles_encoding_folder, bag_size = number_of_tiles_per_bag)

    # create loaders
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last = True)
    validationloader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=True, drop_last = True)

    # train models
    # Attention model
    """
    print("start training Attention model")
    model = AttentionModel(number_of_embeddings = number_of_tiles_per_bag, output_shape = number_of_classes, embedding_shape=embedding_shape)
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate )
    checkpoint_path_attention = f"./Data/checkpoints/model_attention_{fold_number+1}"
    train_model(model, trainloader, validationloader, optimizer, device, n_epoch, checkpoint_path_attention, early_stopping, num_classes = number_of_classes)
    """

    print("start training Gated Attention model")
    model = GatedAttentionModel(number_of_embeddings = number_of_tiles_per_bag, output_shape = number_of_classes, embedding_shape=embedding_shape)
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate )
    checkpoint_path_attention = f"./Data/checkpoints/model_gated_attention_{fold_number+1}"
    train_model(model, trainloader, validationloader, optimizer, device, n_epoch, checkpoint_path_attention, early_stopping, num_classes = number_of_classes)

    """
    # Chowder model
    print("start training Chowder model")
    model = ModelChowder(number_of_embeddings = number_of_tiles_per_bag, output_shape = number_of_classes,  embedding_shape=embedding_shape)
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    checkpoint_path_chowder = f"./Data/checkpoints/model_chowder_fold_{fold_number+1}"
    train_model(model, trainloader, validationloader, optimizer, device, n_epoch, checkpoint_path_chowder, early_stopping, num_classes = number_of_classes)

    # Vanilla model
    print("start training Vanilla model")
    model = ModelVanilla(number_of_embeddings = number_of_tiles_per_bag, output_shape = number_of_classes,  embedding_shape=embedding_shape)
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate )
    checkpoint_path_vanilla = f"./Data/checkpoints/model_vanilla_{fold_number+1}"
    best_model = train_model(model, trainloader, validationloader, optimizer, device, n_epoch, checkpoint_path_vanilla, early_stopping, num_classes = number_of_classes)
    """


# Prediction

In [None]:
"""
model_chowder = ModelChowder(number_of_embeddings = number_of_tiles_per_bag, output_shape = number_of_classes, embedding_shape=embedding_shape)
model_vanilla = ModelVanilla(number_of_embeddings = number_of_tiles_per_bag, output_shape = number_of_classes, embedding_shape=embedding_shape)
"""
model_attention_g = GatedAttentionModel(number_of_embeddings = number_of_tiles_per_bag, output_shape = number_of_classes, embedding_shape=embedding_shape)

In [None]:
# prediction loop for each model and each folds (100 predictions are made for each) --> argmax voting scheme


final_df = pd.DataFrame(columns = ['Id','Predicted'])
total_predictions = []
testing_set = EmbeddingDatasetNoPosition(test_set, test_tiles_encoding_folder, bag_size = number_of_tiles_per_bag, test = True)
testloader = torch.utils.data.DataLoader(testing_set, batch_size=len(test_set), shuffle=False)
for fold_number in range(number_of_folds):
    predictions_fold= []
    print(f"current fold: {fold_number+1}")
    # open checkpoints
    #checkpoint_chowder = torch.load(f"./Data/checkpoints/model_chowder_fold_{fold_number+1}")
    #checkpoint_vanilla = torch.load(f"./Data/checkpoints/model_vanilla_{fold_number+1}")
    checkpoint_attention_g = torch.load(f"./Data/checkpoints/model_gated_attention_{fold_number+1}")

    # load weights
    #model_chowder.load_state_dict(checkpoint_chowder['model_state_dict']) 
    #model_vanilla.load_state_dict(checkpoint_vanilla['model_state_dict']) 
    model_attention_g.load_state_dict(checkpoint_attention_g['model_state_dict']) 
    pred_fold = None
        
    for i in range(100):
        
        for bag in testloader:



            # get predictions
            """
            model_chowder.eval()
            pred_chowder = model_chowder(bag).detach().cpu().numpy()

            model_vanilla.eval()
            pred_vanilla = model_vanilla(bag).detach().cpu().numpy()
            """

            model_attention_g.eval()
            pred_attention_g = model_attention_g(bag).detach().cpu().numpy()

            if pred_fold is None:
                pred_fold = pred_attention_g #+ pred_vanilla  + pred_chowder
            else:
                pred_fold += pred_attention_g

        pred_current_fold = list(np.argmax(pred_fold, axis =1))

    total_predictions.append(pred_current_fold)

In [None]:
# max voiting between folds

a = np.array(total_predictions)
final_prediction = []
for i in range(a.shape[1]):
    count_dict = Counter(a[:,i])
    pred  = max(count_dict, key=count_dict.get)
    final_prediction.append(pred)


In [None]:
# create and save final data frame
final_df = pd.DataFrame()
final_df['Id'] = list(test_set.image_id)
final_df['Predicted'] = final_prediction

final_df.to_csv("final_preds.csv", index = False)

## Visualisation of attention map

In [None]:
# load a trained attention model
model_attention = AttentionModel(number_of_embeddings = number_of_tiles_per_bag, output_shape = number_of_classes, embedding_shape=embedding_shape)
checkpoint_attention = torch.load(f"./Data/checkpoints/model_attention_{1}")
model_attention.load_state_dict(checkpoint_attention['model_state_dict']) 

In [None]:
class VisualisationSet(Dataset):
    """
    Data set that outputs bags, label, positions of the tiles and images id
    """

    def __init__(self, df, dir, bag_size = 180, test = False):
        """
        Inputs:
            df: test set or train set
            dir: where to find the tiles embeddings
            bag size: number of tiles per bag
            test: train or test set 
        outputs:
            - bag content
            - tiles coordinates
            - label (ISUP grades)
            - image_id: id of the image whose tiles are in the bag
        """
        self.df = df
        self.dir = dir
        self.bag_size = bag_size
        self.test = test

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        value = self.df.iloc[idx]
        image_id = value.image_id
        # compute file
        image_embeddings_path = self.dir+image_id+'.pkl'

        # open embeddings  dict
        embeddings_dict = pickle.load(open(image_embeddings_path, 'rb'))

        bag = None
        embedding_keys = list(embeddings_dict.keys())
        np.random.shuffle(embedding_keys)
        selected_keys = embedding_keys[:self.bag_size]
        bag = np.vstack([embeddings_dict[key] for key in list(selected_keys)])
        selected_keys = np.vstack(selected_keys) 
        bag = torch.Tensor(bag)
        # compute label
        if not self.test:
            label = torch.tensor(value.isup_grade)
            return bag, selected_keys,image_id, label
        else:
            return(bag,selected_keys,image_id)

In [None]:
class PlotModel(nn.Module):
    """
    model that only outpus the attention map
    """
    def __init__(self, attention_model):
        """
        inputs:
            - attention model: trained attention model
        """
        super(PlotModel, self).__init__()
        
        # to compute attenation map for aggregation
        self.attention = attention_model.attention
        
        self.softmax = nn.Softmax(dim = 2)


        
    def forward(self, bag):
      # compute attention map
      attention_map = self.attention(bag)
      attention_map = torch.transpose(attention_map, 2, 1) 
      attention_map = self.softmax(attention_map).squeeze(1)
      return attention_map

In [None]:
train_set = VisualisationSet(initial_train_set, train_tiles_encoding_folder, bag_size = number_of_tiles_per_bag)
trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last = True)
plot_model = PlotModel(model_attention)

In [None]:
# get an example of attention map
attention = []
for emb,loc,id_list, lab in trainloader:
    break

attention_values = plot_model(emb).detach().cpu().numpy()
id = id_list[0]
attention = attention_values[0]
locations = loc[0]

# retrieve corresponding image and mask

image = openslide.OpenSlide("Data/raw_data/train/train/"+id+".tiff")
mask = openslide.OpenSlide("Data/raw_data/train_label_masks/train_label_masks/"+id+".tiff")

In [None]:
# plot mask
mask_array = np.array(mask.read_region((0,0), 0, mask.level_dimensions[0]))
cmap = matplotlib.colors.ListedColormap(['black', 'gray', 'green', 'yellow', 'orange', 'red'])
plt.imshow(mask_array[:,:,0], cmap=cmap, interpolation='nearest', vmin=0, vmax=5)

In [None]:
# compute attention heat map
image_array = np.array(image.read_region((0,0), 0, image.level_dimensions[0]))
heat_map = np.zeros((image_array.shape[0], image_array.shape[1]))
for i,(x_ul_wsi,y_ul_wsi, x_br_wsi,y_br_wsi) in enumerate(locations):
    heat_map[y_ul_wsi:y_br_wsi,x_ul_wsi:x_br_wsi] = attention[i]

In [None]:
# plot the attention heat map on top on the WSI image
fig, ax = plt.subplots(figsize=(20, 16))
ax.imshow(image_array)
ax.imshow(heat_map, alpha = 0.5, cmap='jet' )