## Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import os
import json
from tqdm.auto import tqdm
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

## Load model

In [None]:
%run ../models/checkpoint/__init__.py

In [None]:
# run_name = '1215_174443_cxr14_resnet-50-v2_lr0.0001_os_Cardiomegaly_normS_size256_sch-roc_auc-p5-f0.1'
# run_name = '1203_223059_cxr14_densenet-121-v2_lr0.0001_aug_normS_size256_sch-roc_auc-p5-f0.1'
run_name = '0123_174651_cxr14_mobilenet-v2_lr0.0001_hint_normS_size256_sch-roc_auc-p5-f0.1_noes'
debug_run = False

In [None]:
compiled_model = load_compiled_model_classification(run_name, debug=debug_run)
compiled_model.metadata['model_kwargs']

In [None]:
_ = compiled_model.model.eval()

## Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'cxr14',
    'dataset_type': 'test-bbox',
    'max_samples': None,
    'image_size': (1024, 1024),
    'norm_by_sample': True,
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

## Dive params

In [None]:
data_folder = os.path.join('..', '..', 'data')
ATLAS_FPATH = os.path.join(data_folder, 'test-bbox-atlas.png')
RECORDS_FPATH = os.path.join(data_folder, 'records.json')

In [None]:
target_h = 50
target_w = 50

## Prepare data for facets dive

### Create JSON file

In [None]:
import json
from torch import sigmoid

In [None]:
device = 'cuda'

In [None]:
records = []

for item in tqdm(dataloader.dataset):
    record = {
        'fname': item.image_fname,
    }
    
    with torch.no_grad():
        images = item.image.unsqueeze(0).to(device)
        outputs = compiled_model.model(images)
        preds = outputs[0].cpu().squeeze(0)
        preds = sigmoid(preds)
    
    for label_value, pred, label_name in zip(
        item.labels,
        preds,
        dataloader.dataset.labels,
    ):
        record[label_name] = int(label_value)
        # int to avoid int64 not-serializable errors
        
        record[f'{label_name}-pred'] = pred.item()
    
    records.append(record)
    
len(records)

In [None]:
with open(RECORDS_FPATH, 'w') as f:
    json.dump(records, f)

In [None]:
json_records = json.dumps(records)
# json_records

### Create atlas image

In [None]:
import os
# from torch.nn.functional import interpolate
from PIL import Image

In [None]:
%run ../utils/images.py

In [None]:
atlas = create_sprite_atlas(dataloader.dataset,
                            target_h=target_h,
                            target_w=target_w,
                            n_channels=3,
                           )
atlas.size()

In [None]:
atlas_np = (atlas.permute(1, 2, 0) * 255).type(torch.uint8).numpy()
atlas_pil = Image.fromarray(atlas_np, mode='RGB')
atlas_pil.save(ATLAS_FPATH)

In [None]:
plt.imshow(atlas_pil)

In [None]:
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.imshow(tensor_to_range01(item.image).permute(1, 2, 0))

plt.subplot(1, 2, 2)
plt.imshow(resized_image.permute(1, 2, 0))

## Run Facets Dive!

In [None]:
from IPython.core.display import display, HTML
import json

In [None]:
with open(RECORDS_FPATH, 'r') as f:
    json_records = json.load(f)

In [None]:
%%html
<style>
facets-dive-info-card {
    height: 600px;
}
</style>

In [None]:
HTML_TEMPLATE = """
    <script src="https://cdnjs.cloudflare.com/ajax/libs/webcomponentsjs/1.3.3/webcomponents-lite.js">
    </script>
    <link rel="import" href="https://raw.githubusercontent.com/PAIR-code/facets/master/facets-dist/facets-jupyter.html">
    <facets-dive
        sprite-image-width="{sprite_w}"
        sprite-image-height="{sprite_h}"
        id="elem"
        height="600"
        atlas-url="{atlas_fpath}"
        >
    </facets-dive>
    <script>
      document.querySelector("#elem").data = {jsonstr};
    </script>
"""

# Load the json dataset and the sprite_size into the template
html = HTML_TEMPLATE.format(jsonstr=json_records,
                            sprite_h=target_h,
                            sprite_w=target_w,
                            atlas_fpath=ATLAS_FPATH,
                           )

# Display the template
display(HTML(html))