# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch

In [None]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

In [None]:
DEVICE = 'cuda'

# Load model

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
run_id = RunId('0321_052008', debug=False, task='cls')

In [None]:
compiled_model = load_compiled_model(run_id, device=DEVICE)
compiled_model.metadata['model_kwargs']

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'chexpert',
    'dataset_type': 'train-val',
    'max_samples': None,
    'num_workers': 1,
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

# TCAV

In [None]:
import math

In [None]:
ACTUAL_DISEASES = list(dataset.labels[1:])

In [None]:
other_diseases = list(ACTUAL_DISEASES)
other_diseases.remove('Cardiomegaly')

In [None]:
df = dataset.label_index
df = df.loc[df['Cardiomegaly'] == 1]
df = df.loc[df['Frontal/Lateral'] == 'Frontal']
df = df.loc[(df[other_diseases] == 0).all(axis=1)]
print(len(df))
df.head(2)

In [None]:
def plot_items(indexes):
    n_cols = 3
    n_rows = math.ceil(len(indexes) / n_cols)

    plt.figure(figsize=(n_cols*5, n_rows*5))

    for plt_idx, idx in enumerate(indexes):
        item = dataset[idx]

        title = idx # item.image_fname
        print(title)

        plt.subplot(n_rows, n_cols, plt_idx+1)
        plt.imshow(item.image[0], cmap='gray')
        plt.title(title)

In [None]:
idxs = list(df.sample(10).index)
idxs

In [None]:
plot_items(idxs)

In [None]:
cardiom_idxs = [72850, 106333, 38118, 11365, 123427, 71215, 59573, 106333, 2188, 91875]

In [None]:
plot_items(cardiom_idxs)

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
# TODO: change this to the collate_fn??
class DatasetWrapper(Dataset):
    def __init__(self, dataset, device='cuda'):
        self.dataset = dataset
        self.device = device
    def __getitem__(self, idx):
        item = self.dataset[idx]
        return item.image.to(self.device)
    def __len__(self):
        return len(self.dataset)
    def __getattr__(self, name):
        return getattr(self.dataset, name)

In [None]:
random_dataset = dataset
random_dataset.label_index = random_dataset.label_index.sample(10)
random_dataset.label_index.reset_index(drop=True, inplace=True)
random_dataloader = DataLoader(DatasetWrapper(random_dataset, device=DEVICE), batch_size=10, shuffle=True)
len(random_dataloader), len(random_dataloader.dataset)

In [None]:
cardiom_dataset = prepare_data_classification(**dataset_kwargs).dataset
cardiom_dataset.label_index = cardiom_dataset.label_index.iloc[cardiom_idxs]
cardiom_dataset.label_index.reset_index(drop=True, inplace=True)
len(cardiom_dataset)

In [None]:
cardiom_dataloader = DataLoader(DatasetWrapper(cardiom_dataset, device=DEVICE), batch_size=10)
len(cardiom_dataloader), len(cardiom_dataloader.dataset)

In [None]:
from captum.concept import TCAV, Concept

In [None]:
from captum.concept._utils.classifier import DefaultClassifier, Classifier

In [None]:
%run ../training/classification/grad_cam.py

In [None]:
# classifier = DefaultClassifier()
# classifier.lm.to(DEVICE)
# classifier

In [None]:
compiled_model.model = compiled_model.model.to(DEVICE)

In [None]:
tcav = TCAV(
    ModelWrapper(compiled_model.model),
    layers='model.features.denseblock4.denselayer16.conv2',
    # classifier=classifier,
)
tcav

In [None]:
cardiom = Concept(0, 'cardiomegaly', cardiom_dataloader)
random = Concept(1, 'random', random_dataloader)

In [None]:
cavs = tcav.compute_cavs([[cardiom, random]])
cavs

In [None]:
d = cavs['0-1']
k = list(d.keys())[0]
cav = d[k]
cav.__dict__

In [None]:
cav.stats['accs']

In [None]:
item = dataset[4]
item.image.size(), item.labels

In [None]:
plt.imshow(item.image[0], cmap='gray')

In [None]:
inputs = item.image.unsqueeze(0).to(DEVICE)
inputs.size()

In [None]:
scores = tcav.interpret(inputs, [[cardiom, random]], target=1)
scores

In [None]:
item.labels

# Reproducing bugs

Captum bugs: #719 and #721

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from captum.concept import TCAV, Concept

In [None]:
DEVICE = 'cuda'

In [None]:
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 10)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.classifier = nn.Linear(10, 1)
    def forward(self, images):
        # images shape: batch_size, 3, height, width
        x = self.conv(images) # shape: batch_size, 10, features-height, features-width
        x = self.pool(x) # shape: batch_size, 10, 1, 1
        x = self.flatten(x) # shape: batch_size, 10
        x = self.classifier(x) # shape: batch_size, 1
        return x

In [None]:
model = MyModel().to(DEVICE)

In [None]:
class DummyDataset(Dataset):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
    def __getitem__(self, idx):
        image = torch.zeros(3, 256, 256)
        return image.to(self.device)
    def __len__(self):
        return 10

In [None]:
concept0 = Concept(0, 'concept0', DataLoader(DummyDataset(device=DEVICE), batch_size=10))
concept1 = Concept(1, 'concept1', DataLoader(DummyDataset(device=DEVICE), batch_size=10))

In [None]:
tcav = TCAV(model, layers='conv')
tcav

In [None]:
cavs = tcav.compute_cavs([[concept0, concept1]])
cavs

In [None]:
inputs = torch.rand(7, 3, 256, 256).to(DEVICE)

In [None]:
scores = tcav.interpret(inputs, [[concept0, concept1]])