In [None]:
import numpy as np
from tqdm.auto import tqdm
import os
import torch
import segmentation_models_pytorch as smp
import sys
import yaml
from wholeslidedata.iterators import create_batch_iterator
import segmentation_models_pytorch as smp
import torch.nn as nn
from sklearn.metrics import f1_score, confusion_matrix
from matplotlib import patches
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd

sys.path.insert(1, os.path.join(sys.path[0], '..'))
from utils import print_dataset_statistics, plot_pred_batch, load_config, plot_batch
from train_segmentation import load_segmentation_model
from preprocessing import tissue_mask_batch, get_preprocessing

In [None]:
def load_trained_segmentation_model(exp_dir, model_path):
    """ Loads the trained model.

    Args:
        exp_dir: directory that hold all the information from an experiments (src, checkpoints)

    """
    user_config = os.path.join(exp_dir, 'src/configs/base_config.yml')
    _, train_config = load_config(user_config)

    # LOAD MODEL
    model = load_segmentation_model(train_config, activation=None)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    print('Loaded model from {}'.format(model_path))
    
    # LOAD PREPROCESSING
    if train_config['encoder_weights']:
        preprocessing = get_preprocessing(smp.encoders.get_preprocessing_fn(
            train_config['encoder_name'], train_config['encoder_weights']))
    else:
        preprocessing = get_preprocessing()
    print('During training we used {} as encoder with weights from {}.'.format(train_config['encoder_name'], train_config['encoder_weights']))   
    
    return model, preprocessing

In [None]:
def dys_score_batch(x, y, y_hat):
    """ simple dysplasia probability score computed as #NDBE voxels / #DYS voxels
    
    Args:
        x: [B, H, W, CHANNELS]
            (np.array)
        y: [B, H, W]
            (np.array)
        y_hat: [B, CLASSES, H, W]
            (torch.Tensor)
            
    Returns
        dys_score: [B, 1]
            (np.array)
    """
    y_hat_soft = torch.nn.functional.softmax(y_hat, dim=1)
    y_hat_p_ndbe = y_hat_soft[:, 1, :, :].cpu().detach().numpy()
    y_hat_p_dys = y_hat_soft[:, 2, :, :].cpu().detach().numpy()
    dys_score = np.sum(y_hat_p_ndbe, axis=(1, 2)) / np.sum(y_hat_p_dys, axis=(1, 2))
    
    return dys_score

In [None]:
# config path
base_dir = '/home/mbotros/code/barrett_gland_grading/'
classification_config = os.path.join(base_dir, 'configs/classification_config.yml')

In [None]:
# load config
print('Loaded config: {}'.format(classification_config))

with open(classification_config, 'r') as yamlfile:
    data = yaml.load(yamlfile, Loader=yaml.FullLoader)

wholeslide_config = data['wholeslidedata']

# create train and validation generators (no reset)
# training_batch_generator = create_batch_iterator(user_config=classification_config,
#                                                  mode='training',
#                                                  cpus=1)

validation_batch_generator = create_batch_iterator(mode='validation',
                                                   user_config=classification_config,
                                                   presets=('slidingwindow',),
                                                   cpus=1,
                                                   number_of_batches=-1,
                                                   return_info=True)
                                                   
    
print('\nTraining dataset ')
train_data_dict = print_dataset_statistics(training_batch_generator.dataset)
print('\nValidation dataset ')
val_data_dict = print_dataset_statistics(validation_batch_generator.dataset)

print('number of annotations', len(validation_batch_generator))

In [None]:
exp_dir = '/data/archief/AMC-data/Barrett/experiments/barrett_gland_grading/NDvsD/UNet++_EfficientNet-b4_sp=1_Dice/'
model_path = os.path.join(exp_dir, 'checkpoints/model_epoch_142_loss_0.129_dice_0.822.pt')
model, preprocessing = load_trained_segmentation_model(exp_dir, model_path)

# declare device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(model_path))
model = nn.DataParallel(model)
model.to(device)
model.eval()

In [None]:
def infer(model, generator, preprocessing):
    
    results = {}

    with torch.no_grad():
        
        for x_batch, y_batch, info in tqdm(generator):
            for idx, (x_np, y_np) in enumerate(zip(x_batch, y_batch)):
                
                # keep track of where samples are coming from
                point = info['sample_references'][idx]['point']
                wsi = info['sample_references'][idx]['reference'].file_key
                
                # preprocess patches
                sample = preprocessing(image=np.expand_dims(x_np, axis=0), mask=np.expand_dims(y_np, axis=0))
                x, y = sample['image'].to(device), sample['mask'].to(device)

                # forward
                y_hat = model(x)
                y_hat_np = y_hat.cpu().detach().numpy()
            
                # compute dysplasia scores for batch
                score = dys_score_batch(x_np, y_np, y_hat)
                print('idx', idx, 'x_shape', x_np.shape, 'wsi', wsi,  point, 'score', score)

                if wsi in results:
                    results[wsi].append(score)   
                else: 
                    results[wsi] = [score]

    return results

In [None]:
# check dysplasia probablities
results = infer(model, validation_batch_generator, preprocessing)

In [None]:
# load rbe case level diagnosis
rbe_case_df = pd.read_csv('/data/archief/AMC-data/Barrett/labels/rbe_case_level.csv')
display(rbe_case_df)