In [None]:
import os
import warnings
from os import getenv

import cv2
import nopdb
import numpy as np
import PIL
import torch
from matplotlib import pyplot as plt

from spoof.dataset.dataset import FaceDataset
from spoof.model.vit import ViT

warnings.filterwarnings("ignore")

#### Load and visualize an image returned from the ```FaceDataset```

In [None]:
# Setup dataset to get a random image and read it
os.chdir(getenv("HOME") + "/mipt/spoof")
ds = FaceDataset("data/casia/test/annotations.csv")
img = ds[np.random.randint(len(ds))]["image"]

In [None]:
# Define some functions to map the tensor back to an image
def inv_normalize(tensor):
    """Normalize an image tensor back to the 0-255 range."""
    tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
    return tensor


def inv_transform(tensor, normalize=True):
    """Convert a tensor to an image."""
    if normalize:
        tensor = inv_normalize(tensor)
    array = tensor.detach().cpu().numpy()
    array = array.transpose(1, 2, 0).astype(np.uint8)
    return PIL.Image.fromarray(array)

In [None]:
inv_transform(img)

#### Load the model

In [None]:
# Create model and let it load the weights in the constructor
model = ViT()

#### Capture attention weights

In [None]:
def capture_attn_call(input, layer_idx):
    with nopdb.capture_call(
        model.extractor.encoder.layers[layer_idx].self_attention.forward
    ) as attn_call:
        with torch.no_grad():
            model(input)
    return attn_call


# Get the attention matrix for the last layer
attn_matrix = capture_attn_call(img, 11).locals["attn_output_weights"][0]
print(f"Attention matrix shape: {attn_matrix.shape}")

#### Visualize attention map 

In [None]:
def get_attention_map(img, layer_idx, get_mask=False):
    """Get the attention map for an image."""
    attn_matrix = capture_attn_call(img, layer_idx).locals["attn_output_weights"]

    # Average the attention weights across all heads.
    attn_matrix = torch.mean(attn_matrix, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(attn_matrix.size(1))
    aug_att_mat = attn_matrix + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1])

    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()

    if get_mask:
        result = cv2.resize(mask / mask.max(), (img.shape[1], img.shape[2]))
    else:
        mask = cv2.resize(mask / mask.max(), (img.shape[1], img.shape[2]))[..., np.newaxis]
        result = mask * img.numpy().transpose(1, 2, 0)

    return result

In [None]:
def plot_attention(img, save_path=None):
    """Plot the attention maps for each layer."""
    fig = plt.figure(figsize=(16, 16))

    # Show the attention maps for each layer with mask
    for i in range(12):
        ax = fig.add_subplot(4, 3, i + 1)
        ax.imshow(get_attention_map(img, i))
        ax.axis("off")
        ax.set_title(f"Layer {i + 1}")
        fig.tight_layout()
        fig.suptitle("Attention Maps, without mask", fontsize=16)

    # Show the attention maps for each layer with mask on the same figure
    fig = plt.figure(figsize=(16, 16))
    for i in range(12):
        ax = fig.add_subplot(4, 3, i + 1)
        ax.imshow(get_attention_map(img, i, get_mask=True))
        ax.axis("off")
        ax.set_title(f"Layer {i + 1}")
        fig.tight_layout()
        fig.suptitle("Attention Maps, with mask", fontsize=16)

    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

#### Visualize attention map

In [None]:
plot_attention(img)

#### Save a couple of attention maps for Casia train and test datasets

In [None]:
num_samples = 10
split = ["train", "test"]

# create subfolders for each split
os.makedirs("figures/attention/casia/train", exist_ok=True)
os.makedirs("figures/attention/casia/test", exist_ok=True)

for s in split:
    casia = FaceDataset(f"data/casia/{s}/annotations.csv")
    for i in range(num_samples):
        idx = np.random.randint(len(casia))
        img = casia[idx]["image"]
        fn = casia[idx]["filename"].split("/")[-1]
        folder = f"figures/attention/casia/{s}"
        save_path = f"{folder}/{fn}"
        plot_attention(img, save_path=save_path)

#### Replay

In [None]:
kind = ["real", "attack"]

# create subfolders for each split
for k in kind:
    os.makedirs(f"figures/attention/replay/train/{k}", exist_ok=True)
    os.makedirs(f"figures/attention/replay/test/{k}", exist_ok=True)

for s in split:
    replay = FaceDataset(f"data/replay/{s}/annotations.csv")
    for k in kind:
        for i in range(num_samples):
            idx = np.random.randint(len(replay))
            img = replay[idx]["image"]
            fn = replay[idx]["filename"].split("/")[-1]
            save_path = f"figures/attention/replay/{s}/{k}/{fn}"
            plot_attention(img, save_path=save_path)