In [None]:
import os
import pandas as pd
import json
import matplotlib
import torch
from tqdm.notebook import tqdm

In [None]:
%run ../utils/__init__.py
%run ../metrics/__init__.py

In [None]:
%run ../utils/plots.py

In [None]:
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
DEVICE = torch.device('cuda')
DEVICE

## Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'covid-uc',
    'max_samples': None,
    'frontal_only': False,
    'image_size': (512, 512),
}
dataloader = prepare_data_classification(dataset_type='all', **dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

## Load model

In [None]:
%run ../models/checkpoint/__init__.py

In [None]:
run_name = '0717_120222_covid-x_densenet-121_lr1e-06_os_aug-covid'
debug = False

In [None]:
compiled_model = load_compiled_model_classification(run_name, debug, DEVICE)
compiled_model.metadata

In [None]:
model = compiled_model.model
_ = model.eval()

## Run through data

In [None]:
from torch.nn.functional import softmax

In [None]:
from collections import defaultdict

In [None]:
LABELS = dataloader.dataset.labels
LABELS

In [None]:
COVID_IDX = 0

In [None]:
debugging = defaultdict(lambda: list())

def get_predictions(model, dataloader, covid_threshold=None):
    torch.set_grad_enabled(False)

    labels = dataloader.dataset.labels
    cm_predictions = [[[] for _ in range(len(labels))] for _ in range(len(labels))]

    for batch in tqdm(dataloader):
        images, labels = batch
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = model(images)
        outputs = outputs[0].detach()

        for gt, prediction in zip(labels, outputs):
            gt = int(gt.item())
            label_predicted = prediction.argmax().item()
            
            # NOTE: forcing threshold on COVID
            if covid_threshold and label_predicted == COVID_IDX:
                probas = softmax(prediction, dim=-1)
                under_threshold = probas.max().item() < covid_threshold
                if under_threshold:
                    # Use next best thing
                    predictions_copy = prediction.clone()
                    predictions_copy[COVID_IDX] = prediction.min().item() - 10
                    label_predicted = predictions_copy.argmax().item()
#                 else:
#                     predictions_copy = None # DEBUG
                    
#                 debugging['threshold'].append((prediction,
#                                                probas,
#                                                under_threshold,
#                                                predictions_copy,
#                                                label_predicted))

            cm_predictions[gt][label_predicted].append(prediction.cpu())
            
    cm = [[len(l) for l in row] for row in cm_predictions]
    
    return cm_predictions, cm

In [None]:
def plot_distribution(cm_preds, gt_chosen=2, pred_chosen=0, **kwargs):
    predictions = cm_preds[gt_chosen][pred_chosen]
    selection = f'GT: {LABELS[gt_chosen]}, Pred: {LABELS[pred_chosen]}'
    if len(predictions) == 0:
        print(f'No predictions match! ({selection})')
        return
    predictions = torch.stack(predictions)
    # shape: n_cases, n_labels
    
    probas_predicted = softmax(predictions, dim=-1)[:, pred_chosen]
    # shape: n_cases
    
    plt.hist(probas_predicted, **kwargs)
    plt.xlabel(f'Prediction probability of class {LABELS[pred_chosen]}')
    plt.ylabel('Sample count')
    plt.title(f'{selection} ({len(predictions)} cases)')

In [None]:
cm_preds, cm = get_predictions(model, dataloader, covid_threshold=0.6)

In [None]:
plot_cm(cm, LABELS)

In [None]:
plot_cm(cm, LABELS)

In [None]:
plt.figure(figsize=(15, 4))

plt.subplot(1, 3, 1)
plot_cm(cm, LABELS, title='Trained on COVID-X, tested on COVID-UC')

plt.subplot(1, 3, 2)
plot_distribution(cm_preds, 2, 0, bins=20)

plt.subplot(1, 3, 3)
plot_distribution(cm_preds, 0, 0, bins=20)

In [None]:
plt.figure(figsize=(15, 4))

plt.subplot(1, 3, 1)
plot_cm(cm, LABELS, title='Trained on COVID-X, tested on COVID-UC')

plt.subplot(1, 3, 2)
plot_distribution(cm_preds, 2, 0, bins=20)

plt.subplot(1, 3, 3)
plot_distribution(cm_preds, 0, 0, bins=20)