# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import os
import torch
from torch import nn
import torchxrayvision as xrv
import matplotlib.pyplot as plt

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
import pandas as pd
pd.options.display.max_columns = None

# Load model

## Compiled model

In [None]:
%run ../models/checkpoint/__init__.py

In [None]:
run_name = '1215_174443_cxr14_resnet-50-v2_lr0.0001_os_Cardiomegaly_normS_size256_sch-roc_auc-p5-f0.1'
debug_run = False

In [None]:
compiled_model = load_compiled_model_classification(run_name, debug=debug_run)
compiled_model.metadata['model_kwargs']

## XRV model

In [None]:
xrv_model = xrv.models.DenseNet(weights="nih", in_channels=1)
xrv_model

In [None]:
print(super(type(xrv_model), xrv_model).__repr__())

In [None]:
type(xrv_model)

In [None]:
xrv.models.DenseNet

In [None]:
class XRVWrapper(nn.Module):
    """Wrapper for TorchXrayVision models, to comply with own API."""
    def __init__(self, model):
        super().__init__()
        self.model = model
        
        assert isinstance(model, xrv.models.DenseNet), (
            f'wrapped model should be instance of xrv.models.DenseNet, got {type(model)}',
        )
        
        self.valid_idxs, self.valid_pathologies = zip(*[
            (index, pat)
            for index, pat in enumerate(model.pathologies)
            if pat
        ])
        
        self.valid_idxs = torch.tensor(self.valid_idxs)
        
    def forward(self, x):
        output = self.model(x)
        # shape: bs, n_total_pathologies
        
        output = output.index_select(dim=-1, index=self.valid_idxs.to(output.device))
        # shape: bs, n_valid_pathologies
        
        return (output,)

In [None]:
model = XRVWrapper(xrv_model)
model

In [None]:
model = model.to('cuda')
model

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'cxr14',
    'dataset_type': 'test-bbox',
    'max_samples': None,
    'image_format': 'L', # XRV accepts 1 channel
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

In [None]:
from PIL import Image
import numpy as np
from torchvision import transforms

In [None]:
image_fpath = os.path.join(dataset.image_dir, dataset.label_index['FileName'][1])

In [None]:
image = Image.open(image_fpath).convert('RGB')
image.size

In [None]:
image_np = np.array(image)
image_np.shape

In [None]:
image_np.min(), image_np.max()

In [None]:
# image_tensor = dataset.transform(image)
image_tensor = transforms.ToTensor()(image)
image_tensor.size()

In [None]:
image_tensor.min(), image_tensor.max()

In [None]:
dataset.transform

In [None]:
plt.imshow(image)

# Evaluate in dataset

In [None]:
%run -n ../eval_classification.py

In [None]:
metrics = evaluate_model(model, dataloader)
metrics