In [None]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from lucent.modelzoo.util import get_model_layers
from lucent.optvis import render, param, transform, objectives
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').eval()
# print(model)

In [None]:
# tests
obj = f"vit_encoder_layer_0_layernorm_after:767" # layer_0-11
# render.render_vis(model, obj, lambda: param.image(224, 224, fft=False, channels=3), show_inline=True, transforms=lambda x:x)
img = render.render_vis(model, obj, save_image = True, show_inline=True, image_name = "test.jpg")

In [None]:
# print("Module", model.__module__)
# print(get_model_layers(model))

# inputs = processor(images=image, return_tensors="pt")
# print(inputs["pixel_values"].shape)
# model predicts one of the 1000 ImageNet classes
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])

for layer in range(1, 12):
    for neuron in range(0, 768, 50):
            obj = f"vit_encoder_layer_{layer}_layernorm_after:{neuron}" # layer_0-11
            image_name = f"vit_encoder_layer_{layer}_layernorm_after_{neuron}.jpg"
            # render.render_vis(model, obj, lambda: param.image(224, 224, fft=False, channels=3), show_inline=True, transforms=lambda x:x)
            img = render.render_vis(model, obj, show_inline=True, save_image=True, image_name = image_name)


In [None]:
# indian elephant: 385
# ⁠persian cat: 283
# ⁠Goose: 99
# ⁠Model T: 661
# ⁠Harp: 594
    
class_indices = [385, 283, 99, 661, 594]
class_names = ["IndianElephant", "PersianCat", "Goose", "ModelT", "Harp"]
for list_idx, neuron_class in enumerate(class_indices):
    obj = f"classifier:{str(neuron_class)}"
    image_name = f"classifier_{class_names[list_idx]}.jpg"
    img = render.render_vis(model, obj, show_inline=True, save_image=True, image_name = image_name, thresholds=[2560])


In [None]:
# AM for a complete ViT Layer
for layer in range(0, 12):
    obj = f"vit_encoder_layer_{layer}_layernorm_after" # layer_0-11
    image_name = f"vit_encoder_layer_{layer}_layernorm_after.jpg"
    # render.render_vis(model, obj, lambda: param.image(224, 224, fft=False, channels=3), show_inline=True, transforms=lambda x:x)
    img = render.render_vis(model, obj, show_inline=True, save_image=True, image_name = image_name, thresholds=[2560])



In [None]:
# Output maximization for attention heads
# Note: The ViTClassifier seems to have a fundamental error in its self-attention layer, where they don't divide the input onto the heads.
# Instead, they use the same dense layers (Q,K,V) for all attention heads, which is the same as if there was only one attention head.
for layer in range(0, 12):
    for head in range(0, 12):
        obj = f"vit_encoder_layer_{layer}_attention_attention_output_context:{head}" # layer_0-11
        image_name = f"vit_encoder_layer_{layer}_attention_attention_output_context_{head}.jpg"
        # render.render_vis(model, obj, lambda: param.image(224, 224, fft=False, channels=3), show_inline=True, transforms=lambda x:x)
        img = render.render_vis(model, obj, show_inline=True, save_image=True, image_name = image_name, thresholds=[560])

In [None]:
# save output maximization for attention heads in grid
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance

fig,ax = plt.subplots(12,12, tight_layout=True, figsize=(20,20))

filenames=[[f'AM_outputs/attention_heads_output_maximization/vit_encoder_layer_{i}_attention_attention_output_context_{j}.jpg' for j in range(12)] for i in range(12)] #or glob or any other way to describe filenames
for transformer_block in range(12):
    for head in range(12):
        with open(filenames[transformer_block][head],'rb') as f:
            image= Image.open(f)
            ax[head][transformer_block].set_axis_off()
            ax[head][transformer_block].imshow(image)
            
fig.show()

Next we try activation grids by saving the activation of passing a specific image. Then we optimize the input to correspond to exactly this activation vector. We can do so attention-head and transformer-layer-wise. This kind of corresponds to how the network "sees" the input image.

In [3]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
# from lucent.modelzoo.util import get_model_layers
# from lucent.optvis import render, param, transform, objectives
# import requests
# import torch

# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw)
# device = "cuda" if torch.cuda.is_available() else "cpu"

# # processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
# model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').eval().to(device)
# # print(model)

# # Output maximization for attention heads
# # Note: The ViTClassifier seems to have a fundamental error in its self-attention layer, where they don't divide the input onto the heads.
# # Instead, they use the same dense layers (Q,K,V) for all attention heads, which is the same as if there was only one attention head.
# # for layer in range(0, 12):
# #     for head in range(0, 12):
# #         obj = f"vit_encoder_layer_{layer}_attention_attention_output_context:{head}" # layer_0-11
# #         image_name = f"vit_encoder_layer_{layer}_attention_attention_output_context_{head}.jpg"
# #         # render.render_vis(model, obj, lambda: param.image(224, 224, fft=False, channels=3), show_inline=True, transforms=lambda x:x)
# #         img = render.render_vis(model, obj, save_image=True, image_name = image_name, thresholds=[2560])

# # indian elephant: 385
# # ⁠persian cat: 283
# # ⁠Goose: 99
# # ⁠Model T: 661
# # ⁠Harp: 594
    
# class_indices = [385, 283, 99, 661, 594]
# class_names = ["IndianElephant", "PersianCat", "Goose", "ModelT", "Harp"]
# for list_idx, neuron_class in enumerate(class_indices):
#     obj = f"classifier:{str(neuron_class)}"
#     image_name = f"classifier_{class_names[list_idx]}.jpg"
#     img = render.render_vis(model, obj, show_inline=True, save_image=True, image_name = image_name, thresholds=[2560])

# activation grids
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from lucent.modelzoo import *
from lucent.misc.io import show
import lucent.optvis.objectives as objectives
import lucent.optvis.param as param
import lucent.optvis.render as render
import lucent.optvis.transform as transform
from lucent.misc.channel_reducer import ChannelReducer
from lucent.misc.io import show
from itertools import product
import numpy as np
import torch
import torchvision

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

@torch.no_grad()
def get_layer(model, layer, X):
    layers = layer.split(".")
    curr_layer = model
    for layer in layers:
        curr_layer = eval("curr_layer." + layer) # hackerman to allow recursion even through module lists
    hook = render.ModuleHook(curr_layer)
    model(X)
    hook.close()
    return hook.features


@objectives.wrap_objective()
def dot_compare(layer, acts, patch_no, batch=1):
    acts = torch.from_numpy(acts).to(device)
    def inner(T):
        pred = T(layer)
        # print(pred.shape) # (196, 197, 768)
        pred = pred[patch_no, patch_no+1, batch*64:(batch+1)*64] # get the 64 neurons corresponding to that attention head and that patch
        # print(pred.shape) # (64)
        # print(acts.shape) # (64)
        return -(pred * acts).sum(dim=0, keepdims=True).mean()
    return inner

def activation_grid_vit(
    img,
    model,
    layer,
    n_steps=1024
):
    cell_image_size=16 # size of each patch
    # Normalize and resize the image
    img = torch.tensor(np.transpose(img, [2, 0, 1])).to(device)
    normalize = (transform.normalize())
    transforms = transform.standard_transforms.copy() + [
        normalize,
        torch.nn.Upsample(size=224, mode="bilinear", align_corners=True),
    ]
    transforms_f = transform.compose(transforms)
    # shape: (1, 3, original height of img, original width of img)
    img = img.unsqueeze(0)
    # shape: (1, 3, 224, 224)
    img = transforms_f(img)


    attention_scores = get_layer(model, layer, img)[0]
    # print(attention_scores.shape) # (197, 768)
    attention_scores = attention_scores.reshape([197, 12, 64])
    
    attention_scores_np = attention_scores.cpu().numpy()
    # print(attention_scores_np.shape) # (197, 12, 64)
    num_patches, n_attention_heads, _ = attention_scores_np.shape
    num_patches -= 1 # remove the x_class
    
    # do naive implementation first
    
    # for each position `(y, x)` in the feature map `acts`, we optimize an image
    # to match with the features `acts[:, y, x]`
    # This means that the total number of cells (which is the batch size here) 
    # in the grid is layer_height*layer_width.
    nb_cells = (num_patches-1) * n_attention_heads # cell for each patch for each activation head
    
    # rename layer so that the render function finds it
    layer_converted = layer.replace(".", "_")
    layer_converted = layer_converted.replace("[", "_")
    layer_converted = layer_converted.replace("]", "")
    
    for attention_head in range(n_attention_heads):
        # Parametrization of the of each cell in the grid
        param_f = lambda: param.image(
            h=cell_image_size, w=cell_image_size, batch=num_patches
        )

        obj = objectives.Objective.sum(
            [
                # for each position in `acts`, maximize the dot product between the activations
                # `acts` at the position (y, x) and the features of the corresponding
                # cell image on our 'grid'. The activations at (y, x) is a vector of size
                # `layer_channels` (this depends on the `layer`). The features
                # of the corresponding cell on our grid is a tensor of shape
                # (layer_channels, cell_layer_height, cell_layer_width).
                # Note that cell_layer_width != layer_width and cell_layer_height != layer_weight
                # because the cell image size is smaller than the image size.
                # With `dot_compare`, we maximize the dot product between
                # cell_activations[y_cell, x_xcell] and acts[y,x] (both of size `layer_channels`)
                # for each possible y_cell and x_cell, then take the average to get a single
                # number. Check `dot_compare for more details.`
                # dot_compare(layer, attention_scores_np[y+1:y+2, x:x+1], batch=x + y * 64)
                # for i, (x, y) in enumerate(product(range(num_patches), range(n_attention_heads)))
                dot_compare(layer_converted, attention_scores_np[x+1, attention_head], patch_no=x, batch=attention_head) # skip the first patch (x_class)
                for i, x in enumerate(range(num_patches)) # try for one attention head first
            ]
        )
        results = render.render_vis(
            model,
            obj,
            param_f,
            thresholds=(n_steps,),
            progress=True,
            fixed_image_size=224,
            show_image=False,
            save_image=True,
            image_name=f"attention_head_{attention_head}.png"
        )
        # shape: (layer_height*layer_width, cell_image_size, cell_image_size, 3)
        imgs = results[-1] # last step results
        # shape: (layer_height*layer_width, 3, cell_image_size, cell_image_size)
        imgs = imgs.transpose((0, 3, 1, 2))
        imgs = torch.from_numpy(imgs)
        imgs = imgs[:, :, 2:-2, 2:-2]
        # turn imgs into a a grid
        grid = torchvision.utils.make_grid(imgs, nrow=int(np.sqrt(nb_cells)), padding=0)
        grid = grid.permute(1, 2, 0)
        torchvision.utils.save_image(grid, f"attention_head_{attention_head}.png")
        # grid = grid.numpy()
        # Image.fromarray(grid).save(f"attention_head_{attention_head}.png")
        # render.export(grid, f"attention_head_{attention_head}.png")
    
    """

    # negative matrix factorization `NMF` is used to reduce the number
    # of channels to n_groups. This will be used as the following.
    # Each cell image in the grid is decomposed into a sum of
    # (n_groups+1) images. First, each cell has its own set of parameters
    #  this is what is called `cells_params` (see below). At the same time, we have
    # a of group of images of size 'n_groups', which also have their own image parametrized
    # by `groups_params`. The resulting image for a given cell in the grid
    # is the sum of its own image (parametrized by `cells_params`)
    # plus a weighted sum of the images of the group. Each image from the group
    # is weighted by `groups[cell_index, group_idx]`. Basically, this is a way of having
    # the possibility to make cells with similar activations have a similar image, because
    # cells with similar activations will have a similar weighting for the elements
    # of the group.
    
    # reducer = ChannelReducer(n_patches, "NMF")
    
    attention_scores_np /= attention_scores_np.max(0) # (197, 768)
    
    attention_scores = torch.from_numpy(attention_scores_np)

    # Parametrization of the images of the groups (we have 12 activation groups / aka heads)
    attention_heads_params, attention_heads_image_f = param.fft_image(
        [64, 3, 16, 16] # every patch has 64 attention scores per patch
    )
    # Parametrization of the images of each patch in the grid (we have 196 patches + x_class)
    patches_params, patches_images_f = param.fft_image(
        [num_patches, 3, 16, 16] # every patch has channels RGB and size [16,16]
    )

    # First, we need to construct the images of the grid
    # from the parameterizations

    def image_f():
        attention_heads = attention_heads_image_f()
        patches_images = patches_images_f()
        X = []
        for i in range(num_patches):
            x = 0.7 * patches_images[i] + 0.5 * sum(
                attention_scores[i+1, j*64: (j+1)*64].squeeze() for j in range(n_attention_heads) # * attention_heads[j] # (i+1) to skip the x_class, i symbolizes the index of the patch; than take 64 attention scores per attention head.
            )
            X.append(x)
        X = torch.stack(X)
        return X

    # make sure the images are between 0 and 1
    image_f = param.to_valid_rgb(image_f, decorrelate=True)

    # After constructing the cells images, we sample randomly a mini-batch of cells
    # from the grid. This is to prevent memory overflow, especially if the grid
    # is large.
    # def sample(image_f, batch_size):
    #     def f():
    #         X = image_f()
    #         inds = torch.randint(0, len(X), size=(batch_size,))
    #         inputs = X[inds]
    #         # HACK to store indices of the mini-batch, because we need them
    #         # in objective func. Might be better ways to do that
    #         sample.inds = inds
    #         return inputs

    #     return f

    # image_f_sampled = sample(image_f, batch_size=batch_size)

    # Now, we define the objective function

    def objective_func(model):
        # shape: (batch_size, layer_channels, cell_layer_height, cell_layer_width)
        pred = model(layer)
        # use the sampled indices from `sample` to get the corresponding targets
        # target = attention_scores[sample.inds].to(pred.device)
        target = attention_scores.to(pred.device)
        
        # shape: (batch_size, layer_channels, 1, 1)
        target = target.view(target.shape[0], target.shape[1], 1, 1)
        dot = (pred * target).sum(dim=1).mean()
        return -dot

    obj = objectives.Objective(objective_func)

    def param_f():
        # We optimize the parametrizations of both the groups and the cells
        params = list(attention_heads_params) + list(patches_params)
        return params, image_f
        # return params, image_f_sampled

    results = render.render_vis(
        model,
        obj,
        param_f,
        thresholds=(n_steps,),
        show_image=False,
        progress=True,
        fixed_image_size=cell_image_size,
    )
    # shape: (layer_height*layer_width, 3, grid_image_size, grid_image_size)
    imgs = image_f()
    imgs = imgs.cpu().data
    imgs = imgs[:, :, 2:-2, 2:-2]
    # turn imgs into a a grid
    grid = torchvision.utils.make_grid(imgs, nrow=int(np.sqrt(num_patches)), padding=0)
    grid = grid.permute(1, 2, 0)
    grid = grid.numpy()
    render.show(grid)
    return imgs """


model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').eval().to(device)
img = np.array(Image.open("cat.png"), np.float32)
layer = "vit.encoder"
layer = "vit.encoder.layer[0].attention.attention.output_context"
_ = activation_grid_vit(img, model, layer=layer, n_steps=1)


100%|██████████| 1/1 [00:12<00:00, 12.36s/it]


TypeError: Cannot handle this data type: (1, 1, 60), |u1