## Look for optimal imagenet images that induce the largest projection values onto the singular vectors/product of U and V inside the images.

In [None]:
import config as c
from load_model import load_model
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm


model_names = [
"google/vit-base-patch16-224",
"google/vit-base-patch32-384",
"google/vit-large-patch16-224",
"google/vit-large-patch32-384",
"google/vit-huge-patch14-224-in21k",
"facebook/dino-vitb16",
"facebook/dino-vits16",
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14",
"facebook/deit-base-distilled-patch16-224",
"facebook/deit-small-distilled-patch16-224",
"facebook/deit-tiny-distilled-patch16-224"
]

for model_name in model_names:

    (model, processor) = load_model(model_name) 
        
    c.imagenet_val_path = "/opt/scratch/group/odelia/vit_2023/ImageNet1k/ILSVRC/Data/CLS-LOC/val/"
    im_list = os.listdir(c.imagenet_val_path)

    if 'deit' in model_name: # deit has two extra tokens before image patch tokens.
        start_index = 2
    else:
        start_index = 1

    for pn in ["p", "n"]: # direction of singular vectors are arbitary. so look for both direction.

        U_total = np.load(os.path.join(c.data_path, "UVS", f"{model_name}_U_total.npy"))
        Vt_total = np.load(os.path.join(c.data_path, "UVS", f"{model_name}_Vt_total.npy"))

        if pn == "n":
            U_total = -U_total
            Vt_total = -Vt_total

        U_total = torch.from_numpy(U_total).float().to(c.device) # left singular vector
        Vt_total = torch.from_numpy(Vt_total).float().to(c.device) # right singular vector

        total_U = [] # the activation
        total_V = []
        total_U_max = [] # the activation
        total_V_max = []
        total_U_max5 = [] # the activation
        total_V_max5 = []
        total_product = [] # the activation

        for im_file in tqdm(im_list):
            im = plt.imread(c.imagenet_val_path+im_file)
            if len(im.shape) < 3: # gray images
                im = np.repeat(im[:, :, np.newaxis], 3, axis=2)
            input = processor(images=im, return_tensors="pt")
            input = input["pixel_values"].float().to(c.device)
            output = model(input, output_hidden_states=True, output_attentions=False)

            layer_U = []
            layer_V = []
            layer_U_max = []
            layer_V_max = []
            layer_U_max5 = []
            layer_V_max5 = []
            layer_product = []
            for layer in range(Vt_total.shape[0]):
                hidden_states = output['hidden_states'][layer] # 1 batch, 14*14+1 token, 12 head * 64 embeding
                head_U = []
                head_V = []
                head_U_max = []
                head_V_max = []
                head_U_max5 = []
                head_V_max5 = []
                head_product = []
                for head in range(Vt_total.shape[1]):
                    head_U.append(torch.mean((hidden_states[:,:,:] @ U_total[layer,head,:,:])[:,start_index:,:], dim=1)[0].detach().to('cpu').numpy())
                    head_V.append(torch.mean((hidden_states[:,:,:] @ (Vt_total[layer,head,:,:].T))[:,start_index:,:], dim=1)[0].detach().to('cpu').numpy())
                    head_U_max.append(torch.max((hidden_states[:,:,:] @ U_total[layer,head,:,:])[:,start_index:,:], dim=1)[0][0].detach().to('cpu').numpy())
                    head_V_max.append(torch.max((hidden_states[:,:,:] @ (Vt_total[layer,head,:,:].T))[:,start_index:,:], dim=1)[0][0].detach().to('cpu').numpy())
                    head_U_max5.append(torch.mean(torch.topk((hidden_states[:,:,:] @ U_total[layer,head,:,:])[:,start_index:,:], 5, dim=1)[0], dim=1)[0].detach().to('cpu').numpy())
                    head_V_max5.append(torch.mean(torch.topk((hidden_states[:,:,:] @ Vt_total[layer,head,:,:].T)[:,start_index:,:], 5, dim=1)[0], dim=1)[0].detach().to('cpu').numpy())
                    head_product.append(torch.max((hidden_states[:,:,:] @ U_total[layer,head,:,:])[:,start_index:,:], dim=1)[0][0].detach().to('cpu').numpy() * 
                                        torch.max((hidden_states[:,:,:] @ (Vt_total[layer,head,:,:].T))[:,start_index:,:], dim=1)[0][0].detach().to('cpu').numpy())
                layer_U.append(head_U)
                layer_V.append(head_V)
                layer_U_max.append(head_U_max)
                layer_V_max.append(head_V_max)
                layer_U_max5.append(head_U_max5)
                layer_V_max5.append(head_V_max5)
                layer_product.append(head_product)
            total_U.append(layer_U)
            total_V.append(layer_V)
            total_U_max.append(layer_U_max)
            total_V_max.append(layer_V_max)
            total_U_max5.append(layer_U_max5)
            total_V_max5.append(layer_V_max5)
            total_product.append(layer_product)

        file_path = os.path.join(c.data_path, "optimal_natural_image/data", f"{model_name}_mean_{pn}.npy")
        if not os.path.exists(os.path.dirname(file_path)):
            os.makedirs(os.path.dirname(file_path))
        np.save(file_path, np.array([total_U, total_V]))
        np.save(os.path.join(c.data_path, "optimal_natural_image/data", f"{model_name}_max_{pn}.npy"), np.array([total_U_max, total_V_max]))
        np.save(os.path.join(c.data_path, "optimal_natural_image/data", f"{model_name}_max5_{pn}.npy"), np.array([total_U_max5, total_V_max5]))
        np.save(os.path.join(c.data_path, "optimal_natural_image/data", f"{model_name}_product_{pn}.npy"), np.array(total_product))

## show the top 7 attention (product) images, and corresponding U map and V map.

In [None]:
import PIL
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

model_names = [
"google/vit-base-patch16-224",
"google/vit-base-patch32-384",
"google/vit-large-patch16-224",
"google/vit-large-patch32-384",
"google/vit-huge-patch14-224-in21k",
"facebook/dino-vitb16",
"facebook/dino-vits16",
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14",
"facebook/deit-base-distilled-patch16-224",
"facebook/deit-small-distilled-patch16-224",
"facebook/deit-tiny-distilled-patch16-224"
]

im_list = os.listdir(c.imagenet_val_path)

for model_name in model_names:

    (model, processor) = load_model(model_name) 

    if 'deit' in model_name: # deit has two extra tokens before image patch tokens.
        start_index = 2
    else:
        start_index = 1

    image_size = model.config.image_size
    num_attention_heads = model.config.num_attention_heads
    num_hidden_layers = model.config.num_hidden_layers
    patch_size = model.config.patch_size
    token_size = int(image_size / patch_size)

    U_total = np.load(os.path.join(c.data_path, "UVS", f"{model_name}_U_total.npy"))
    Vt_total = np.load(os.path.join(c.data_path, "UVS", f"{model_name}_Vt_total.npy"))

    U_total = torch.from_numpy(U_total).float().to(c.device) # left singular vector
    Vt_total = torch.from_numpy(Vt_total).float().to(c.device) # right singular vector

    total_product_p = np.load(os.path.join(c.data_path, "optimal_natural_image/data", f"{model_name}_product_p.npy"))
    total_product_n = np.load(os.path.join(c.data_path, "optimal_natural_image/data", f"{model_name}_product_n.npy"))

    if not os.path.exists(os.path.join(c.data_path, f"optimal_natural_image/figure/{model_name}_product_p")):
        os.makedirs(os.path.join(c.data_path, f"optimal_natural_image/figure/{model_name}_product_p"))
        os.makedirs(os.path.join(c.data_path, f"optimal_natural_image/figure/{model_name}_product_n"))

    for layer in range(num_hidden_layers):
        for head in range(num_attention_heads):
            print(f"model: {model_name}, layer: {layer}, head: {head}")
            for sign,pn in [(1,'p'),(-1,'n')]:
                fig, axs = plt.subplots(10, 22, figsize=(22, 10))
                for mode in range(10):
                    if sign == 1:
                        order_product = np.argsort(total_product_p[:,layer,head,mode])
                    else:
                        order_product = np.argsort(total_product_n[:,layer,head,mode])
                    for photo_i in range(7):
                        im_no = order_product[-(photo_i+1)]
                        im = PIL.Image.open(c.imagenet_val_path+im_list[im_no])
                        im = im.resize((224,224))
                        axs[mode,1+photo_i*3].imshow(im)
                        axs[mode,1+photo_i*3].set_axis_off()
                        # display cosine similarity and attention score.
                        if sign == 1:
                            attention_score = total_product_p[im_no,layer,head,mode]
                        else:
                            attention_score = total_product_n[im_no,layer,head,mode]
                        axs[mode,1+photo_i*3].text(0, 20, f'{attention_score:.4g}', fontsize = 6)
                        im = plt.imread(c.imagenet_val_path+im_list[im_no])
                        if len(im.shape) < 3: # gray images
                            im = np.repeat(im[:, :, np.newaxis], 3, axis=2)
                        input = processor(images=im, return_tensors="pt")
                        input = input["pixel_values"].float().to(c.device)
                        output = model(input, output_hidden_states=True, output_attentions=False)
                        hidden_states = output['hidden_states'][layer] # 1 batch, 14*14+1 token, 12 head * 64 embeding
                        Umap = (hidden_states[:,:,:] @ (sign * U_total[layer,head,:,mode]))[0, start_index:].unflatten(0,(token_size,token_size)).detach().to('cpu').numpy()
                        Vmap = (hidden_states[:,:,:] @ (sign * (Vt_total[layer,head,mode,:].T)))[0, start_index:].unflatten(0,(token_size,token_size)).detach().to('cpu').numpy()
                        axs[mode,1+photo_i*3+1].imshow(Umap, cmap="gist_heat", vmin=np.percentile(Umap,30), vmax=np.percentile(Umap,95))
                        axs[mode,1+photo_i*3+2].imshow(Vmap, cmap="gist_heat", vmin=np.percentile(Vmap,30), vmax=np.percentile(Vmap,95))
                        axs[mode,1+photo_i*3+1].set_axis_off()
                        axs[mode,1+photo_i*3+2].set_axis_off()

                    cosine = torch.dot(U_total[layer,head,:,mode], Vt_total[layer,head,mode,:].T).detach().to('cpu').numpy()
                    axs[mode,0].text(0.2,0.5,f"{cosine:.3f}")
                    axs[mode,0].set_axis_off()

                fig.tight_layout(pad=0.3)
                plt.savefig(os.path.join(c.data_path, f"optimal_natural_image/figure/{model_name}_product_{pn}/L{layer}h{head}_product_{pn}.png"))
                plt.clf()
                plt.close()
