In [None]:
import jupyter_black

jupyter_black.load(lab=False)

In [None]:
from glob import glob
from natsort import natsorted

import torch
import numpy as np
# from umap import UMAP
from torchvision.utils import make_grid
from torchvision import transforms

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib import colormaps

import data_utils

plt.style.use("ggplot")
plt.style.use("seaborn-v0_8-colorblind")

In [None]:
# after:
# !python save_activations.py
# saved_activations/ has activation files:
# ./saved_activations/raw-{layer_name}-{batch_start}.pt

In [None]:
# load activations files, estimate quantile
layer_name = 'conv1'
fns = natsorted(glob(f'./saved_activations/raw-{layer_name}-*.pt'))
display(fns)

## Load minibatch activation files, compute quantile per neuron, on each mini-batch

In [None]:
quantile = 0.95

quantile_samples = []
for fn in tqdm(fns):
    act = torch.load(fn)
    
    if len(act.shape) == 4: # Conv layer
        # channel first, then combine all remaining (spatial, and instance) dimensions 
        act1 = act.permute(1,0,2,3).reshape(act.shape[1], -1)
    elif len(act.shape) == 2: # fc layer
        # neuron first
        act1 = act.permute(1,0)
    
    q = np.quantile(act1.numpy(), q=quantile, axis=1)
    quantile_samples.append(q)

quantile_samples = np.stack(quantile_samples)
quantile_samples.shape

In [None]:
# quantile_estimates per neuron (channel)
quantile_estimates = np.mean(quantile_samples, axis=0)
quantile_estimates.shape, quantile_estimates

## visualize binary masks

In [None]:
for channel_threshold, channel_act in zip(quantile_estimates, act.permute(1, 0, 2, 3)):
    act_mask = channel_act > channel_threshold
    grid = make_grid(act_mask.unsqueeze(1), nrow=16, padding=0)[0]
    plt.figure(figsize=[12, 6])
    plt.imshow(grid)
    plt.axis("off")
    plt.show()

## visualize original images

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)

imgs = torch.empty(136, 3, 224, 224)
for i, img_index in enumerate(dataset.indices[-136:]):
    img = dataset.dataset[img_index][0]
    img = transform(img)
    imgs[i] = img
#     display(img)
#     plt.imshow(img.permute(1, 2, 0).numpy())

grid = make_grid(imgs, nrow=16)
plt.figure(figsize=[12, 6])
plt.imshow(
    grid.permute(1, 2, 0).numpy(),
)
plt.axis("off")

In [None]:
target_model, target_preprocess = data_utils.get_target_model("resnet50", "cpu")
dataset = data_utils.get_data("imagenet_val", preprocess=None)
subset = list(range(0, len(dataset), 10))
dataset = torch.utils.data.Subset(dataset, subset)
target_preprocess