In [None]:
import os

import cv2
import IPython.display as ipd
import nopdb
import numpy as np
import PIL
import torch
from timm.data import create_transform, resolve_data_config

from spoof.dataset.dataset import FaceDataset
from spoof.model.vit import ViT

#### Helper functions

In [None]:
# Define some functions to map the tensor back to an image
def inv_normalize(tensor):
    """Normalize an image tensor back to the 0-255 range."""
    tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
    return tensor


def inv_transform(tensor, normalize=True):
    """Convert a tensor to an image."""
    tensor = inv_normalize(tensor)
    array = tensor.detach().cpu().numpy()
    array = array.transpose(1, 2, 0).astype(np.uint8)
    return PIL.Image.fromarray(array)

In [None]:
# Setup dataset to get a random image and read it
os.chdir("/Users/motorbreath/mipt/spoof/")
ds = FaceDataset("data/casia/test/annotations.csv")
idx = np.random.randint(0, len(ds))
img_path = ds.annotations.iloc[idx, 0]
img = cv2.imread(img_path)[..., ::-1]

#### Load our model 

In [None]:
model = ViT()

#### Visualize the image using `timm` transform

In [None]:
# Get the function to transform the image to a tensor
config = resolve_data_config({}, model=model)
timm_transform = create_transform(**config)

# Load the image and transform it to a tensor
img = PIL.Image.open(img_path).convert("RGB")
input = timm_transform(img)
inv_transform(input)

#### Visualize images using `FaceRegionRCXT` class

In [None]:
img = ds[idx]["image"]
inv_transform(img)

#### Predict the scores

In [None]:
def predict(input):
    """Run the model on an input and print the predicted classes with probabilities."""
    with torch.no_grad():
        score = model.get_liveness_score(model(input))
    return score

In [None]:
score = predict(input).item()
print(f"score: {score:.3f}")

#### Capture and plot the attention weights

In [None]:
def plot_attention(input, attn):
    """Given an input image and the attention tensor, plot the average attention weight given to each image patch by each attention head."""
    with torch.no_grad():
        # Loop over attention heads
        for h_weights in attn:
            # h_weights = h_weights.mean(axis=-2)  # Average over all attention keys
            h_weights = h_weights[1:]  # Skip the [class] token
            plot_weights(input, h_weights)


def plot_weights(input, patch_weights):
    """Display the image, dimming each patch according to the given weight."""
    # Multiply each patch of the input image by the corresponding weight
    plot = inv_normalize(input.clone())
    for i in range(patch_weights.shape[0]):
        x = i * 16 % 224
        y = i // (224 // 16) * 16
        plot[:, y : y + 16, x : x + 16] *= patch_weights[i]
    ipd.display(inv_transform(plot, normalize=False))


def predict_with_attn(input, layer_idx):
    with nopdb.capture_call(
        model.extractor.encoder.layers[layer_idx].self_attention.forward
    ) as attn_call:
        score = predict(input).item()
    return score, attn_call

In [None]:
score, attn_call = predict_with_attn(input, 0)

print(f"score: {score:.3f}")
print(f"captured call: {attn_call.locals.keys()}")