In [None]:
##########################################################################################################
# Author: Mihaly Sulyok & Peter Karacsonyi                                                               #
# Last updated: 2024 jan 7                                                                               #
# doing inference on one or more wsi(s)                                                                  #
# Input: model type, saved weights, session name and wsi(s)                                              #
# Output: accuracy, roc/auc graph, top false negative images                                             #
##########################################################################################################

import util.utils
import torch
import numpy as np
from pathlib import Path
from torch.nn import Linear
import matplotlib.pyplot as plt
from torch.hub import load as hub_load
from torchvision.utils import make_grid
from sklearn.metrics import roc_curve, auc
from torch.utils.tensorboard import SummaryWriter 
# tensorboard --logdir /mnt/bigdata/placenta/tensorboard_data
# firefox http://localhost:6006/

def test_model(dataloader, model, session_name = None) -> dict():

    writer = None

    if session_name is not None:
        writer = SummaryWriter(log_dir=tensorboard_log_dir, comment="test-results")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using {device}")
    model = model.to(device)

    true_labels = []
    predictions = []

    model.eval()

    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, _, labels_dict, _ in dataloader:
            images = images.to(device)
            labels = labels_dict['class'].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            outputs.shape
            outputs.data.shape
            probabilities = torch.sigmoid(outputs)  # the senet does not have a sigmoid output
            predictions.extend(probabilities[:, 1].cpu().numpy()) # getting back the probs for both class, roc_curve needs only the positive 
            true_labels.extend(labels.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print('Accuracy: {:.2f}%'.format(100 * correct / total))
    writer.add_text("Accuracy", str(accuracy), global_step=None, walltime=None)

    false_negatives = {}

    for i, positive_index in enumerate(true_labels):
        if positive_index == 1 and predictions[i] < 0.3:
            false_negatives[predictions[i]] = i # dict with keys as the prediction scores and values as indexes of images (sorting is easy) 

    worst_fns = sorted(false_negatives.items(), key=lambda item: item[0])

    fpr, tpr, thresholds = roc_curve(true_labels, predictions)
    roc_auc = auc(fpr, tpr)

    # Ensure thresholds are within the expected range
    thresholds = np.clip(thresholds, 0, 1)

    # Normalize the threshold values
    norm = plt.Normalize(vmin=thresholds.min(), vmax=thresholds.max())
    cmap = plt.cm.viridis

    # Create figure and axis
    fig, ax = plt.subplots()

    # Plot each segment of the ROC curve with color mapping to the thresholds
    for i in range(len(fpr) - 1):
        color = cmap(norm(thresholds[i]))
        ax.plot(fpr[i:i+2], tpr[i:i+2], color=color, lw=2)

    # Adding the colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])  # Important to ensure the colorbar works with our custom colors
    fig.colorbar(sm, ax=ax, label='Threshold')

    # Plotting the diagonal line for random chance
    ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')

    # Customize the plot
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('Receiver Operating Characteristic with Thresholds')
    ax.legend(['ROC curve (area = {:.2f})'.format(roc_auc)], loc="lower right")

    plt.show()
    writer.add_figure("ROC/AUC fig", fig, global_step=None, close=True, walltime=None)
    
    writer.add_graph(model, images)
    writer.close()

    return worst_fns


#########
# usage #
#########

##########################################################################
# set session name, model type, test data and session name from training #
##########################################################################

# append to tensorboard report (get session name first)
# training session name for tensorboard
session_name = "adept-orangutan"
# load test dataset and saved weights
base_dir = Path("/mnt/bigdata/placenta")
model_checkpoint = "adept-orangutan5.ckpt"
tensorboard_log_dir = base_dir / "tensorboard_data" / session_name

# model type
model = hub_load(
    'moskomule/senet.pytorch',
    'se_resnet50',
    pretrained=True,
    verbose=True
)
# saved weights / checkpoint file
model_checkpoint_file = base_dir / "training_checkpoints" / model_checkpoint
checkpoint = torch.load(model_checkpoint_file)
# custom model configuration
num_ftrs = model.fc.in_features
model.fc = Linear(num_ftrs, 2)

model.load_state_dict(checkpoint["model_state_dict"])

# initiate tensorboard writer
tensorboard_log_dir = base_dir / "tensorboard_data" / session_name
writer = SummaryWriter(log_dir=tensorboard_log_dir, comment="evaluation-results")

# draw the model structure (tensorboard)
# writer.add_graph(model)

#########################################################################
# inserting tiles from h5path with TransformedPathmlTileSet to datasets #
#########################################################################
base_dir = Path("/mnt/bigdata/placenta")
h5folder = base_dir / Path("h5mini")
h5files = list(h5folder.glob("*.h5path"))

datasets = []
ds_fullsize = 0

for h5file in h5files:
    print(f"creating dataset from {str(h5file)} with TransformedPathmlTileSet")
    datasets.append(util.utils.TransformedPathmlTileSet(h5file))

for ds in datasets:
    ds_fullsize += ds.dataset_len

full_ds = torch.utils.data.ConcatDataset(datasets)


dataloader = torch.utils.data.DataLoader(
    full_ds, batch_size=64, shuffle=True, num_workers=0
)

# run
worst_fns = test_model(
    dataloader, model, session_name
)

# retrieve the top 20 worst false negative images
fn_top20_images = []
fn_top20_preds = []
for pred, image_id in worst_fns[:20]:
    image, _ = full_ds[image_id] 
    fn_top20_images.append(image)
    fn_top20_preds.append(str(pred))

def show(inp, label):
    fig = plt.gcf()
    plt.imshow(inp.permute(1,2,0))
    plt.title(label)
    plt.show()
    return fig

grid = make_grid(fn_top20_images)
grid_labels = [str(fn_top20_preds[x]) for x in range(len(fn_top20_images))]
fig = show(grid, label=grid_labels)

# write to tensorboard as well
writer.add_figure("top_20_false_negatives", fig, global_step=None, close=True, walltime=None)
writer.close()