# ViViT Inference and Attention Visualization for Eye Blink Detection

This notebook performs inference on unseen videos and visualizes
spatiotemporal attention maps produced by ViViT.


## 1. Imports and Model Loading

We load the trained ViViT model with attention outputs enabled
to analyze temporal and spatial focus during blink detection.


In [None]:
import os
import cv2
import csv
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import VivitForVideoClassification


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = VivitForVideoClassification.from_pretrained(
    "google/vivit-b-16x2-kinetics400",
    num_labels=2,
    ignore_mismatched_sizes=True,
    output_attentions=True,
).to(device)

checkpoint_path = "/content/drive/MyDrive/best_model(1).pth"
state = torch.load(checkpoint_path, map_location="cpu")

if "state_dict" in state:
    state = state["state_dict"]

model.load_state_dict(state, strict=False)
model.eval()


Some weights of VivitForVideoClassification were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


VivitForVideoClassification(
  (vivit): VivitModel(
    (embeddings): VivitEmbeddings(
      (patch_embeddings): VivitTubeletEmbeddings(
        (projection): Conv3d(3, 768, kernel_size=(2, 16, 16), stride=(2, 16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): VivitEncoder(
      (layer): ModuleList(
        (0-11): 12 x VivitLayer(
          (attention): VivitAttention(
            (attention): VivitSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): VivitSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): VivitIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bi

## 2. Video Loading and Preprocessing

Input videos are resized, padded, and temporally aligned to match
ViViT’s expected input format.


In [None]:
def load_eye_video(
    video_path,
    target_frames=32,
    output_hw=(96, 192),
):
    cap = cv2.VideoCapture(video_path)
    frames = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (output_hw[1], output_hw[0]))
        frames.append(frame)

    cap.release()

    if len(frames) == 0:
        raise ValueError("No frames loaded")

    real_frames = len(frames)
    frames = np.array(frames, dtype=np.uint8)

    if real_frames > target_frames:
        idx = np.linspace(0, real_frames - 1, target_frames).astype(int)
        frames = frames[idx]
        real_frames = target_frames
    elif real_frames < target_frames:
        pad = target_frames - real_frames
        frames = np.concatenate(
            [frames, np.repeat(frames[-1:], pad, axis=0)], axis=0
        )

    return frames, real_frames


In [None]:
def pad_to_224(frames):
    H, W = frames.shape[1], frames.shape[2]

    pad_top = (224 - H) // 2
    pad_bottom = 224 - H - pad_top
    pad_left = (224 - W) // 2
    pad_right = 224 - W - pad_left

    return np.pad(
        frames,
        ((0, 0), (pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
        mode="constant",
        constant_values=0,
    )


In [None]:
video_path = "/content/drive/MyDrive/eye_dataset_test/blink/164.mp4"

frames_orig, real_frames = load_eye_video(video_path)
frames_224 = pad_to_224(frames_orig)

video_tensor = torch.from_numpy(frames_224).float() / 255.0
video_tensor = video_tensor.permute(0, 3, 1, 2)  # (T, 3, H, W)
video_tensor = video_tensor.unsqueeze(0).to(device)  # (1, 3, T, H, W)

with torch.no_grad():
    outputs = model(video_tensor)


## 3. Attention Rollout

We extract attention maps from the final transformer layer and compute:
- Temporal attention per tubelet
- Spatial attention per frame


In [None]:
def attention_rollout_vivit(outputs):
    attn = outputs.attentions[0][0].mean(0)
    attn = attn[1:, 1:]

    num_patches = 14 * 14
    T_tokens = attn.shape[0] // num_patches

    temporal_attn = []
    spatial_attn = []

    for t in range(T_tokens):
        start = t * num_patches
        end = (t + 1) * num_patches
        block = attn[start:end, start:end]

        temporal_attn.append(block.mean().item())
        spatial_attn.append(block.mean(0).reshape(14, 14).cpu())

    return np.array(temporal_attn), np.stack(spatial_attn)


In [None]:
temporal_attn, spatial_attn = attention_rollout_vivit(outputs)

tubelet_size = 2
real_tokens = int(np.ceil(real_frames / tubelet_size))

temporal_attn = temporal_attn[:real_tokens]
spatial_attn = spatial_attn[:real_tokens]


In [None]:
def crop_spatial_attention(attn_14, orig_hw=(96, 192), padded_hw=(224, 224)):
    H, W = padded_hw
    h0 = (H - orig_hw[0]) // 2
    w0 = (W - orig_hw[1]) // 2

    attn_224 = cv2.resize(attn_14, (W, H))
    return attn_224[h0:h0+orig_hw[0], w0:w0+orig_hw[1]]


## 4. Temporal Saliency Estimation

Temporal saliency is computed by aggregating spatial attention across frames.
This provides a soft indication of blink likelihood over time.


In [None]:
def compute_temporal_saliency(spatial_attn):
    temporal = spatial_attn.mean(axis=(1, 2)).astype(np.float32)
    temporal -= temporal.min()
    temporal /= (temporal.max() + 1e-6)
    return temporal


In [None]:
temporal_saliency = compute_temporal_saliency(spatial_attn)


In [None]:
def detect_blinks(temporal_saliency, smoothing=3, threshold_factor=0.5):
    kernel = np.ones(smoothing) / smoothing
    sal = np.convolve(temporal_saliency, kernel, mode="same")

    threshold = threshold_factor * sal.max()

    blink_tokens = [
        t for t in range(1, len(sal) - 1)
        if sal[t] > threshold and sal[t] > sal[t-1] and sal[t] > sal[t+1]
    ]

    return blink_tokens, sal, threshold


In [None]:
blink_tokens, temporal_smooth, blink_threshold = detect_blinks(temporal_saliency)
blink_frames = [t * 2 for t in blink_tokens]


## 5. Attention Overlay on Eye Frames

Spatial attention maps are overlaid onto eye-region frames and modulated
by temporal saliency to highlight blink-relevant frames.


In [None]:
def overlay_attention_single_frame(
    frame,
    spatial_map,
    temporal_saliency_t,
    alpha=0.35,
    gamma=2.5,
    temporal_gain=3.0,
):
    attn = spatial_map.astype(np.float32)
    attn /= (attn.max() + 1e-6)

    attn *= np.exp(temporal_gain * temporal_saliency_t)
    attn = attn ** gamma

    attn -= attn.min()
    attn /= (attn.max() + 1e-6)

    heat = cv2.applyColorMap(
        (attn * 255).astype(np.uint8),
        cv2.COLORMAP_INFERNO
    )

    return cv2.addWeighted(frame, 1 - alpha, heat, alpha, 0)


In [None]:
def overlay_blink_marker(
    frame,
    is_blink,
    frame_idx,
):
    frame = frame.copy()
    h, w = frame.shape[:2]

    # ── tiny frame index (bottom-left)
    cv2.putText(
        frame,
        f"{frame_idx}",
        (4, h - 4),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.28,               # VERY small
        (255, 255, 255),
        1,
        cv2.LINE_AA,
    )

    # ── subtle blink dot (top-right)
    if is_blink:
        cv2.circle(
            frame,
            (w - 12, 12),
            4,               # small dot
            (0, 255, 255),   # vibrant yellow
            -1,
        )

    return frame


## 6. Visualization Video Export

We save:
- Attention overlays
- Side-by-side comparisons of original and attention-enhanced videos


In [None]:
vis_frames = []

for i, frame in enumerate(frames_orig[:real_frames]):
    t = min(i // tubelet_size, real_tokens - 1)

    spatial_map = crop_spatial_attention(
        spatial_attn[t],
        orig_hw=(96, 192),
        padded_hw=(224, 224),
    )

    vis = overlay_attention_single_frame(
        frame,
        spatial_map,
        temporal_saliency[t],
    )

    vis = overlay_blink_marker(vis, t in blink_tokens, i)
    vis_frames.append(vis)


In [None]:
def save_video(frames, path, fps=12):
    h, w, _ = frames[0].shape
    writer = cv2.VideoWriter(
        path,
        cv2.VideoWriter_fourcc(*"mp4v"),
        fps,
        (w, h)
    )
    for f in frames:
        writer.write(f.astype(np.uint8))
    writer.release()


In [None]:
save_video(vis_frames, "blink_attention_overlay.mp4")


In [None]:
def stack_and_save_video(orig, overlay, path, fps=3):
    stacked = [cv2.hconcat([o, v]) for o, v in zip(orig, overlay)]
    save_video(stacked, path, fps)


In [None]:
stack_and_save_video(
    frames_orig,
    vis_frames,
    "blink_attention_side_by_side.mp4"
)


## Output Artifacts

- Attention overlay videos (`.mp4`)
- Side-by-side visual comparisons
- Temporal saliency curves (optional)
