# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=

In [None]:
import torch

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Load model

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
run_name = '0606_183458'
# run_name = '0607_000601'
run_id = RunId(run_name, debug=False, task='cls-spatial')

In [None]:
compiled_model = load_compiled_model(run_id, device=DEVICE)
_ = compiled_model.model.eval()

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_type': 'train',
    **compiled_model.metadata['dataset_kwargs'],
    'num_workers': 1,
    'batch_size': 10,
    'sort_samples': False,
    'shuffle': True,
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

# Check some examples

In [None]:
from torch.nn.functional import interpolate

In [None]:
%run ../training/detection/cls_spatial.py
%run ../metrics/detection/__init__.py

In [None]:
step_fn = get_step_fn_cls_spatial(
    compiled_model.model, training=False,
    cl_loss_name='wbce', device=DEVICE,
)

In [None]:
# parse_output_for_metric = partial(
#     _threshold_activations_and_keep_valid,
#     cls_thresh=None, heat_thresh=None, only='T',
# )
ioo_metric = IoO(reduce_sum=False, device=DEVICE)

In [None]:
dataloader_iter = iter(dataloader)

In [None]:
batch = next(dataloader_iter)
batch.labels.sum(dim=0)

In [None]:
batch.image.size(), batch.masks.size()

In [None]:
with torch.no_grad():
    output = step_fn(None, batch)

out_cl_spatial_osize = output['activations_original_size'].cpu()
out_cl_spatial = output['activations'].cpu()
out_cl = output['pred_labels'].cpu()
out_cl.size(), out_cl_spatial.size(), out_cl_spatial_osize.size()

In [None]:
for k in ('loss', 'cl_loss', 'spatial_loss'):
    print(k, output[k].item())

In [None]:
def plot_sample(idx):
    activations_osize = out_cl_spatial_osize[idx]
    activations = out_cl_spatial[idx]
    gt_masks = batch.masks[idx]
    labels = batch.labels[idx]
    preds = out_cl[idx]
    
    # Prepare output for IoO metric
    metric_activations = out_cl_spatial[idx:idx+1]
    metric_gt_masks = batch.masks[idx:idx+1]
    valid = batch.labels[idx:idx+1].bool()
    
    # Calculate metrics
    ioo_metric.reset()
    ioo_metric.update((metric_activations, metric_gt_masks, valid))
    ioo = ioo_metric.compute()
    print(ioo)
    
    # Prepare plot
    diseases = dataloader.dataset.labels
    n_rows = len(diseases)
    n_cols = 3
    plt.figure(figsize=(5*n_cols, 5*n_rows))

    for i, disease in enumerate(diseases):
        plt.subplot(n_rows, n_cols, i*n_cols + 1)
        plt.title(f'{disease} (gt={labels[i].item()})')
        plt.imshow(gt_masks[i])
        plt.colorbar()

        plt.subplot(n_rows, n_cols, i*n_cols + 2)
        plt.title(f'Spatial output (gen={preds[i].item():.2f}, ioo={ioo[i]:.2f})')
        plt.imshow(activations[i])
        plt.colorbar()
        
        plt.subplot(n_rows, n_cols, i*n_cols + 3)
        a = activations_osize[i]
        min_value = a.min().item()
        max_value = a.max().item()
        plt.title(f'O-size (range={min_value:.1f},{max_value:.1f})')
        plt.imshow(a)
        plt.colorbar()

In [None]:
batch.labels

In [None]:
plot_sample(9)