# ViViT Eye Blink Detection â€” Evaluation on Test Videos

This notebook evaluates a fine-tuned ViViT model on raw `.mp4` eye videos
and reports Accuracy, Precision, Recall, and Confusion Matrix.

No metadata or MediaPipe preprocessing is used during evaluation.


ðŸŸ¦ 2. Imports & Device Setup

In [None]:
import os
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    confusion_matrix,
    precision_recall_curve
)

from transformers import VivitForVideoClassification


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_ROOT = "/content/drive/MyDrive/eye_dataset_test"  # blink / no_blink folders
MODEL_PATH = "/content/drive/MyDrive/best_model(1).pth"

TARGET_FRAMES = 32
IMG_SIZE = 224


ðŸŸ¦ 3. Load Trained ViViT Model

In [None]:
model = VivitForVideoClassification.from_pretrained(
    "google/vivit-b-16x2-kinetics400",
    num_labels=2,
    ignore_mismatched_sizes=True
)

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()


ðŸŸ¦ 3. Video Loading Utility

In [None]:
def load_video_tensor(video_path, target_frames=32, size=224):
    cap = cv2.VideoCapture(video_path)
    frames = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (size, size))
        frames.append(frame)

    cap.release()

    if len(frames) == 0:
        return None

    frames = np.stack(frames)

    # Temporal sampling
    idx = np.linspace(0, len(frames)-1, target_frames).astype(int)
    frames = frames[idx]

    frames = torch.from_numpy(frames).float() / 255.0
    frames = frames.permute(0, 3, 1, 2)  # T C H W
    frames = frames.unsqueeze(0)         # 1 T C H W

    return frames


ðŸŸ¦ 4. Evaluation Dataset (MP4-based, FAST)

In [None]:
all_probs = []
all_labels = []

label_map = {"no_blink": 0, "blink": 1}

video_list = []
for cls in ["blink", "no_blink"]:
    cls_dir = os.path.join(TEST_ROOT, cls)
    for f in os.listdir(cls_dir):
        if f.endswith(".mp4"):
            video_list.append((os.path.join(cls_dir, f), label_map[cls]))

print(f"Total test videos: {len(video_list)}")


ðŸŸ¦ 5. Run Evaluation

In [None]:
with torch.no_grad():
    for video_path, label in tqdm(video_list, desc="Evaluating"):
        video_tensor = load_video_tensor(video_path)

        if video_tensor is None:
            continue

        video_tensor = video_tensor.to(device)
        logits = model(video_tensor).logits

        prob = torch.softmax(logits, dim=1)[0, 1].item()  # blink prob

        all_probs.append(prob)
        all_labels.append(label)


ðŸŸ¦ 6. Metrics

In [None]:
all_probs = np.array(all_probs)
all_labels = np.array(all_labels)

precision, recall, thresholds = precision_recall_curve(
    all_labels, all_probs
)

plt.figure(figsize=(6,5))
plt.plot(recall, precision)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precisionâ€“Recall Curve (ViViT Blink Detection)")
plt.grid(True)
plt.show()


In [None]:
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
best_idx = np.argmax(f1_scores)

best_thresh = thresholds[best_idx]
print(f"Best threshold (F1): {best_thresh:.3f}")
print(f"Precision: {precision[best_idx]:.3f}")
print(f"Recall   : {recall[best_idx]:.3f}")


In [None]:
THRESH =0.05  # adjust visually from PR curve

preds = (all_probs > THRESH).astype(int)

acc = accuracy_score(all_labels, preds)
prec = precision_score(all_labels, preds)
rec = recall_score(all_labels, preds)
cm = confusion_matrix(all_labels, preds)

print(f"Accuracy : {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall   : {rec:.4f}")
print("Confusion Matrix:\n", cm)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(
    y_true,
    y_pred,
    class_names=("No Blink", "Blink"),
    normalize=True,
    threshold=None,
    save_path="confusion_matrix.png"
):
    cm = confusion_matrix(y_true, y_pred)

    if normalize:
        cm_norm = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True)
    else:
        cm_norm = cm

    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(cm_norm, cmap="Blues")

    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, fontsize=12)
    ax.set_yticklabels(class_names, fontsize=12)

    ax.set_xlabel("Predicted Label", fontsize=13)
    ax.set_ylabel("True Label", fontsize=13)

    title = "Confusion Matrix â€“ ViViT Blink Detection"
    if threshold is not None:
        title += f" (Threshold = {threshold:.2f})"
    if normalize:
        title += "\n(Row-Normalized)"

    ax.set_title(title, fontsize=14, pad=12)

    # Annotate cells
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if normalize:
                text = f"{cm_norm[i, j]:.2f}\n({cm[i, j]})"
            else:
                text = str(cm[i, j])

            ax.text(
                j, i, text,
                ha="center", va="center",
                fontsize=12,
                color="white" if cm_norm[i, j] > 0.5 else "black"
            )

    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=11)

    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    plt.show()

    print(f"Saved confusion matrix to {save_path}")


In [None]:
for t in [0.5, 0.2, 0.1, 0.05, 0.02]:
    preds = (all_probs > t).astype(int)
    rec = recall_score(all_labels, preds, zero_division=0)
    prec = precision_score(all_labels, preds, zero_division=0)
    print(f"thr={t:.2f} | precision={prec:.3f} | recall={rec:.3f}")


In [None]:
plot_confusion_matrix(
    all_labels,
    preds,
    threshold=THRESH,
    save_path="confusion_matrix_thresh_0.2.png"
)
