In [None]:
import os
import sys
import numpy as np
from matplotlib import patches
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
from wholeslidedata.annotation.parser import MaskAnnotationParser
from wholeslidedata.image.wholeslideimage import WholeSlideImage
from wholeslidedata.iterators import create_batch_iterator

sys.path.insert(1, os.path.join(sys.path[0], '..'))
from utils import load_config

### In this notebook we evaluate trained models on slide level.
 1. Load the configuration and the trained model from the corresponding experiment folder. 
 2. Extract patches sliding windows fashion from the WSI.
     * Use tissue mask to only extract patches from interesting regions.
 3. Inferences on patches and stitch results back together.
 4. Compute Dice score and confusion matrix on pixel level.

In [None]:
path_to_wsi = 'C:/Users/mbotros/PhD/data/ASL/ASL01_3_HE.tiff'

# open with asap backend
with WholeSlideImage(path_to_wsi, backend='asap') as wsi:
    print(f'Backend used: {wsi.__class__}\n')

# open with openslide backend
with WholeSlideImage(path_to_wsi, backend='openslide') as wsi:
    print(f'Backend used: {wsi.__class__}\n')

#### Load config and trained model

In [None]:
# define paths
exp_dir = '/data/archief/AMC-data/Barrett/experiments/barrett_gland_grading/NDvsD_DeepLab_Res34_sp1_ps_1024_aug_scheduler/'
model_path = os.path.join(exp_dir, 'checkpoints/model_epoch_54_loss_0.147_dice_0.946.pt')

user_config = '/home/mbotros/code/barrett_gland_grading/configs/slidingwindowconfig.yml'

print('Loading config: {}'.format(user_config))
print('Loading model from {}'.format(model_path))

In [None]:
# create a sliding window test iterator

mode='test'
with create_batch_iterator(mode=mode,
                           user_config=user_config,
                           presets=('folders',),
                           cpus=1, 
                           number_of_batches=-1, 
                           return_info=True) as test_iterator:

    print('number of annotations', len(test_iterator))

    for x_batch, y_batch, info in tqdm(test_iterator):
        for idx, (x_sample, y_sample) in enumerate(zip(x_batch, y_batch)):
            point = info['sample_references'][idx]['point']
            print('idx', idx, 'x_shape', x_sample.shape, 'mask_shape', y_sample.shape,  point)

In [None]:
print(batch_generator.dataset.annotations_per_label_per_key)

In [None]:
metrics = {}

for i, (x, y, info) in tqdm(enumerate(batch_generator)):

    # INPUT
    # x: [B, H, W, C]
    # y: [B, H, W]

    # dysplastic vs non-dysplastic
    y = to_dysplastic_vs_non_dysplastic(y)

    # TENSOR
    # x: [B, C, H, W]
    # y: [B, H, W]
    x = torch.tensor(x.astype('float32'))
    x = x.transpose(1, 3).transpose(2, 3).to(device)
    y = torch.tensor(y.astype('int64')).to(device)

    # forward
    y_hat = model.forward(x)
    
    # compute and store metrics
    y = y.cpu().detach().numpy().flatten()
    y_hat = torch.argmax(y_hat, dim=1).cpu().detach().numpy().flatten()
    metrics[i] = {'dice per class': f1_score(y, y_hat, average=None, labels=[0, 1, 2]),
                  'dice weighted': f1_score(y, y_hat, average='weighted')}
    
    break