In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
from importlib import reload
import IPython
mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

import torch, argparse, os, shutil, inspect, json, numpy, math
import netdissect
from netdissect.easydict import EasyDict
from netdissect import pbar, nethook, renormalize, parallelfolder, pidfile
from netdissect import upsample, tally, imgviz, imgsave, bargraph, show
from experiment import dissect_experiment

torch.cuda.set_device(5)

model_name='DA-SS-sketch'
# choices are alexnet, vgg16, or resnet152.
args = EasyDict(model=model_name, dataset='pacs-s', seg='netpqc', layer='layer3', quantile=0.01)
resdir = 'results/%s-%s-%s-%s-%s' % (args.model, args.dataset, args.seg, args.layer, int(args.quantile * 1000))
print(resdir)
def resfile(f):
    return os.path.join(resdir, f)

print('### Load dataset!')
dataset = dissect_experiment.load_dataset(args)
sample_size = len(dataset)

print(len(dataset))

# Classifier labels
from urllib.request import urlopen
from netdissect import renormalize

classlabels = dataset.classes
print(classlabels)
renorm = renormalize.renormalizer(dataset, target='zc')

In [None]:
print('### Load model!')
model = torch.load('./'+model_name+'-torch16-0.pkl',map_location='cpu')
model = model.cuda()

indices = [0]
batch = torch.cat([dataset[i][0][None,...] for i in indices])
activations, res = model(batch.cuda())

layername = args.layer
upfn = dissect_experiment.make_upfn_without_hooks(args, dataset, layername, activations)
percent_level = 1.0 - args.quantile

print('### Collect quantile statistics!')
pbar.descnext('rq')
def compute_samples(batch, *args):
    image_batch = batch.cuda()
    activations, _ = model(image_batch)
    acts = activations[layername].detach()
    hacts = upfn(acts)
    return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])
rq = tally.tally_quantile(compute_samples, dataset,
                          sample_size=sample_size,
                          r=8192,
                          num_workers=100,
                          pin_memory=True,
                          cachefile=resfile('rq.npz'))

print('### Collect TopK!')
pbar.descnext('topk')
def compute_image_max(batch, *args):
    image_batch = batch.cuda()
    activations, _ = model(image_batch)
    acts = activations[layername].detach()
    acts = acts.view(acts.shape[0], acts.shape[1], -1)
    acts = acts.max(2)[0]
    return acts

#topk => (64X100, 64X100)
topk = tally.tally_topk(compute_image_max, dataset, sample_size=sample_size,
        batch_size=50, num_workers=30, pin_memory=True,
        cachefile=resfile('topk.npz'))

def scale(intermediate_layer_output, rmax=1, rmin=0):
    X_std = (intermediate_layer_output - intermediate_layer_output.min()) / (
        intermediate_layer_output.max() - intermediate_layer_output.min())
    X_scaled = X_std * (rmax - rmin) + rmin
    return X_scaled

In [None]:
# single image visualization
rank=98
percent_level=0.99

unit_number = 170
real_classlabels =['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
# print(topk.result()[0][unit_number][rank].item())
# print('Before Image No.: ' + str(topk.result()[1][unit_number][rank].item()))
image_number = topk.result()[1][unit_number][rank].item()
image_number=3862
print('Af ter Image No.: ' + str(image_number))
# print(dataset.images[topk.result()[1][unit_number][rank]])

iv = imgviz.ImageVisualizer((224, 224), source=dataset, quantiles=rq,
        level=rq.quantiles(percent_level))
batch = torch.cat([dataset[i][0][None,...] for i in [image_number]])
truth = [classlabels[dataset[i][1]] for i in [image_number]]

activations, output = model(batch.cuda())

preds = output.max(1)[1]
acts = activations[layername].detach()
print(acts[0][unit_number].sum())

imgs = [renormalize.as_image(t, source=dataset) for t in batch]
prednames = [real_classlabels[p.item()] for p in preds]
print( 'pred: ' + prednames[0], 'true: ' + truth[0])

ivsmall = imgviz.ImageVisualizer((300, 300), source=dataset)
mask_img = ivsmall.masked_image(batch[0], acts, (0, unit_number), percent_level=percent_level)
display(show.blocks(
    [[[mask_img],
      [ivsmall.heatmap(acts.cpu(), (0, u), mode='nearest')]] for u in [unit_number]]
))

img_name = '%s-%s-%s-p%s-i%s-u%s.jpg' % (
    args.model, args.dataset, args.layer, 
    int(percent_level * 1000), image_number, unit_number
)
print(img_name)

In [None]:
mask_img.save(img_name, quality=95)

In [None]:
reload(tally)
pbar.descnext('unit_images')

iv = imgviz.ImageVisualizer((100, 100), source=dataset, quantiles=rq,
        level=rq.quantiles(percent_level))
def compute_acts(*image_batch):
    image_batch = image_batch[0].cuda()
    activations, _ = model(image_batch)
    acts_batch = activations[layername]
    return acts_batch

k='90,10'
unit_images = iv.masked_images_for_topk(
        compute_acts, dataset, topk, k=k, num_workers=30, pin_memory=True,
        cachefile=resfile('top'+k+'images.npz'))
for u in [170]:
    display(unit_images[u])