In [None]:
import os

# suppress warnings
import warnings

import cv2
import IPython.display as ipd
import nopdb
import numpy as np
import PIL
import torch
from matplotlib import pyplot as plt

from spoof.dataset.dataset import FaceDataset
from spoof.model.vit import ViT

warnings.filterwarnings("ignore")

#### Load and visualize an image returned from the ```FaceDataset```

In [None]:
# Setup dataset to get a random image and read it
os.chdir("..")
ds = FaceDataset("data/casia/test/annotations.csv")
img = ds[np.random.randint(len(ds))]["image"]

In [None]:
# Define some functions to map the tensor back to an image
def inv_normalize(tensor):
    """Normalize an image tensor back to the 0-255 range."""
    tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
    return tensor


def inv_transform(tensor, normalize=True):
    """Convert a tensor to an image."""
    if normalize:
        tensor = inv_normalize(tensor)
    array = tensor.detach().cpu().numpy()
    array = array.transpose(1, 2, 0).astype(np.uint8)
    return array


def show_img(tensor):
    array = inv_transform(tensor)
    return PIL.Image.fromarray(array)

In [None]:
show_img(img)

#### Load the model

In [None]:
# Create model and let it load the weights in the constructor
model = ViT()

#### Capture attention weights

In [None]:
def predict(input):
    """Run the model on an input and print the predicted classes with probabilities."""
    with torch.no_grad():
        score = model.get_liveness_score(model(input))
    return score


def get_attn_call(input, layer_idx):
    with nopdb.capture_call(
        model.extractor.encoder.layers[layer_idx].self_attention.forward
    ) as attn_call:
        predict(input).item()
    return attn_call

In [None]:
# Get the attention matrix for the last layer
attn_mat = get_attn_call(img, 11).locals["attn_output_weights"]
print(f"Attention matrix shape: {attn_mat.shape}")

#### Visualize attention map 

In [None]:
def get_attention_map(img, layer_idx, get_mask=False):
    """Get the attention map for an image."""
    attn_mat = get_attn_call(img, layer_idx).locals["attn_output_weights"]

    # Average the attention weights across all heads.
    att_mat = torch.mean(attn_mat, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1])

    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()

    if get_mask:
        result = cv2.resize(mask / mask.max(), (img.shape[1], img.shape[2]))
    else:
        # result = cv2.resize(img.numpy().transpose(1, 2, 0), (img.shape[1], img.shape[2]))
        # result = result.astype(np.float32) + 1
        # result /= result.max()
        mask = cv2.resize(mask / mask.max(), (img.shape[1], img.shape[2]))[..., np.newaxis]
        result = (mask * img.numpy().transpose(1, 2, 0)).astype("uint8")

    return result

In [None]:
def plot_attention_map(original_img, att_map):
    _, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title("Original")
    ax2.set_title("Attention Map Last Layer")
    _ = ax1.imshow(original_img)
    _ = ax2.imshow(att_map)

#### Visualize without mask

In [None]:
att_map = get_attention_map(img, layer_idx=11, get_mask=False)
original_img = inv_transform(img)
plot_attention_map(original_img, att_map)

#### Visualize with mask

In [None]:
att_map = get_attention_map(img, layer_idx=11, get_mask=True)
original_img = inv_transform(img)
plot_attention_map(original_img, att_map)