# Imports

In [1]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [2]:
import torch
from torch import nn
import matplotlib.pyplot as plt

In [3]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [4]:
import pandas as pd
pd.options.display.max_columns = None

In [5]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Load model

## Load std model

In [None]:
%run ../models/checkpoint/__init__.py

In [None]:
run_name = '0123_174651_cxr14_mobilenet-v2_lr0.0001_hint_normS_size256_sch-roc_auc-p5-f0.1_noes'
debug_run = False

In [None]:
compiled_model = load_compiled_model_classification(run_name, debug=debug_run)
model = compiled_model.model
compiled_model.metadata['model_kwargs']

## Load xrv autoencoder

In [6]:
import torchxrayvision as xrv

In [7]:
class XrvAEWrapper(nn.Module):
    def __init__(self, labels=[]):
        super().__init__()
        self.ae = xrv.autoencoders.ResNetAE(weights="101-elastic")
        
        self.features_size = 512

        self.prediction = nn.Linear(self.features_size, len(labels))
        
    def forward(self, images):
        embedding = self.ae.encode(images)
        # shape: bs, n_features, 1, 1

        embedding = torch.flatten(embedding, start_dim=1)
        # shape: bs, n_features
        
        output = self.prediction(embedding)
        # shape: bs, n_diseases

        return output, embedding

In [9]:
model = XrvAEWrapper(labels=list(range(14))).cuda()
model

XrvAEWrapper(
  (ae): XRV-ResNetAE-101-elastic
  (prediction): Linear(in_features=512, out_features=14, bias=True)
)

In [11]:
run_name = 'xrv-ae-encoder'
debug_run = False

In [8]:
images = torch.randn(1, 1, 64, 64)
images.size()

torch.Size([1, 1, 64, 64])

In [10]:
out, emb = model(images.cuda())
out.size(), emb.size()

(torch.Size([1, 14]), torch.Size([1, 512]))

# Load data

In [12]:
%run ../datasets/__init__.py

In [13]:
dataset_kwargs = {
    'dataset_name': 'cxr14',
    'dataset_type': 'train',
    'max_samples': None,
    # XRV-AE params:
    'image_format': 'L',
    'xrv_norm': True,
    'image_size': (64, 64)
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

[__main__] INFO(02-09 15:51) Loading cxr14/train cl-dataset, bs=10 imgsize=(64, 64) version=None format=L
[__main__] INFO(02-09 15:51) 	Dataset size: 75713


75713

# Visualize embeddings

In [14]:
import random
from collections import Counter, defaultdict
from torch.utils.data import Subset
from torch.nn.functional import interpolate
from torch import sigmoid
from tqdm.auto import tqdm

In [15]:
%run ../tensorboard/__init__.py
# %run ../utils/images.py

In [23]:
tb_writer = TBWriter(run_name, task='cls', large=True, debug=debug_run)

In [17]:
N_SAMPLES = 1000
indexes = random.sample(range(len(dataset)), k=N_SAMPLES)

In [19]:
# img_atlas_size = (50, 50)
img_atlas_size = dataset.image_size

_ATLAS_SIZE = int(np.ceil(np.sqrt(len(indexes)) * img_atlas_size[1]))

assert _ATLAS_SIZE <= 8192, f'Atlas wont fit in TB: {_ATLAS_SIZE}'

In [22]:
embeddings = []
label_img = []
metadata = []
metadata_header = [
    f'{val}_{disease}'
    for disease in dataset.labels
    for val in ('pred', 'round', 'gt')
]

for item_idx in tqdm(indexes):
    item = dataset[item_idx]
    
    images = item.image.to('cuda').unsqueeze(0) # shape: bs=1, n_channels=3, height, width
    
    with torch.no_grad():
        preds, emb = model(images)
    
    # Save predictions as metadata
    preds = sigmoid(preds).squeeze(0) # shape: n_diseases
    
    sample_meta = [
        val
        for pred, gt in zip(preds.tolist(), item.labels.tolist())
        for val in [f'{pred:.2f}', round(pred), gt]
    ]
    
    metadata.append(sample_meta)
    
    # Save embedding
    embeddings.append(emb)
    
    # Save images
    images = images.detach() # .cpu()
    if tuple(images.size()) != img_atlas_size:
        images = interpolate(images, img_atlas_size, mode='nearest')
        # shape: bs=1, n_channels, atlas_h, atlas_w
    image = images.squeeze(0)
    label_img.append(tensor_to_range01(image))
    
label_img = torch.stack(label_img, dim=0)
embeddings = torch.cat(embeddings, dim=0)
embeddings.size(), label_img.size()

  0%|          | 0/1000 [00:00<?, ?it/s]

(torch.Size([1000, 512]), torch.Size([1000, 1, 64, 64]))

In [24]:
tb_writer.writer.add_embedding(embeddings,
                               metadata=metadata,
                               label_img=label_img,
                               metadata_header=metadata_header,
                               tag=dataset.dataset_type,
                              )

In [None]:
[(i, c) for i, c in enumerate(metadata_header) if 'Cardio' in c]

In [None]:
Counter(m[4] for m in metadata)