In [None]:
import os
import torch
from wholeslidedata.iterators import create_batch_iterator
from utils import load_config
from label_utils import to_dysplastic_vs_non_dysplastic
from tqdm.notebook import tqdm

### In this notebook we evaluate trained models on slide level.
 1. Load the configuration and the trained model from the corresponding experiment folder. 
 2. Extract patches sliding windows fashion from the WSI.
 3. Inferences on patches and stitch results back together.
 4. Compute Dice score and confusion matrix on pixel level.

#### Load config and trained model

In [None]:
# define paths
exp_dir = '/home/mbotros/experiments/barrett_gland_grading/NDvsD_ASL_LANS_RBE_sp1_ps1024/'
model_path = os.path.join(exp_dir, 'checkpoints/model_epoch_40_loss_0.249_dice_0.900.pt')

user_config = '/home/mbotros/code/barrett_gland_grading/configs/unet_training_config.yml'
wholeslide_config, train_config = load_config(user_config)
print('Loading config: {}\n'.format(user_config))
print('Loading model from {}'.format(model_path))

In [None]:
# create a slidingwindow batch iterator
batch_generator = create_batch_iterator(mode='inference',
                                        user_config=user_config,
                                        cpus=1)

In [None]:
print(batch_generator.dataset.annotations_per_label_per_key)

In [None]:
metrics = {}

for i, (x, y, info) in tqdm(enumerate(batch_generator)):

    # INPUT
    # x: [B, H, W, C]
    # y: [B, H, W]

    # dysplastic vs non-dysplastic
    y = to_dysplastic_vs_non_dysplastic(y)

    # TENSOR
    # x: [B, C, H, W]
    # y: [B, H, W]
    x = torch.tensor(x.astype('float32'))
    x = x.transpose(1, 3).transpose(2, 3).to(device)
    y = torch.tensor(y.astype('int64')).to(device)

    # forward
    y_hat = model.forward(x)
    
    # compute and store metrics
    y = y.cpu().detach().numpy().flatten()
    y_hat = torch.argmax(y_hat, dim=1).cpu().detach().numpy().flatten()
    metrics[i] = {'dice per class': f1_score(y, y_hat, average=None, labels=[0, 1, 2]),
                  'dice weighted': f1_score(y, y_hat, average='weighted')}
    
    break