# Colab dependencies

In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit 
pip install ninja
git clone https://github.com/SIDN-IAP/interactivity.git tutorial_code

In [None]:
try:
    import google.colab, sys, torch
    sys.path.append('/content/tutorial_code')
    print("GPU available" if torch.cuda.is_available()
          else "Change runtime type to include a GPU.")
except:
    pass

# Import packages

In [None]:
from netdissect import nethook, setting, renormalize, zdataset, paintwidget, labwidget, show

# Load a GAN generator

In [None]:
# Load a generator, and paint random image number 20 in a widget.
G = nethook.InstrumentedModel(setting.load_proggan('church')).cuda()
probe_z = zdataset.z_sample_for_model(G, 1000)[13][None].cuda()

# Create Probing Widget

In [None]:
prober = paintwidget.PaintWidget(image=renormalize.as_url(G(probe_z)[0]))
output_div = labwidget.Div()

LAYERNAME = 'layer5'
SELECTED_UNITS = []
SELECTED_VALUES = []


G.retain_layer(LAYERNAME)
G(probe_z)
activations = G.retained_layer(LAYERNAME)

def probe_changed(c):
    if not prober.mask: return
    area = renormalize.from_url(prober.mask, target='pt', size=activations.shape[2:])[0]
    if area.sum() <= 0.0: return
    prober.mask = ''
    mean = (activations.cpu()[0] * area[None]).sum(2).sum(1) / (area.sum() + 1e-8)
    value, order = mean.sort(0, descending=True)
    global SELECTED_UNITS, SELECTED_VALUES
    SELECTED_UNITS = [o.item() for o in order[:10]]
    SELECTED_VALUES = [v.item() for v in value[:10]]
    output_div.innerHTML = ''
    output_div.print('SELECTED_UNITS:', SELECTED_UNITS)
    output_div.print('SELECTED_VALUES:', [float('%.2f' % v) for v in SELECTED_VALUES])
prober.on('mask', probe_changed)

show(prober)
show(output_div)


# Create Painting Widget

In [None]:
import torch

G.remove_edits()
canvas_z = zdataset.z_sample_for_model(G, 1000)[70][None].cuda()
canvas = paintwidget.PaintWidget(image=renormalize.as_url(G(canvas_z)[0]))

def canvas_changed(c):
    global SELECTED_UNITS, SELECTED_VALUES
    if not canvas.mask: return
    area = renormalize.from_url(canvas.mask, target='pt', size=activations.shape[2:])[0]
    def editrule(x, imodel, **buffers):
        x[:,SELECTED_UNITS] += (area[None] * torch.Tensor(SELECTED_VALUES)[:,None,None]).to(x.device)
        return x
    G.edit_layer(LAYERNAME, rule=editrule)
    canvas.mask, canvas.image = '', renormalize.as_url(G(canvas_z)[0])
    canvas.mask = ''
    return
    G.remove_edits()
canvas.on('mask', canvas_changed)
canvas.brushsize=10
show(canvas)