In [None]:
import numpy as np
from aix360.algorithms.protodash import ProtodashExplainer
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torch

In [None]:
%run ../utils/__init__.py

## Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
kwargs = {
    'dataset_type': 'test',
    'dataset_name': 'iu-x-ray',
    'image_size': (512, 512),
    'shuffle': True,
    'frontal_only': True,
}

dataloader = prepare_data_classification(**kwargs)
len(dataloader.dataset)

## Load model

In [None]:
%run ../models/checkpoint/__init__.py

In [None]:
run_name = '0917_161952_iu-x-ray_mobilenet_lr1e-06_aug-0-cls0_size256'
debug = False

In [None]:
compiled_model = load_compiled_model_classification(run_name, debug=debug, device='cuda')
compiled_model.metadata

## ProtoDash

### Random images

In [None]:
batch_size = 10
n_features = 15
images = np.random.random((batch_size, n_features))
images.shape

### Load images

In [None]:
images = []
labels = []
for batch in dataloader:
    images.append(batch.image)
    labels.append(batch.labels)

images = torch.cat(images, dim=0)
labels = torch.cat(labels, dim=0)
images.size(), labels.size()

#### Raw pixels as features

In [None]:
batch_size = images.size()[0]
features = images.detach().cpu().view(batch_size, -1).numpy()
features.shape

#### CNN output as features

In [None]:
features = []

for batch in dataloader:
    feats = compiled_model.model(batch.image.to('cuda'), features=True).detach().cpu()
    feats = feats.view(feats.size()[0], -1)
    features.append(feats)
features = torch.cat(features, dim=0)
features.size()

### Run proto

In [None]:
proto = ProtodashExplainer()
proto

In [None]:
%%time

weights, samples, other_values = proto.explain(features, features, 10)
samples.shape

In [None]:
def get_label(dataloader, item_label):
    dataset = dataloader.dataset
    multilabel = dataset.multilabel
    labels = dataset.labels
    
    if multilabel:
        return ','.join([
            disease
            for disease, presence in zip(labels, item_label)
            if presence
        ])
    else:
        return labels[l]

In [None]:
selected_images = []
selected_labels = []

for sample_idx in samples:
    image = images[sample_idx]
    selected_images.append(image)
    
    label = labels[sample_idx]
    label = get_label(dataloader, label)
    selected_labels.append(label)

selected_images = torch.stack(selected_images, dim=0)
selected_images.size(), list(enumerate(selected_labels))

In [None]:
grid = make_grid(selected_images, normalize=True, scale_each=True, nrow=5)
grid = grid.permute(1, 2, 0)
grid.size()

In [None]:
plt.figure(figsize=(15, 5))
plt.imshow(grid)

In [None]:
S_IDX = 7

plt.figure(figsize=(15, 10))
plt.imshow(tensor_to_range01(selected_images[S_IDX]).permute(1, 2, 0))