In [1]:
import os
import sys
import warnings
warnings.simplefilter("ignore", UserWarning)
from wholeslidedata.image.wholeslideimage import WholeSlideImage
from wholeslidedata.annotation.wholeslideannotation import WholeSlideAnnotation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm
from wholeslidedata.iterators import create_batch_iterator
sys.path.insert(1, os.path.join(sys.path[0], '../..'))
from utils import plot_batch, colors_1
import segmentation_models_pytorch as smp
import argparse
from train_ensemble import Ensemble
from confidence_calibration import avg_entropy_sk_per_patch
from preprocessing import tissue_mask_batch, get_preprocessing
from train_segmentation import load_trained_segmentation_model
from train_ensemble import load_model, Ensemble, SingleModel
from confidence_calibration import avg_entropy_sk, plot_class_probabilities_sample, plot_pred_sample, ece, brier_score, avg_entropy_sk_per_patch

from matplotlib.patches import Circle, Rectangle
from matplotlib.offsetbox import (TextArea, DrawingArea, OffsetImage, AnnotationBbox)

from nn_archs.set_transformer import SetTransformer
from train_slide_classification import SlideGradeModel
from metrics_lib import _validate_probabilities
from skimage.filters import gaussian
from matplotlib.colors import LinearSegmentedColormap, ListedColormap

from sklearn.metrics import cohen_kappa_score, accuracy_score, roc_auc_score, confusion_matrix
from utils import plot_confusion_matrix, plot_roc_curves
import seaborn as sns
import yaml
from yaml.loader import SafeLoader

In [2]:
def entropy_slide(y_pred, epsilon=1e-5):
    """ Computes the entropy per class vs the rest

    Args:
        y_pred: (C, )
        epsilon: small number for computation

    Returns:
        avg_entropy_sk: (C, )
    """
    # validate probabilities
    _validate_probabilities(y_pred)

    num_classes = y_pred.shape[0]
    avg_entropy = np.zeros(num_classes)
    max_entropy = np.log(2) * num_classes

    for c in range(num_classes):
        
        p_c = y_pred[c]
        avg_entropy[c] = -np.sum(p_c * np.log(p_c + epsilon) + (1 - p_c) * np.log(1 - p_c + epsilon))

    return avg_entropy / max_entropy

In [3]:
def entropy_pixel(y_pred, epsilon=1e-5):
    """ Computes the average of pixel-wise entropy values for

    Args:
        y_pred: (N, C)
        epsilon: small number for computation

    Returns:
        avg_entropy_sk: (C, )
    """
    # validate probabilities
    _validate_probabilities(y_pred)
    num_classes = y_pred.shape[1]
    avg_entropy = np.zeros(num_classes)

    for c in range(num_classes):

        # for every pixel the prob of c
        p_c = y_pred[:, c]
        avg_entropy[c] = -(1 / len(p_c)) * np.sum(p_c * np.log(p_c + epsilon) + (1 - p_c) * np.log(1 - p_c + epsilon))

    return avg_entropy

def entropy_pixel_per_patch(y_pred):
    """ Compute the average entropy per patch in a batch

    Args:
        y_pred: (B, C, H, W)

    Returns:
        h: (B, C)
    """
    h = np.zeros((y_pred.shape[0], y_pred.shape[1]))

    for i, y_pred_patch in enumerate(y_pred):
        y_pred_patch = np.transpose(y_pred_patch, (1, 2, 0)).reshape(-1, y_pred.shape[1])
        h[i] = entropy_pixel(y_pred_patch)

    return h

In [4]:
class TileGenerator:
    '''Generates tiles for Numpy images
    '''
    def __init__(self, image, step_size, tile_size):
        
        self.image = image
        self.tile_size = tile_size
        self.step_size = step_size
    
    def get_generator(self):
        img = self.image
        width, height = img.shape[0], img.shape[1]
        x_tiles = int(np.floor(width/self.step_size))
        y_tiles = int(np.floor(height/self.step_size))

        for y in range(y_tiles):
            for x in range(x_tiles):
                x_coord = int(np.round(x*self.step_size))
                y_coord = int(np.round(y*self.step_size))
                tile = img[x_coord: x_coord + self.tile_size, y_coord: y_coord + self.tile_size]
                centre_coord = (x_coord, y_coord)

                # remove when doesnt fit
                if tile.shape == (self.tile_size, self.tile_size, 3):
                    yield tile, centre_coord

In [5]:
# (1) get tiles with overlap
tile_size = 512
step_size = 256
spacing = 1

def tile_generator(imgs, step_size, tile_size):
    '''Generates tiles for Numpy images
    '''
    for img in imgs:
        width, height = img.shape[0], img.shape[1]
        x_tiles = int(np.floor(width/step_size))
        y_tiles = int(np.floor(height/step_size))

        for y in range(y_tiles):
            for x in range(x_tiles):
                x_coord = int(np.round(x*step_size))
                y_coord = int(np.round(y*step_size))
                tile = img[x_coord: x_coord + tile_size, y_coord: y_coord + tile_size]
                centre_coord = (x_coord, y_coord)

                # remove when doesnt fit
                if tile.shape == (tile_size, tile_size, 3):
                    yield tile, centre_coord

In [6]:
# load ensemble of segmentation models
exp_dir = '/data/archief/AMC-data/Barrett/experiments/barrett_gland_grading/3_classes/Ensemble_m5_UNet++_CE_IN/'
preprocessing = get_preprocessing(smp.encoders.get_preprocessing_fn('efficientnet-b4', 'imagenet'))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
ensemble_m5_CE_IN = Ensemble(exp_dir, device=device, m=5)

Loading model: unet++, weights: imagenet
Loaded model from: /data/archief/AMC-data/Barrett/experiments/barrett_gland_grading/3_classes/Ensemble_m5_UNet++_CE_IN/net_0/checkpoints/best_model.pt

Loading model: unet++, weights: imagenet
Loaded model from: /data/archief/AMC-data/Barrett/experiments/barrett_gland_grading/3_classes/Ensemble_m5_UNet++_CE_IN/net_1/checkpoints/best_model.pt

Loading model: unet++, weights: imagenet
Loaded model from: /data/archief/AMC-data/Barrett/experiments/barrett_gland_grading/3_classes/Ensemble_m5_UNet++_CE_IN/net_2/checkpoints/best_model.pt

Loading model: unet++, weights: imagenet
Loaded model from: /data/archief/AMC-data/Barrett/experiments/barrett_gland_grading/3_classes/Ensemble_m5_UNet++_CE_IN/net_3/checkpoints/best_model.pt

Loading model: unet++, weights: imagenet
Loaded model from: /data/archief/AMC-data/Barrett/experiments/barrett_gland_grading/3_classes/Ensemble_m5_UNet++_CE_IN/net_4/checkpoints/best_model.pt



In [7]:
def extract_tiles_2(model, generator, preprocessing, device):
    """ Extracts tiles.

    Args:
            model:
            batch_iterator:
            preprocessing:
            device:

    Returns:
            info_tiles:
    """
    info_tiles = []

    with torch.no_grad():
        for idx, (x_np, loc) in enumerate(tqdm(generator)):
            
            # pre process and put on device
            x = preprocessing(image=np.expand_dims(x_np, axis=0))['image'].to(device)
            y = torch.zeros_like(x)
  
            # forward
            y_hat = model.forward(x, y)

            # naive max grade prediction 
            y_pred_max_grade = np.max(np.argmax(y_hat, axis=1), axis=(1, 2))       # (B, 1)
            avg_msp = np.mean(y_hat, axis=(2, 3))                                  # (B, C)
            
            # entropy (uncertainty score)
            avg_entropy = avg_entropy_sk_per_patch(y_hat)
            avg_entropy_pixel = entropy_pixel_per_patch(y_hat)
            y_hat_nd_vs_d = np.add.reduceat(y_hat, indices=[0, 1, 2], axis=1)
            avg_msp_nd_vs_d = np.mean(y_hat_nd_vs_d, axis=(2, 3))  
            avg_entropy_nd_vs_d = avg_entropy_sk_per_patch(y_hat_nd_vs_d)
                       
            for i in range(len(y_hat)):
                
                # print('Idx: {}, point:  {}, shape: {}: '.format(idx, point, x_np[i].shape))
                info_tiles.append({'loc': loc, 
                                   'naive_pred': y_pred_max_grade[i], 
                                   'avg_msp': avg_msp[i],
                                   'avg_msp_nd_vs_d': avg_msp_nd_vs_d[i],
                                   'avg_msp_pred': avg_msp[i][y_pred_max_grade[i]],
                                   'entropy_pred': avg_entropy[i][y_pred_max_grade[i]], 
                                   'entropy_sk': avg_entropy[i],
                                   'entropy_pixel': avg_entropy_pixel[i],
                                   'entropy_nd_vs_d': avg_entropy_nd_vs_d[i],
                                   'avg_msp_dys': avg_msp_nd_vs_d[i][2],
                                   'entropy_dys': avg_entropy_nd_vs_d[i][2]}) 
                
    return pd.DataFrame(info_tiles)

### The training data

In [8]:
# load rbe case level diagnosis
rbe_slide_df = pd.read_csv('/data/archief/AMC-data/Barrett/labels/rbe_slide_level.csv')
rbe_slide_df['grade normal'] = rbe_slide_df['grade'].map({'NDBE': 1, 'LGD': 2, 'HGD': 3})
display(rbe_slide_df)

Unnamed: 0,slide,grade_num,grade,grade normal
0,ASL01_3_HE,3,LGD,2
1,ASL02_1_HE,1,NDBE,1
2,ASL03_1_HE,3,LGD,2
3,ASL04_1_HE,1,NDBE,1
4,ASL05_1_HE,1,NDBE,1
...,...,...,...,...
285,ROCT38_XI-HE1,4,HGD,3
286,ROCT38_XII-HE1,1,NDBE,1
287,ROCT39_V-HE1,4,HGD,3
288,ROCT39_VI-HE1,4,HGD,3


In [19]:
# path to the datasets                      
split_file = '/home/mbotros/code/barrett_gland_grading/configs/split.yml'
dataset = 'validation'

# open the file and load the file
with open(split_file) as f:
    data = yaml.load(f, Loader=SafeLoader)
    wsi_path_list = [x['wsi']['path'] for x in data[dataset]]

print(len(wsi_path_list))
print(wsi_path_list[0])

29
/data/archief/AMC-data/Barrett/ASL/ASL24_1_HE.tiff


In [20]:
# get the files that are in both (the GT and in the file list)
groundtruth_cases = list(rbe_slide_df.slide)
file_cases = [f.split('/')[-1][:-5] for f in wsi_path_list]
cases = [slide for slide in file_cases if slide in groundtruth_cases]
print(len(cases))

28


In [None]:
x = np.zeros((len(cases), 250, 4))
y = np.zeros((len(cases), 1))
n = []

for idx, case in enumerate(tqdm(cases)):
    
    # look up the label of this case
    n.append(case)
    wsi_path = [s for s in wsi_path_list if case in s][0]
    grade = int(rbe_slide_df[rbe_slide_df['slide'] == case]['grade normal']) - 1
    y[idx] = grade
    print('Processing case: {}\nPath: {}\nLabel: {}\n'.format(case, wsi_path, grade))
    
    # open the image with spacing 1
    with WholeSlideImage(wsi_path, backend='openslide') as wsi:
        slide = wsi.get_slide(spacing=1)

    # tile gen WSI level: overlapping sliding window
    tile_gn = tile_generator(imgs=[slide], step_size=step_size, tile_size=tile_size)
    
    # apply ensemble over the tiles
    info_tiles = extract_tiles_2(model=ensemble_m5_CE_IN, generator=tile_gn, preprocessing=preprocessing, device=device)
    
    # rank on entropy
    sus_tiles =  info_tiles[info_tiles['naive_pred'] > 1].sort_values(by=['entropy_dys'])[:250]

    if len(sus_tiles) == 0:
        print('Did not find any evidence of dysplasia.')
        sus_tiles = info_tiles[info_tiles['naive_pred'] > 0].sort_values(by=['entropy_dys'])[:250]

    for tile_idx, tile in sus_tiles.reset_index().iterrows():      
        x[idx, tile_idx, 0:4] = tile['entropy_sk']    

In [11]:
# store features
classification_exp_dir = '/data/archief/AMC-data/Barrett/experiments/barrett_slide_classification/entropy_features_wsi/'
x_save_path = os.path.join(classification_exp_dir, 'x_test_overlap')
y_save_path = os.path.join(classification_exp_dir, 'y_test_overlap')

# np.save(file=x_save_path, arr=x)
# np.save(file=y_save_path, arr=y)
# y_names_save_path = os.path.join(classification_exp_dir, 'y_test_names')
# np.save(file=y_names_save_path, arr=n)