In [18]:
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import os
import random
import cv2
import matplotlib
from sklearn.model_selection import train_test_split
from skimage.filters import threshold_otsu
import torchmetrics
import pickle
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torchvision.models as models
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn
from torch.optim import Adam
from torch import LongTensor as LongTensor
from torch import FloatTensor as FloatTensor
import cv2

In [19]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

## Load data

In [20]:
train_set_path = "./Data/raw_data/train.csv"
train_tiles_encoding_folder = "./Data/processed_data/train_tiles_encoding/"

test_set_path = "./Data/raw_data/test.csv"
test_tiles_encoding_folder = "./Data/processed_data/test_tiles_encoding/"

In [21]:
train_set = pd.read_csv(train_set_path)
test_set = pd.read_csv(test_set_path)

In [22]:
## keep only if enough tiles retrieve
to_remove = []
for idx in range(train_set.shape[0]):
  value = train_set.iloc[idx]
  image_id = value.image_id
  # compute file
  image_embeddings_path = train_tiles_encoding_folder+image_id+'.pkl'

  # open embeddings  dict
  embeddings_dict = pickle.load(open(image_embeddings_path, 'rb'))
  if len(embeddings_dict) < 180:
    to_remove.append(image_id)

In [23]:
train_set = train_set.query('image_id not in @to_remove')

## Prepare Data

In [24]:
validation_ratio = 0.1

X = train_set.drop(columns = ['isup_grade'])
y = train_set[['isup_grade']]
train_samples, validation_samples, train_labels, validation_labels = train_test_split(X,y, test_size=validation_ratio, random_state=0, shuffle=True, stratify = y)

train_set = pd.concat([train_samples, train_labels], axis = 1)
validation_set = pd.concat([validation_samples, validation_labels], axis = 1)

## Define epoch training

In [25]:
# 1 epoch training

def train_one_epoch(model, trainloader, validationloader, optimizer, device): 
    losses = []

    val_auroc = torchmetrics.AUROC(num_classes = 6)
    val_f1 = torchmetrics.F1Score(num_classes = 6)
    train_f1 = torchmetrics.F1Score(num_classes = 6)
    best_validation_f1 = -np.inf

    ### traning 
    model.train()
    for (features, target) in tqdm(trainloader):
        features, target = features.to(device), target.to(device)

        optimizer.zero_grad()
        predictions = model(features)
        predicted_classes = torch.argmax(predictions, dim=1)
      
        criterion = nn.CrossEntropyLoss()
        loss = criterion(predictions, target)
        losses.append(float(loss))
        loss.backward()
        optimizer.step()
        f1_train = train_f1(predicted_classes.cpu(), target.cpu())

    ### model evaluation 
    model.eval()
    with torch.no_grad():
      for (features, target) in (validationloader):
          features, target = features.to(device), target.to(device)

          predictions = model(features)
          predicted_classes = torch.argmax(predictions, dim=1)
          
          validation_auroc = val_auroc(predictions.cpu().detach(), target.cpu().detach())
          f1_val = val_f1(predicted_classes.cpu().detach(), target.cpu().detach())


    val_f1_final = val_f1.compute()    
    train_f1_final = train_f1.compute()  
    val_auroc_final = val_auroc.compute()
    print('average train loss: ', np.mean(losses))
    print('validation f1: ', val_f1_final)
    print('train f1: ', train_f1_final)
    print('validation auroc: ', val_auroc_final)

    return val_auroc_final

In [26]:
def train_model(model, trainloader, validationloader, optimizer, device, n_epoch, checkpoint_path, early_stopping):
    model.to(device)

    
    best_roc = -np.inf
    previous_roc = -np.inf
    counter = 0
    for epoch in range(0, n_epoch):
        print(f"epoch {epoch+1}/{n_epoch}")
        roc = train_one_epoch(model , trainloader, validationloader, optimizer, device)
        if roc >= best_roc:
            best_roc = roc
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'roc': roc,
            }, checkpoint_path)

            print('New best model saved !')

        if roc <= previous_roc:
            print('no iprovement')
            counter += 1
        
        else:
            counter = 0

        previous_roc = roc

        if counter == early_stopping:
            print('early stopping !')
            break 
    print("load best model")

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return(model)       

## Define data sets


In [27]:
class EmbeddingDatasetNoPosition(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, df, dir, bag_size = 180):
        """
        Args:
            df (dataframe): Path to the csv file with annotations.
            dir (string): Directory with tiles images embeddings
        """
        self.df = df
        self.dir = dir
        self.bag_size = bag_size

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        value = self.df.iloc[idx]
        image_id = value.image_id
        # compute file
        image_embeddings_path = self.dir+image_id+'.pkl'

        # open embeddings  dict
        embeddings_dict = pickle.load(open(image_embeddings_path, 'rb'))

        bag = None
        embedding_list = np.array(list(embeddings_dict.values()))
        np.random.shuffle(embedding_list)
        bag = np.vstack(embedding_list[:self.bag_size])

        bag = torch.Tensor(bag)
        # compute label
        label = torch.tensor(value.isup_grade)

        return bag, label

## Define Models

In [28]:
class ModelVanilla(nn.Module):
    def __init__(self, number_of_embeddings = 180, output_shape = 6, embedding_shape = 1280):
        super(ModelVanilla, self).__init__()
        
        # to encode each bag
        self.aggregator = torch.nn.AvgPool1d(number_of_embeddings)


        self.classifier = nn.Sequential(
                        nn.Linear(embedding_shape, int(embedding_shape/2)),
                        nn.ReLU(),
                        torch.nn.BatchNorm1d(int(embedding_shape/2)),
                        nn.Dropout(0.25),
                        nn.Linear(int(embedding_shape/2), int(embedding_shape/4)),
                        nn.ReLU(),
                        nn.Dropout(0.25),
                        torch.nn.BatchNorm1d(int(embedding_shape/4)),
                        nn.Linear(int(embedding_shape/4), output_shape)
                        )
        
        
        self.softmax = nn.Softmax(dim = 1)


        
    def forward(self, bag):
      bag = torch.transpose(bag, 1, 2)
      wsi_descriptor = self.aggregator(bag).squeeze(-1)
      out = self.classifier(wsi_descriptor)
      out = self.softmax(out)
      return out

In [29]:
class AttentionModel(nn.Module):
    def __init__(self, number_of_embeddings = 180, output_shape = 6, embedding_shape = 1280):
        super(AttentionModel, self).__init__()
        
        # to compute attenation map for aggregation
        self.attention = nn.Sequential(
            nn.Linear(embedding_shape, int(embedding_shape/4)),
            nn.Tanh(),
            nn.Linear(int(embedding_shape/4), 1)
            )


        self.classifier = nn.Sequential(
                        nn.Linear(embedding_shape, int(embedding_shape/2)),
                        nn.ReLU(),
                        torch.nn.BatchNorm1d(int(embedding_shape/2)),
                        nn.Dropout(0.25),
                        nn.Linear(int(embedding_shape/2), int(embedding_shape/4)),
                        nn.ReLU(),
                        torch.nn.BatchNorm1d(int(embedding_shape/4)),
                        nn.Dropout(0.25),
                        nn.Linear(int(embedding_shape/4), output_shape)
        )
        
        
        self.softmax = nn.Softmax(dim = 1)


        
    def forward(self, bag):
      # compute attention map
      attention_map = self.attention(bag)
      attention_map = torch.transpose(attention_map, 2, 1) # KxN
      attention_map = self.softmax(attention_map)
      # apply attention map
      M = torch.matmul(attention_map, bag).squeeze(1)
      # classification
      out = self.classifier(M)
      out = self.softmax(out)
      return out

In [30]:
class ModelChowder(nn.Module):
    def __init__(self, number_of_embeddings = 180, output_shape = 6, embedding_shape = 1280, R = 5):
        super(ModelChowder, self).__init__()
        
        # to encode each bag
        self.R = R
        self.conv = torch.nn.Conv1d(number_of_embeddings,number_of_embeddings,kernel_size =embedding_shape)
        
        self.classifier = nn.Sequential(
                        nn.Linear(2*self.R, 200),
                        nn.Sigmoid(),
                        torch.nn.BatchNorm1d(200),
                        nn.Dropout(0.25),
                        nn.Linear(200, 100),
                        nn.Sigmoid(),
                        torch.nn.BatchNorm1d(100),
                        nn.Dropout(0.25),
                        nn.Linear(100, output_shape)
                        )
        
        
        self.softmax = nn.Softmax(dim = 1)


        
    def forward(self, bag):
      post_conv = self.conv(bag).squeeze(-1)
      post_conv = torch.sort(post_conv)[0]

      max_values = post_conv[:,-self.R:]
      min_values = post_conv[:,:self.R]

      wsi_descriptor = torch.cat((max_values,min_values), 1)



      out = self.classifier(wsi_descriptor)
      out = self.softmax(out)
      return out

## Train models

In [31]:
batch_size = 16
number_of_tiles_per_bag = 20
train_set = EmbeddingDatasetNoPosition(train_set, train_tiles_encoding_folder, bag_size = number_of_tiles_per_bag)
validation_set = EmbeddingDatasetNoPosition(validation_set, train_tiles_encoding_folder, bag_size = number_of_tiles_per_bag)


trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last = True)
validationloader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=True, drop_last = True)


In [32]:
model = ModelChowder(number_of_embeddings = number_of_tiles_per_bag)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.00001)
checkpoint_path_chowder = "./Data/checkpoints/model_chowder"
best_model = train_model(model, trainloader, validationloader, optimizer, device, 100, checkpoint_path_chowder, 5)



epoch 1/100


100%|██████████| 19/19 [00:01<00:00, 12.49it/s]


average train loss:  1.7963831424713135
validation f1:  tensor(0.1562)
train f1:  tensor(0.1842)
validation auroc:  tensor(0.6403)
New best model saved !
epoch 2/100


100%|██████████| 19/19 [00:01<00:00, 14.37it/s]


average train loss:  1.7933124178334285
validation f1:  tensor(0.0938)
train f1:  tensor(0.1842)
validation auroc:  tensor(0.5282)
no iprovement
epoch 3/100


100%|██████████| 19/19 [00:01<00:00, 10.92it/s]


average train loss:  1.779329776763916
validation f1:  tensor(0.3125)
train f1:  tensor(0.2270)
validation auroc:  tensor(0.6368)
epoch 4/100


100%|██████████| 19/19 [00:01<00:00, 12.49it/s]


average train loss:  1.7856169939041138
validation f1:  tensor(0.2812)
train f1:  tensor(0.2270)
validation auroc:  tensor(0.5966)
no iprovement
epoch 5/100


100%|██████████| 19/19 [00:01<00:00, 12.75it/s]


average train loss:  1.7796101005453813
validation f1:  tensor(0.2812)
train f1:  tensor(0.1974)
validation auroc:  tensor(0.5254)
no iprovement
epoch 6/100


100%|██████████| 19/19 [00:01<00:00, 13.37it/s]


average train loss:  1.76618870935942
validation f1:  tensor(0.2812)
train f1:  tensor(0.2500)
validation auroc:  tensor(0.5610)
epoch 7/100


100%|██████████| 19/19 [00:01<00:00, 13.63it/s]


average train loss:  1.7647827424501117
validation f1:  tensor(0.3125)
train f1:  tensor(0.2632)
validation auroc:  tensor(0.6091)
epoch 8/100


100%|██████████| 19/19 [00:01<00:00, 13.29it/s]


average train loss:  1.7708625040556256
validation f1:  tensor(0.1875)
train f1:  tensor(0.2336)
validation auroc:  tensor(0.6038)
no iprovement
epoch 9/100


100%|██████████| 19/19 [00:01<00:00, 13.51it/s]


average train loss:  1.7581619463468854
validation f1:  tensor(0.2188)
train f1:  tensor(0.2730)
validation auroc:  tensor(0.6414)
New best model saved !
epoch 10/100


100%|██████████| 19/19 [00:01<00:00, 18.01it/s]


average train loss:  1.746614186387313
validation f1:  tensor(0.4062)
train f1:  tensor(0.2763)
validation auroc:  tensor(0.6533)
New best model saved !
epoch 11/100


100%|██████████| 19/19 [00:01<00:00, 12.15it/s]


average train loss:  1.7529224722008956
validation f1:  tensor(0.2812)
train f1:  tensor(0.2599)
validation auroc:  tensor(0.6548)
New best model saved !
epoch 12/100


100%|██████████| 19/19 [00:01<00:00, 12.24it/s]


average train loss:  1.7551368788668984
validation f1:  tensor(0.3438)
train f1:  tensor(0.2368)
validation auroc:  tensor(0.6441)
no iprovement
epoch 13/100


100%|██████████| 19/19 [00:01<00:00, 11.09it/s]


average train loss:  1.752535713346381
validation f1:  tensor(0.2188)
train f1:  tensor(0.2796)
validation auroc:  tensor(0.6227)
no iprovement
epoch 14/100


100%|██████████| 19/19 [00:01<00:00, 11.65it/s]


average train loss:  1.7375727766438533
validation f1:  tensor(0.3125)
train f1:  tensor(0.2862)
validation auroc:  tensor(0.6469)
epoch 15/100


100%|██████████| 19/19 [00:01<00:00, 13.41it/s]


average train loss:  1.7361961289456016
validation f1:  tensor(0.3750)
train f1:  tensor(0.3092)
validation auroc:  tensor(0.6260)
no iprovement
epoch 16/100


100%|██████████| 19/19 [00:01<00:00, 17.53it/s]


average train loss:  1.7326265887210244
validation f1:  tensor(0.3750)
train f1:  tensor(0.3257)
validation auroc:  tensor(0.6814)
New best model saved !
epoch 17/100


100%|██████████| 19/19 [00:01<00:00, 12.79it/s]


average train loss:  1.7221018703360307
validation f1:  tensor(0.3125)
train f1:  tensor(0.3092)
validation auroc:  tensor(0.6111)
no iprovement
epoch 18/100


100%|██████████| 19/19 [00:01<00:00, 12.65it/s]


average train loss:  1.7213397026062012
validation f1:  tensor(0.3750)
train f1:  tensor(0.3191)
validation auroc:  tensor(0.6406)
epoch 19/100


100%|██████████| 19/19 [00:01<00:00, 12.57it/s]


average train loss:  1.7138549654107345
validation f1:  tensor(0.2812)
train f1:  tensor(0.3191)
validation auroc:  tensor(0.6253)
no iprovement
epoch 20/100


100%|██████████| 19/19 [00:01<00:00, 11.16it/s]


average train loss:  1.7132046912845813
validation f1:  tensor(0.4375)
train f1:  tensor(0.3454)
validation auroc:  tensor(0.7421)
New best model saved !
epoch 21/100


100%|██████████| 19/19 [00:02<00:00,  9.11it/s]


average train loss:  1.723845362663269
validation f1:  tensor(0.2188)
train f1:  tensor(0.3092)
validation auroc:  tensor(0.6597)
no iprovement
epoch 22/100


100%|██████████| 19/19 [00:01<00:00, 13.08it/s]


average train loss:  1.710542176899157
validation f1:  tensor(0.3125)
train f1:  tensor(0.3355)
validation auroc:  tensor(0.6342)
no iprovement
epoch 23/100


100%|██████████| 19/19 [00:02<00:00,  7.76it/s]


average train loss:  1.7141458988189697
validation f1:  tensor(0.4062)
train f1:  tensor(0.3191)
validation auroc:  tensor(0.6234)
no iprovement
epoch 24/100


100%|██████████| 19/19 [00:02<00:00,  8.60it/s]


average train loss:  1.703666367028889
validation f1:  tensor(0.4062)
train f1:  tensor(0.3487)
validation auroc:  tensor(0.7608)
New best model saved !
epoch 25/100


100%|██████████| 19/19 [00:02<00:00,  8.80it/s]


average train loss:  1.7103547297025983
validation f1:  tensor(0.4062)
train f1:  tensor(0.3355)
validation auroc:  tensor(0.7305)
no iprovement
epoch 26/100


100%|██████████| 19/19 [00:02<00:00,  9.49it/s]


average train loss:  1.7173075111288774
validation f1:  tensor(0.3125)
train f1:  tensor(0.3191)
validation auroc:  tensor(0.7222)
no iprovement
epoch 27/100


100%|██████████| 19/19 [00:02<00:00,  9.25it/s]


average train loss:  1.696863757936578
validation f1:  tensor(0.3125)
train f1:  tensor(0.3750)
validation auroc:  tensor(0.6698)
no iprovement
epoch 28/100


100%|██████████| 19/19 [00:01<00:00,  9.53it/s]


average train loss:  1.7031103498057316
validation f1:  tensor(0.3438)
train f1:  tensor(0.3487)
validation auroc:  tensor(0.6518)
no iprovement
epoch 29/100


100%|██████████| 19/19 [00:02<00:00,  9.21it/s]


average train loss:  1.6963515595385903
validation f1:  tensor(0.3125)
train f1:  tensor(0.3520)
validation auroc:  tensor(0.6860)
epoch 30/100


100%|██████████| 19/19 [00:02<00:00,  8.66it/s]


average train loss:  1.7023609562924034
validation f1:  tensor(0.3438)
train f1:  tensor(0.3553)
validation auroc:  tensor(0.6873)
epoch 31/100


100%|██████████| 19/19 [00:01<00:00,  9.68it/s]


average train loss:  1.6912791415264732
validation f1:  tensor(0.3125)
train f1:  tensor(0.3783)
validation auroc:  tensor(0.7160)
epoch 32/100


100%|██████████| 19/19 [00:02<00:00,  8.92it/s]


average train loss:  1.693677268530193
validation f1:  tensor(0.4375)
train f1:  tensor(0.3289)
validation auroc:  tensor(0.6885)
no iprovement
epoch 33/100


100%|██████████| 19/19 [00:02<00:00,  9.05it/s]


average train loss:  1.6901293553804095
validation f1:  tensor(0.3438)
train f1:  tensor(0.3487)
validation auroc:  tensor(0.6123)
no iprovement
epoch 34/100


100%|██████████| 19/19 [00:02<00:00,  9.39it/s]


average train loss:  1.6902809205808138
validation f1:  tensor(0.3750)
train f1:  tensor(0.3388)
validation auroc:  tensor(0.6982)
epoch 35/100


100%|██████████| 19/19 [00:02<00:00,  8.39it/s]


average train loss:  1.6859457681053562
validation f1:  tensor(0.3438)
train f1:  tensor(0.3684)
validation auroc:  tensor(0.7069)
epoch 36/100


100%|██████████| 19/19 [00:02<00:00,  8.88it/s]


average train loss:  1.683136180827492
validation f1:  tensor(0.4062)
train f1:  tensor(0.3783)
validation auroc:  tensor(0.7262)
epoch 37/100


100%|██████████| 19/19 [00:01<00:00, 10.43it/s]


average train loss:  1.674877455360011
validation f1:  tensor(0.3438)
train f1:  tensor(0.3618)
validation auroc:  tensor(0.6771)
no iprovement
epoch 38/100


100%|██████████| 19/19 [00:01<00:00, 10.06it/s]


average train loss:  1.6857330799102783
validation f1:  tensor(0.2812)
train f1:  tensor(0.3750)
validation auroc:  tensor(0.6607)
no iprovement
epoch 39/100


100%|██████████| 19/19 [00:02<00:00,  8.97it/s]


average train loss:  1.683336985738654
validation f1:  tensor(0.3438)
train f1:  tensor(0.3651)
validation auroc:  tensor(0.7566)
epoch 40/100


100%|██████████| 19/19 [00:02<00:00,  9.28it/s]


average train loss:  1.6851973721855564
validation f1:  tensor(0.4375)
train f1:  tensor(0.3684)
validation auroc:  tensor(0.7476)
no iprovement
epoch 41/100


100%|██████████| 19/19 [00:01<00:00,  9.96it/s]


average train loss:  1.671561642696983
validation f1:  tensor(0.3438)
train f1:  tensor(0.3980)
validation auroc:  tensor(0.7003)
no iprovement
epoch 42/100


100%|██████████| 19/19 [00:02<00:00,  8.85it/s]


average train loss:  1.6694822813335217
validation f1:  tensor(0.3438)
train f1:  tensor(0.3947)
validation auroc:  tensor(0.7038)
epoch 43/100


100%|██████████| 19/19 [00:01<00:00, 10.41it/s]


average train loss:  1.675541275425961
validation f1:  tensor(0.3125)
train f1:  tensor(0.3750)
validation auroc:  tensor(0.7044)
epoch 44/100


100%|██████████| 19/19 [00:01<00:00, 10.62it/s]


average train loss:  1.660737156867981
validation f1:  tensor(0.3750)
train f1:  tensor(0.3980)
validation auroc:  tensor(0.7131)
epoch 45/100


100%|██████████| 19/19 [00:01<00:00, 11.70it/s]


average train loss:  1.6783776471489353
validation f1:  tensor(0.3438)
train f1:  tensor(0.3783)
validation auroc:  tensor(0.7318)
epoch 46/100


100%|██████████| 19/19 [00:01<00:00, 10.87it/s]


average train loss:  1.6672355062083195
validation f1:  tensor(0.3750)
train f1:  tensor(0.3914)
validation auroc:  tensor(0.7418)
epoch 47/100


100%|██████████| 19/19 [00:01<00:00, 10.28it/s]


average train loss:  1.657609694882443
validation f1:  tensor(0.3438)
train f1:  tensor(0.4079)
validation auroc:  tensor(0.6946)
no iprovement
epoch 48/100


100%|██████████| 19/19 [00:01<00:00, 11.23it/s]


average train loss:  1.654757116970263
validation f1:  tensor(0.4062)
train f1:  tensor(0.4178)
validation auroc:  tensor(0.6776)
no iprovement
epoch 49/100


100%|██████████| 19/19 [00:01<00:00, 10.44it/s]


average train loss:  1.6809629578339427
validation f1:  tensor(0.3125)
train f1:  tensor(0.3980)
validation auroc:  tensor(0.6974)
epoch 50/100


100%|██████████| 19/19 [00:01<00:00, 13.07it/s]


average train loss:  1.6451141583292108
validation f1:  tensor(0.3750)
train f1:  tensor(0.4375)
validation auroc:  tensor(0.7800)
New best model saved !
epoch 51/100


100%|██████████| 19/19 [00:01<00:00, 13.25it/s]


average train loss:  1.6649041677776135
validation f1:  tensor(0.2812)
train f1:  tensor(0.3816)
validation auroc:  tensor(0.6742)
no iprovement
epoch 52/100


100%|██████████| 19/19 [00:02<00:00,  9.36it/s]


average train loss:  1.6614178168146234
validation f1:  tensor(0.4375)
train f1:  tensor(0.3750)
validation auroc:  tensor(0.7387)
epoch 53/100


100%|██████████| 19/19 [00:03<00:00,  6.13it/s]


average train loss:  1.6654630648462396
validation f1:  tensor(0.3125)
train f1:  tensor(0.3947)
validation auroc:  tensor(0.7269)
no iprovement
epoch 54/100


100%|██████████| 19/19 [00:02<00:00,  9.02it/s]


average train loss:  1.6696757642846358
validation f1:  tensor(0.2500)
train f1:  tensor(0.3816)
validation auroc:  tensor(0.7089)
no iprovement
epoch 55/100


100%|██████████| 19/19 [00:01<00:00, 10.40it/s]


average train loss:  1.6752067176919234
validation f1:  tensor(0.3750)
train f1:  tensor(0.3717)
validation auroc:  tensor(0.6860)
no iprovement
epoch 56/100


100%|██████████| 19/19 [00:02<00:00,  8.96it/s]


average train loss:  1.6593243824808221
validation f1:  tensor(0.3438)
train f1:  tensor(0.3783)
validation auroc:  tensor(0.6994)
epoch 57/100


100%|██████████| 19/19 [00:01<00:00, 10.08it/s]


average train loss:  1.6530338086579974
validation f1:  tensor(0.3125)
train f1:  tensor(0.4079)
validation auroc:  tensor(0.6775)
no iprovement
epoch 58/100


100%|██████████| 19/19 [00:01<00:00, 11.44it/s]


average train loss:  1.6527169942855835
validation f1:  tensor(0.3438)
train f1:  tensor(0.3750)
validation auroc:  tensor(0.7190)
epoch 59/100


100%|██████████| 19/19 [00:01<00:00,  9.74it/s]


average train loss:  1.6610448423184847
validation f1:  tensor(0.3750)
train f1:  tensor(0.3882)
validation auroc:  tensor(0.7375)
epoch 60/100


100%|██████████| 19/19 [00:01<00:00, 12.09it/s]


average train loss:  1.6493666422994513
validation f1:  tensor(0.3750)
train f1:  tensor(0.4145)
validation auroc:  tensor(0.7507)
epoch 61/100


100%|██████████| 19/19 [00:02<00:00,  7.99it/s]


average train loss:  1.6382572399942499
validation f1:  tensor(0.3750)
train f1:  tensor(0.4145)
validation auroc:  tensor(0.7465)
no iprovement
epoch 62/100


100%|██████████| 19/19 [00:02<00:00,  8.49it/s]


average train loss:  1.6491620415135433
validation f1:  tensor(0.4062)
train f1:  tensor(0.4243)
validation auroc:  tensor(0.7090)
no iprovement
epoch 63/100


100%|██████████| 19/19 [00:01<00:00, 10.79it/s]


average train loss:  1.6462487045087313
validation f1:  tensor(0.3438)
train f1:  tensor(0.4276)
validation auroc:  tensor(0.7453)
epoch 64/100


100%|██████████| 19/19 [00:01<00:00, 11.47it/s]


average train loss:  1.6371510969965082
validation f1:  tensor(0.4062)
train f1:  tensor(0.4145)
validation auroc:  tensor(0.7277)
no iprovement
epoch 65/100


100%|██████████| 19/19 [00:01<00:00, 12.21it/s]


average train loss:  1.6529053261405544
validation f1:  tensor(0.4062)
train f1:  tensor(0.4309)
validation auroc:  tensor(0.7631)
epoch 66/100


100%|██████████| 19/19 [00:01<00:00, 12.11it/s]


average train loss:  1.6475661553834613
validation f1:  tensor(0.3438)
train f1:  tensor(0.3980)
validation auroc:  tensor(0.7255)
no iprovement
epoch 67/100


100%|██████████| 19/19 [00:01<00:00, 12.20it/s]


average train loss:  1.6432719042426662
validation f1:  tensor(0.3438)
train f1:  tensor(0.4342)
validation auroc:  tensor(0.7437)
epoch 68/100


100%|██████████| 19/19 [00:01<00:00, 11.95it/s]


average train loss:  1.6425678667269255
validation f1:  tensor(0.3438)
train f1:  tensor(0.4145)
validation auroc:  tensor(0.7113)
no iprovement
epoch 69/100


100%|██████████| 19/19 [00:01<00:00, 11.37it/s]


average train loss:  1.6451870077534725
validation f1:  tensor(0.3750)
train f1:  tensor(0.4112)
validation auroc:  tensor(0.7528)
epoch 70/100


100%|██████████| 19/19 [00:01<00:00, 11.30it/s]


average train loss:  1.639188992349725
validation f1:  tensor(0.3750)
train f1:  tensor(0.4145)
validation auroc:  tensor(0.7039)
no iprovement
epoch 71/100


100%|██████████| 19/19 [00:01<00:00, 11.06it/s]


average train loss:  1.6412990344198126
validation f1:  tensor(0.4375)
train f1:  tensor(0.4243)
validation auroc:  tensor(0.7497)
epoch 72/100


100%|██████████| 19/19 [00:02<00:00,  9.16it/s]


average train loss:  1.6284529849102622
validation f1:  tensor(0.3750)
train f1:  tensor(0.4408)
validation auroc:  tensor(0.7808)
New best model saved !
epoch 73/100


100%|██████████| 19/19 [00:02<00:00,  7.80it/s]


average train loss:  1.6429106310794228
validation f1:  tensor(0.3438)
train f1:  tensor(0.4145)
validation auroc:  tensor(0.7424)
no iprovement
epoch 74/100


100%|██████████| 19/19 [00:02<00:00,  8.01it/s]


average train loss:  1.6196934612173783
validation f1:  tensor(0.3125)
train f1:  tensor(0.4638)
validation auroc:  tensor(0.7637)
epoch 75/100


 26%|██▋       | 5/19 [00:00<00:02,  6.94it/s]


KeyboardInterrupt: 

In [33]:
model = AttentionModel(number_of_embeddings = number_of_tiles_per_bag)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.00001)
checkpoint_path_attention = "./Data/checkpoints/model_attention"
best_model = train_model(model, trainloader, validationloader, optimizer, device, 100, checkpoint_path_attention, 5)

epoch 1/100


100%|██████████| 19/19 [00:01<00:00, 16.09it/s]


average train loss:  1.7783041565041793
validation f1:  tensor(0.2500)
train f1:  tensor(0.2007)
validation auroc:  tensor(0.5649)
New best model saved !
epoch 2/100


100%|██████████| 19/19 [00:01<00:00, 18.08it/s]


average train loss:  1.7559023593601428
validation f1:  tensor(0.2500)
train f1:  tensor(0.2862)
validation auroc:  tensor(0.5698)
New best model saved !
epoch 3/100


100%|██████████| 19/19 [00:01<00:00, 17.24it/s]


average train loss:  1.7472642032723678
validation f1:  tensor(0.2500)
train f1:  tensor(0.2862)
validation auroc:  tensor(0.7013)
New best model saved !
epoch 4/100


100%|██████████| 19/19 [00:01<00:00, 14.86it/s]


average train loss:  1.740062675978008
validation f1:  tensor(0.2188)
train f1:  tensor(0.2961)
validation auroc:  tensor(0.6301)
no iprovement
epoch 5/100


100%|██████████| 19/19 [00:01<00:00, 14.16it/s]


average train loss:  1.7241172351335223
validation f1:  tensor(0.3125)
train f1:  tensor(0.3059)
validation auroc:  tensor(0.6293)
no iprovement
epoch 6/100


100%|██████████| 19/19 [00:02<00:00,  8.59it/s]


average train loss:  1.7188427573756169
validation f1:  tensor(0.2500)
train f1:  tensor(0.3355)
validation auroc:  tensor(0.6954)
epoch 7/100


100%|██████████| 19/19 [00:01<00:00, 11.69it/s]


average train loss:  1.7084923668911582
validation f1:  tensor(0.3750)
train f1:  tensor(0.3750)
validation auroc:  tensor(0.7417)
New best model saved !
epoch 8/100


100%|██████████| 19/19 [00:01<00:00, 11.84it/s]


average train loss:  1.6941894041864496
validation f1:  tensor(0.4375)
train f1:  tensor(0.3586)
validation auroc:  tensor(0.7625)
New best model saved !
epoch 9/100


100%|██████████| 19/19 [00:01<00:00, 11.20it/s]


average train loss:  1.695092659247549
validation f1:  tensor(0.3438)
train f1:  tensor(0.3882)
validation auroc:  tensor(0.7238)
no iprovement
epoch 10/100


100%|██████████| 19/19 [00:01<00:00, 11.66it/s]


average train loss:  1.6868887637790881
validation f1:  tensor(0.4062)
train f1:  tensor(0.3980)
validation auroc:  tensor(0.7155)
no iprovement
epoch 11/100


100%|██████████| 19/19 [00:01<00:00, 11.51it/s]


average train loss:  1.6864223166515953
validation f1:  tensor(0.3125)
train f1:  tensor(0.3520)
validation auroc:  tensor(0.7146)
no iprovement
epoch 12/100


100%|██████████| 19/19 [00:02<00:00,  8.38it/s]


average train loss:  1.6896901444384926
validation f1:  tensor(0.4062)
train f1:  tensor(0.3684)
validation auroc:  tensor(0.7251)
epoch 13/100


100%|██████████| 19/19 [00:01<00:00, 11.62it/s]


average train loss:  1.6751450174733211
validation f1:  tensor(0.3750)
train f1:  tensor(0.3684)
validation auroc:  tensor(0.6710)
no iprovement
epoch 14/100


100%|██████████| 19/19 [00:02<00:00,  9.01it/s]


average train loss:  1.6756357142799778
validation f1:  tensor(0.3125)
train f1:  tensor(0.3783)
validation auroc:  tensor(0.7219)
epoch 15/100


100%|██████████| 19/19 [00:01<00:00, 12.36it/s]


average train loss:  1.6521248754702116
validation f1:  tensor(0.3125)
train f1:  tensor(0.4178)
validation auroc:  tensor(0.7091)
no iprovement
epoch 16/100


100%|██████████| 19/19 [00:01<00:00, 10.86it/s]


average train loss:  1.6528458093342029
validation f1:  tensor(0.4375)
train f1:  tensor(0.4309)
validation auroc:  tensor(0.7545)
epoch 17/100


100%|██████████| 19/19 [00:01<00:00, 10.17it/s]


average train loss:  1.6586493692900006
validation f1:  tensor(0.3750)
train f1:  tensor(0.3816)
validation auroc:  tensor(0.7148)
no iprovement
epoch 18/100


100%|██████████| 19/19 [00:01<00:00, 10.36it/s]


average train loss:  1.6574791544362117
validation f1:  tensor(0.3125)
train f1:  tensor(0.4243)
validation auroc:  tensor(0.6893)
no iprovement
epoch 19/100


100%|██████████| 19/19 [00:02<00:00,  9.19it/s]


average train loss:  1.645497064841421
validation f1:  tensor(0.4375)
train f1:  tensor(0.4178)
validation auroc:  tensor(0.7568)
epoch 20/100


100%|██████████| 19/19 [00:01<00:00, 11.56it/s]


average train loss:  1.6492819158654464
validation f1:  tensor(0.3438)
train f1:  tensor(0.4309)
validation auroc:  tensor(0.7353)
no iprovement
epoch 21/100


100%|██████████| 19/19 [00:02<00:00,  8.62it/s]


average train loss:  1.651850888603612
validation f1:  tensor(0.4062)
train f1:  tensor(0.4013)
validation auroc:  tensor(0.7205)
no iprovement
epoch 22/100


100%|██████████| 19/19 [00:01<00:00, 10.65it/s]


average train loss:  1.6534667642492997
validation f1:  tensor(0.3750)
train f1:  tensor(0.4112)
validation auroc:  tensor(0.7274)
epoch 23/100


100%|██████████| 19/19 [00:01<00:00, 12.60it/s]


average train loss:  1.6415548387326693
validation f1:  tensor(0.3750)
train f1:  tensor(0.4408)
validation auroc:  tensor(0.7488)
epoch 24/100


100%|██████████| 19/19 [00:02<00:00,  8.30it/s]


average train loss:  1.6283128763500012
validation f1:  tensor(0.3438)
train f1:  tensor(0.4441)
validation auroc:  tensor(0.7276)
no iprovement
epoch 25/100


100%|██████████| 19/19 [00:01<00:00, 11.56it/s]


average train loss:  1.6200619371313798
validation f1:  tensor(0.3438)
train f1:  tensor(0.4474)
validation auroc:  tensor(0.7495)
epoch 26/100


100%|██████████| 19/19 [00:01<00:00, 11.84it/s]


average train loss:  1.6316117236488743
validation f1:  tensor(0.4688)
train f1:  tensor(0.4243)
validation auroc:  tensor(0.7314)
no iprovement
epoch 27/100


100%|██████████| 19/19 [00:01<00:00, 11.65it/s]


average train loss:  1.6336673372670223
validation f1:  tensor(0.3750)
train f1:  tensor(0.4408)
validation auroc:  tensor(0.7333)
epoch 28/100


100%|██████████| 19/19 [00:02<00:00,  8.55it/s]


average train loss:  1.5991214262811762
validation f1:  tensor(0.4375)
train f1:  tensor(0.4671)
validation auroc:  tensor(0.7449)
epoch 29/100


100%|██████████| 19/19 [00:02<00:00,  9.40it/s]


average train loss:  1.6163043724863153
validation f1:  tensor(0.3125)
train f1:  tensor(0.4572)
validation auroc:  tensor(0.6853)
no iprovement
epoch 30/100


100%|██████████| 19/19 [00:02<00:00,  8.74it/s]


average train loss:  1.618375627618087
validation f1:  tensor(0.4375)
train f1:  tensor(0.4572)
validation auroc:  tensor(0.7514)
epoch 31/100


100%|██████████| 19/19 [00:02<00:00,  7.77it/s]


average train loss:  1.610301318921541
validation f1:  tensor(0.3438)
train f1:  tensor(0.4605)
validation auroc:  tensor(0.7132)
no iprovement
epoch 32/100


100%|██████████| 19/19 [00:02<00:00,  8.17it/s]


average train loss:  1.6210713888469495
validation f1:  tensor(0.4688)
train f1:  tensor(0.4539)
validation auroc:  tensor(0.7748)
New best model saved !
epoch 33/100


100%|██████████| 19/19 [00:02<00:00,  7.90it/s]


average train loss:  1.6152246249349493
validation f1:  tensor(0.4375)
train f1:  tensor(0.4901)
validation auroc:  tensor(0.7515)
no iprovement
epoch 34/100


100%|██████████| 19/19 [00:02<00:00,  8.61it/s]


average train loss:  1.5898375887619822
validation f1:  tensor(0.4062)
train f1:  tensor(0.4868)
validation auroc:  tensor(0.7639)
epoch 35/100


100%|██████████| 19/19 [00:01<00:00,  9.66it/s]


average train loss:  1.6087970608159115
validation f1:  tensor(0.4062)
train f1:  tensor(0.4605)
validation auroc:  tensor(0.7649)
epoch 36/100


100%|██████████| 19/19 [00:01<00:00, 10.24it/s]


average train loss:  1.6053142108415301
validation f1:  tensor(0.4375)
train f1:  tensor(0.4836)
validation auroc:  tensor(0.7798)
New best model saved !
epoch 37/100


100%|██████████| 19/19 [00:02<00:00,  8.42it/s]


average train loss:  1.6217467973106785
validation f1:  tensor(0.5312)
train f1:  tensor(0.4605)
validation auroc:  tensor(0.7986)
New best model saved !
epoch 38/100


100%|██████████| 19/19 [00:01<00:00, 11.58it/s]


average train loss:  1.6084130751459222
validation f1:  tensor(0.4375)
train f1:  tensor(0.4671)
validation auroc:  tensor(0.8010)
New best model saved !
epoch 39/100


100%|██████████| 19/19 [00:01<00:00, 10.66it/s]


average train loss:  1.6032852687333758
validation f1:  tensor(0.4062)
train f1:  tensor(0.4737)
validation auroc:  tensor(0.7957)
no iprovement
epoch 40/100


100%|██████████| 19/19 [00:02<00:00,  7.82it/s]


average train loss:  1.6032401197835018
validation f1:  tensor(0.3438)
train f1:  tensor(0.4638)
validation auroc:  tensor(0.6972)
no iprovement
epoch 41/100


100%|██████████| 19/19 [00:01<00:00, 13.07it/s]


average train loss:  1.5880156692705656
validation f1:  tensor(0.3438)
train f1:  tensor(0.4868)
validation auroc:  tensor(0.7680)
epoch 42/100


100%|██████████| 19/19 [00:02<00:00,  8.28it/s]


average train loss:  1.592379419427169
validation f1:  tensor(0.4375)
train f1:  tensor(0.4770)
validation auroc:  tensor(0.7698)
epoch 43/100


100%|██████████| 19/19 [00:02<00:00,  8.28it/s]


average train loss:  1.589890906685277
validation f1:  tensor(0.3750)
train f1:  tensor(0.4901)
validation auroc:  tensor(0.7103)
no iprovement
epoch 44/100


100%|██████████| 19/19 [00:01<00:00, 10.25it/s]


average train loss:  1.5886763208790828
validation f1:  tensor(0.4375)
train f1:  tensor(0.4770)
validation auroc:  tensor(0.7384)
epoch 45/100


100%|██████████| 19/19 [00:01<00:00, 10.33it/s]


average train loss:  1.578935284363596
validation f1:  tensor(0.4375)
train f1:  tensor(0.4967)
validation auroc:  tensor(0.7321)
no iprovement
epoch 46/100


100%|██████████| 19/19 [00:02<00:00,  8.84it/s]


average train loss:  1.606710998635543
validation f1:  tensor(0.4375)
train f1:  tensor(0.4704)
validation auroc:  tensor(0.7887)
epoch 47/100


100%|██████████| 19/19 [00:02<00:00,  7.97it/s]


average train loss:  1.588147809630946
validation f1:  tensor(0.3750)
train f1:  tensor(0.4934)
validation auroc:  tensor(0.7633)
no iprovement
epoch 48/100


100%|██████████| 19/19 [00:02<00:00,  8.40it/s]


average train loss:  1.5738387170590853
validation f1:  tensor(0.4375)
train f1:  tensor(0.5197)
validation auroc:  tensor(0.7415)
no iprovement
epoch 49/100


100%|██████████| 19/19 [00:02<00:00,  8.07it/s]


average train loss:  1.590410458414178
validation f1:  tensor(0.4375)
train f1:  tensor(0.4803)
validation auroc:  tensor(0.8206)
New best model saved !
epoch 50/100


100%|██████████| 19/19 [00:02<00:00,  7.82it/s]


average train loss:  1.5828742855473568
validation f1:  tensor(0.4062)
train f1:  tensor(0.4836)
validation auroc:  tensor(0.7525)
no iprovement
epoch 51/100


100%|██████████| 19/19 [00:01<00:00, 10.27it/s]


average train loss:  1.5720102661534359
validation f1:  tensor(0.4062)
train f1:  tensor(0.5099)
validation auroc:  tensor(0.7524)
no iprovement
epoch 52/100


100%|██████████| 19/19 [00:02<00:00,  8.56it/s]


average train loss:  1.5655851238652279
validation f1:  tensor(0.4062)
train f1:  tensor(0.5066)
validation auroc:  tensor(0.7443)
no iprovement
epoch 53/100


100%|██████████| 19/19 [00:02<00:00,  9.25it/s]


average train loss:  1.5666543747249402
validation f1:  tensor(0.4062)
train f1:  tensor(0.5000)
validation auroc:  tensor(0.7835)
epoch 54/100


100%|██████████| 19/19 [00:01<00:00, 12.72it/s]


average train loss:  1.5634560710505436
validation f1:  tensor(0.4375)
train f1:  tensor(0.5296)
validation auroc:  tensor(0.8049)
epoch 55/100


100%|██████████| 19/19 [00:01<00:00, 10.64it/s]


average train loss:  1.5778636242214001
validation f1:  tensor(0.4375)
train f1:  tensor(0.4901)
validation auroc:  tensor(0.7998)
no iprovement
epoch 56/100


100%|██████████| 19/19 [00:02<00:00,  9.10it/s]


average train loss:  1.5753224711669118
validation f1:  tensor(0.3438)
train f1:  tensor(0.5000)
validation auroc:  tensor(0.7299)
no iprovement
epoch 57/100


100%|██████████| 19/19 [00:02<00:00,  9.32it/s]


average train loss:  1.5676479590566534
validation f1:  tensor(0.4375)
train f1:  tensor(0.5099)
validation auroc:  tensor(0.7826)
epoch 58/100


100%|██████████| 19/19 [00:02<00:00,  8.13it/s]


average train loss:  1.5623757964686344
validation f1:  tensor(0.4375)
train f1:  tensor(0.5197)
validation auroc:  tensor(0.7793)
no iprovement
epoch 59/100


100%|██████████| 19/19 [00:01<00:00, 11.85it/s]


average train loss:  1.5691145282042653
validation f1:  tensor(0.4688)
train f1:  tensor(0.5296)
validation auroc:  tensor(0.7772)
no iprovement
epoch 60/100


100%|██████████| 19/19 [00:01<00:00,  9.52it/s]


average train loss:  1.5753955150905408
validation f1:  tensor(0.3750)
train f1:  tensor(0.5099)
validation auroc:  tensor(0.7273)
no iprovement
epoch 61/100


100%|██████████| 19/19 [00:02<00:00,  9.10it/s]


average train loss:  1.5639142927370573
validation f1:  tensor(0.4062)
train f1:  tensor(0.5263)
validation auroc:  tensor(0.7370)
epoch 62/100


100%|██████████| 19/19 [00:01<00:00, 10.46it/s]


average train loss:  1.5615369583431042
validation f1:  tensor(0.4688)
train f1:  tensor(0.5362)
validation auroc:  tensor(0.8258)
New best model saved !
epoch 63/100


100%|██████████| 19/19 [00:02<00:00,  8.74it/s]


average train loss:  1.571148351619118
validation f1:  tensor(0.5000)
train f1:  tensor(0.5099)
validation auroc:  tensor(0.7697)
no iprovement
epoch 64/100


100%|██████████| 19/19 [00:02<00:00,  8.90it/s]


average train loss:  1.562252534063239
validation f1:  tensor(0.4375)
train f1:  tensor(0.5066)
validation auroc:  tensor(0.7793)
epoch 65/100


100%|██████████| 19/19 [00:01<00:00, 10.31it/s]


average train loss:  1.5548334184445833
validation f1:  tensor(0.3750)
train f1:  tensor(0.5230)
validation auroc:  tensor(0.7407)
no iprovement
epoch 66/100


100%|██████████| 19/19 [00:01<00:00,  9.62it/s]


average train loss:  1.5589533856040554
validation f1:  tensor(0.5000)
train f1:  tensor(0.5329)
validation auroc:  tensor(0.8480)
New best model saved !
epoch 67/100


100%|██████████| 19/19 [00:02<00:00,  8.90it/s]


average train loss:  1.56547548896388
validation f1:  tensor(0.4375)
train f1:  tensor(0.5099)
validation auroc:  tensor(0.7615)
no iprovement
epoch 68/100


100%|██████████| 19/19 [00:02<00:00,  9.41it/s]


average train loss:  1.5615800242674978
validation f1:  tensor(0.4375)
train f1:  tensor(0.5197)
validation auroc:  tensor(0.7607)
no iprovement
epoch 69/100


100%|██████████| 19/19 [00:01<00:00, 10.80it/s]


average train loss:  1.545233268486826
validation f1:  tensor(0.3438)
train f1:  tensor(0.5493)
validation auroc:  tensor(0.6927)
no iprovement
epoch 70/100


100%|██████████| 19/19 [00:01<00:00,  9.68it/s]


average train loss:  1.5649490983862626
validation f1:  tensor(0.5000)
train f1:  tensor(0.5033)
validation auroc:  tensor(0.7950)
epoch 71/100


100%|██████████| 19/19 [00:01<00:00, 11.84it/s]


average train loss:  1.5562966936512996
validation f1:  tensor(0.4062)
train f1:  tensor(0.5164)
validation auroc:  tensor(0.7155)
no iprovement
epoch 72/100


100%|██████████| 19/19 [00:02<00:00,  8.98it/s]


average train loss:  1.556228430647599
validation f1:  tensor(0.3438)
train f1:  tensor(0.5395)
validation auroc:  tensor(0.7040)
no iprovement
epoch 73/100


100%|██████████| 19/19 [00:01<00:00, 10.37it/s]


average train loss:  1.5548793516660993
validation f1:  tensor(0.5000)
train f1:  tensor(0.5164)
validation auroc:  tensor(0.7726)
epoch 74/100


100%|██████████| 19/19 [00:01<00:00, 10.27it/s]


average train loss:  1.5539780039536326
validation f1:  tensor(0.4062)
train f1:  tensor(0.5164)
validation auroc:  tensor(0.7624)
no iprovement
epoch 75/100


100%|██████████| 19/19 [00:02<00:00,  8.31it/s]


average train loss:  1.550021165295651
validation f1:  tensor(0.4062)
train f1:  tensor(0.5296)
validation auroc:  tensor(0.7654)
epoch 76/100


100%|██████████| 19/19 [00:01<00:00, 10.85it/s]


average train loss:  1.5447666645050049
validation f1:  tensor(0.4375)
train f1:  tensor(0.5461)
validation auroc:  tensor(0.7981)
epoch 77/100


100%|██████████| 19/19 [00:01<00:00, 10.22it/s]


average train loss:  1.5542997121810913
validation f1:  tensor(0.3750)
train f1:  tensor(0.5066)
validation auroc:  tensor(0.7131)
no iprovement
epoch 78/100


100%|██████████| 19/19 [00:01<00:00, 10.33it/s]


average train loss:  1.53761971624274
validation f1:  tensor(0.4062)
train f1:  tensor(0.5559)
validation auroc:  tensor(0.8020)
epoch 79/100


100%|██████████| 19/19 [00:02<00:00,  8.54it/s]


average train loss:  1.5455463999196102
validation f1:  tensor(0.4062)
train f1:  tensor(0.5230)
validation auroc:  tensor(0.7941)
no iprovement
epoch 80/100


100%|██████████| 19/19 [00:02<00:00,  9.19it/s]


average train loss:  1.5307790856612356
validation f1:  tensor(0.3125)
train f1:  tensor(0.5789)
validation auroc:  tensor(0.7346)
no iprovement
epoch 81/100


100%|██████████| 19/19 [00:02<00:00,  9.08it/s]


average train loss:  1.5355988866404484
validation f1:  tensor(0.4062)
train f1:  tensor(0.5362)
validation auroc:  tensor(0.8064)
epoch 82/100


100%|██████████| 19/19 [00:01<00:00, 11.59it/s]


average train loss:  1.5352203030335276
validation f1:  tensor(0.3750)
train f1:  tensor(0.5493)
validation auroc:  tensor(0.7619)
no iprovement
epoch 83/100


100%|██████████| 19/19 [00:01<00:00, 11.06it/s]


average train loss:  1.5397458641152633
validation f1:  tensor(0.3750)
train f1:  tensor(0.5559)
validation auroc:  tensor(0.7345)
no iprovement
epoch 84/100


100%|██████████| 19/19 [00:01<00:00,  9.83it/s]


average train loss:  1.529063576146176
validation f1:  tensor(0.3750)
train f1:  tensor(0.5559)
validation auroc:  tensor(0.7142)
no iprovement
epoch 85/100


100%|██████████| 19/19 [00:02<00:00,  8.99it/s]


average train loss:  1.5352874994277954
validation f1:  tensor(0.5000)
train f1:  tensor(0.5461)
validation auroc:  tensor(0.7821)
epoch 86/100


100%|██████████| 19/19 [00:01<00:00,  9.67it/s]


average train loss:  1.536015692510103
validation f1:  tensor(0.4062)
train f1:  tensor(0.5395)
validation auroc:  tensor(0.7735)
no iprovement
epoch 87/100


100%|██████████| 19/19 [00:01<00:00, 10.48it/s]


average train loss:  1.519401186390927
validation f1:  tensor(0.4375)
train f1:  tensor(0.5658)
validation auroc:  tensor(0.7824)
epoch 88/100


100%|██████████| 19/19 [00:01<00:00, 10.96it/s]


average train loss:  1.5326173305511475
validation f1:  tensor(0.4375)
train f1:  tensor(0.5559)
validation auroc:  tensor(0.7453)
no iprovement
epoch 89/100


100%|██████████| 19/19 [00:01<00:00, 10.48it/s]


average train loss:  1.5249162912368774
validation f1:  tensor(0.4375)
train f1:  tensor(0.5691)
validation auroc:  tensor(0.7751)
epoch 90/100


100%|██████████| 19/19 [00:01<00:00, 11.04it/s]


average train loss:  1.5291344868509393
validation f1:  tensor(0.4688)
train f1:  tensor(0.5526)
validation auroc:  tensor(0.7696)
no iprovement
epoch 91/100


100%|██████████| 19/19 [00:01<00:00, 11.99it/s]


average train loss:  1.541270701508773
validation f1:  tensor(0.4688)
train f1:  tensor(0.5428)
validation auroc:  tensor(0.7702)
epoch 92/100


100%|██████████| 19/19 [00:02<00:00,  8.55it/s]


average train loss:  1.5120077321403904
validation f1:  tensor(0.4375)
train f1:  tensor(0.5592)
validation auroc:  tensor(0.8177)
epoch 93/100


100%|██████████| 19/19 [00:02<00:00,  9.13it/s]


average train loss:  1.5376440412119816
validation f1:  tensor(0.4688)
train f1:  tensor(0.5329)
validation auroc:  tensor(0.7530)
no iprovement
epoch 94/100


100%|██████████| 19/19 [00:02<00:00,  9.10it/s]


average train loss:  1.5319359929938066
validation f1:  tensor(0.3750)
train f1:  tensor(0.5493)
validation auroc:  tensor(0.7521)
no iprovement
epoch 95/100


100%|██████████| 19/19 [00:02<00:00,  9.06it/s]


average train loss:  1.5005110941435162
validation f1:  tensor(0.3750)
train f1:  tensor(0.5987)
validation auroc:  tensor(0.7286)
no iprovement
epoch 96/100


100%|██████████| 19/19 [00:02<00:00,  9.23it/s]


average train loss:  1.5228570009532727
validation f1:  tensor(0.4062)
train f1:  tensor(0.5757)
validation auroc:  tensor(0.7635)
epoch 97/100


100%|██████████| 19/19 [00:02<00:00,  9.08it/s]


average train loss:  1.521694540977478
validation f1:  tensor(0.4375)
train f1:  tensor(0.5855)
validation auroc:  tensor(0.7745)
epoch 98/100


100%|██████████| 19/19 [00:02<00:00,  8.77it/s]


average train loss:  1.518570994075976
validation f1:  tensor(0.5000)
train f1:  tensor(0.5855)
validation auroc:  tensor(0.8015)
epoch 99/100


100%|██████████| 19/19 [00:01<00:00, 10.42it/s]


average train loss:  1.526691474412617
validation f1:  tensor(0.3750)
train f1:  tensor(0.5461)
validation auroc:  tensor(0.7606)
no iprovement
epoch 100/100


100%|██████████| 19/19 [00:01<00:00, 10.64it/s]


average train loss:  1.5132911895450794
validation f1:  tensor(0.4375)
train f1:  tensor(0.5559)
validation auroc:  tensor(0.7475)
no iprovement
load best model


In [None]:
model = ModelVanilla(number_of_embeddings = number_of_tiles_per_bag)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.00001)
checkpoint_path_vanilla = "./Data/checkpoints/model_vanilla"
best_model = train_model(model, trainloader, validationloader, optimizer, device, 100, checkpoint_path_vanilla, 5)

epoch 1/100


100%|██████████| 19/19 [00:01<00:00, 10.83it/s]


average train loss:  1.782977675136767
validation f1:  tensor(0.2500)
train f1:  tensor(0.2072)
validation auroc:  tensor(0.5777)
New best model saved !
epoch 2/100


100%|██████████| 19/19 [00:01<00:00, 10.11it/s]


average train loss:  1.7662819310238487
validation f1:  tensor(0.2812)
train f1:  tensor(0.2467)
validation auroc:  tensor(0.6170)
New best model saved !
epoch 3/100


100%|██████████| 19/19 [00:02<00:00,  8.60it/s]


average train loss:  1.7566714663254588
validation f1:  tensor(0.3125)
train f1:  tensor(0.2862)
validation auroc:  tensor(0.6286)
New best model saved !
epoch 4/100


100%|██████████| 19/19 [00:02<00:00,  7.07it/s]


average train loss:  1.7466106603020115
validation f1:  tensor(0.3438)
train f1:  tensor(0.3059)
validation auroc:  tensor(0.6904)
New best model saved !
epoch 5/100


100%|██████████| 19/19 [00:02<00:00,  7.24it/s]


average train loss:  1.7365853347276385
validation f1:  tensor(0.3125)
train f1:  tensor(0.3026)
validation auroc:  tensor(0.6905)
New best model saved !
epoch 6/100


100%|██████████| 19/19 [00:02<00:00,  7.12it/s]


average train loss:  1.7251869440078735
validation f1:  tensor(0.2812)
train f1:  tensor(0.3355)
validation auroc:  tensor(0.6520)
no iprovement
epoch 7/100


100%|██████████| 19/19 [00:02<00:00,  7.22it/s]


average train loss:  1.7079953896371942
validation f1:  tensor(0.3125)
train f1:  tensor(0.3816)
validation auroc:  tensor(0.7046)
New best model saved !
epoch 8/100


100%|██████████| 19/19 [00:02<00:00,  7.30it/s]


average train loss:  1.7103737216246755
validation f1:  tensor(0.3125)
train f1:  tensor(0.3421)
validation auroc:  tensor(0.6886)
no iprovement
epoch 9/100


100%|██████████| 19/19 [00:02<00:00,  7.37it/s]


average train loss:  1.6869408331419293
validation f1:  tensor(0.2812)
train f1:  tensor(0.4112)
validation auroc:  tensor(0.7166)
New best model saved !
epoch 10/100


100%|██████████| 19/19 [00:02<00:00,  7.41it/s]


average train loss:  1.6759391458410966
validation f1:  tensor(0.2812)
train f1:  tensor(0.4112)
validation auroc:  tensor(0.7296)
New best model saved !
epoch 11/100


100%|██████████| 19/19 [00:02<00:00,  8.42it/s]


average train loss:  1.6722377852389687
validation f1:  tensor(0.3125)
train f1:  tensor(0.4375)
validation auroc:  tensor(0.7375)
New best model saved !
epoch 12/100


100%|██████████| 19/19 [00:02<00:00,  6.96it/s]


average train loss:  1.676335479083814
validation f1:  tensor(0.3125)
train f1:  tensor(0.3882)
validation auroc:  tensor(0.7341)
no iprovement
epoch 13/100


100%|██████████| 19/19 [00:02<00:00,  7.61it/s]


average train loss:  1.6537839550721019
validation f1:  tensor(0.2812)
train f1:  tensor(0.4474)
validation auroc:  tensor(0.7261)
no iprovement
epoch 14/100


100%|██████████| 19/19 [00:02<00:00,  7.56it/s]


average train loss:  1.6521120447861521
validation f1:  tensor(0.3125)
train f1:  tensor(0.4408)
validation auroc:  tensor(0.7379)
New best model saved !
epoch 15/100


100%|██████████| 19/19 [00:02<00:00,  7.06it/s]


average train loss:  1.6601589165235822
validation f1:  tensor(0.2500)
train f1:  tensor(0.4178)
validation auroc:  tensor(0.7330)
no iprovement
early stopping !
load best model


# Prediction

In [None]:
bag_size = 170
final_df = pd.DataFrame(columns = ['Id','Predicted'])
for i in range(test_set.shape[0]):
  value = test_set.iloc[i]
  id = value.image_id


  # compute file
  image_embeddings_path = test_tiles_encoding_folder+id+'.pkl'

  # open embeddings  dict
  embeddings_dict = pickle.load(open(image_embeddings_path, 'rb'))

  bag = None
  embedding_list = np.array(list(embeddings_dict.values()))
  np.random.shuffle(embedding_list)
  bag = np.vstack(embedding_list[:bag_size])

  bag = torch.Tensor(bag).unsqueeze(0).to(device)

  # get prediction
  model.eval()
  pred = model(bag).detach().cpu().numpy()

  final_pred = int(np.argmax(pred))
  current_line = pd.DataFrame([[id, final_pred]], columns = ['Id','Predicted'])
  final_df = pd.concat([final_df, current_line], axis = 0)


final_df.to_csv("/content/drive/MyDrive/centrale_3A/Deep Learning for medical imaging/final_preds.csv", index = False)
