# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch

In [None]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Load model

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
run_id = RunId('0605_000039', debug=True, task='cls-spatial')

In [None]:
compiled_model = load_compiled_model(run_id)
compiled_model.metadata['model_kwargs']

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_type': 'train',
    **compiled_model.metadata['dataset_kwargs'],
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

# Check some examples

In [None]:
from torch.nn.functional import interpolate

In [None]:
%run ../training/detection/cls_spatial.py

In [None]:
step_fn = get_step_fn_cls_spatial(
    compiled_model.model, training=False,
    cl_loss_name='wbce', device='cuda'
)

In [None]:
dataloader_iter = iter(dataloader)

In [None]:
batch = next(dataloader_iter)

In [None]:
batch.image.size(), batch.masks.size()

In [None]:
with torch.no_grad():
    output = step_fn(None, batch)

out_cl_spatial = output['activations'].cpu()
out_cl = output['pred_labels'].cpu()
out_cl.size(), out_cl_spatial.size()

In [None]:
for k in ('loss', 'cl_loss', 'spatial_loss'):
    print(k, output[k].item())

In [None]:
# def plot_sample(idx):
idx = 4

activations = out_cl_spatial[idx]
gt_masks = batch.masks[idx]
labels = batch.labels[idx]
preds = out_cl[idx]
gt_masks.size(), activations.size(), labels.size(), preds.size()

In [None]:
diseases = dataloader.dataset.labels
n_rows = len(diseases)
n_cols = 2
plt.figure(figsize=(5*n_cols, 5*n_rows))

for i, disease in enumerate(diseases):
    plt.subplot(n_rows, n_cols, i*n_cols + 1)
    plt.title(f'{disease} (gt={labels[i].item()})')
    plt.imshow(gt_masks[i])
    plt.colorbar()
    
    plt.subplot(n_rows, n_cols, i*n_cols + 2)
    plt.title(f'Spatial output (gen={preds[i].item():.2f})')
    plt.imshow(activations[i])
    plt.colorbar()