# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Load model

In [None]:
%run ../models/checkpoint/__init__.py

In [None]:
run_name = '0123_174651_cxr14_mobilenet-v2_lr0.0001_hint_normS_size256_sch-roc_auc-p5-f0.1_noes'
debug_run = False

In [None]:
compiled_model = load_compiled_model_classification(run_name, debug=debug_run)
compiled_model.metadata['model_kwargs']

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'cxr14',
    'dataset_type': 'train',
    'max_samples': None,
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

# Visualize embeddings

In [None]:
import random
from collections import Counter, defaultdict
from torch.utils.data import Subset
from torch.nn.functional import interpolate
from torch import sigmoid
from tqdm.auto import tqdm

In [None]:
%run ../tensorboard/__init__.py
# %run ../utils/images.py

In [None]:
tb_writer = TBWriter(run_name, task='cls', large=True, debug=debug_run)

In [None]:
N_SAMPLES = 1000
indexes = random.sample(range(len(dataset)), k=N_SAMPLES)

In [None]:
img_atlas_size = (50, 50) # dataset.image_size

_ATLAS_SIZE = int(np.ceil(np.sqrt(len(indexes)) * img_atlas_size[1]))

assert _ATLAS_SIZE <= 8192, f'Atlas wont fit in TB: {_ATLAS_SIZE}'

In [None]:
embeddings = []
label_img = []
metadata = []
metadata_header = [
    f'{val}_{disease}'
    for disease in dataset.labels
    for val in ('pred', 'round', 'gt')
]

for item_idx in tqdm(indexes):
    item = dataset[item_idx]
    
    images = item.image.to('cuda').unsqueeze(0) # shape: bs=1, n_channels=3, height, width
    
    with torch.no_grad():
        preds, emb = compiled_model.model(images)
    
    # Save predictions as metadata
    preds = sigmoid(preds).squeeze(0) # shape: n_diseases
    
    sample_meta = [
        val
        for pred, gt in zip(preds.tolist(), item.labels.tolist())
        for val in [f'{pred:.2f}', round(pred), gt]
    ]
    
    metadata.append(sample_meta)
    
    # Save embedding
    embeddings.append(emb)
    
    # Save images
    images = images.detach() # .cpu()
    images = interpolate(images, img_atlas_size, mode='nearest') # shape: bs=1, 3, atlas_h, atlas_w
    image = images.squeeze(0)
    label_img.append(tensor_to_range01(image))
    
label_img = torch.stack(label_img, dim=0)
embeddings = torch.cat(embeddings, dim=0)
embeddings.size(), label_img.size()

In [None]:
tb_writer.writer.add_embedding(embeddings,
                               metadata=metadata,
                               label_img=label_img,
                               metadata_header=metadata_header,
                               tag=dataset.dataset_type,
                              )

In [None]:
[(i, c) for i, c in enumerate(metadata_header) if 'Cardio' in c]

In [None]:
Counter(m[4] for m in metadata)