In [2]:
# run only once to create frames
# import os
# import re
# import cv2

# ROOT_DIR = "../datasets/Avenue"
# OUT_ROOT = os.path.join(ROOT_DIR, "frames")

# def natural_key(s: str):
#     return [int(x) if x.isdigit() else x.lower() for x in re.split(r"(\d+)", s)]

# def extract_video_to_frames(video_path, out_dir):
#     os.makedirs(out_dir, exist_ok=True)
#     cap = cv2.VideoCapture(video_path)
#     if not cap.isOpened():
#         cap.release()
#         raise RuntimeError(f"Could not open video: {video_path}")

#     idx = 0
#     while True:
#         ret, frame = cap.read()
#         if not ret:
#             break
#         cv2.imwrite(os.path.join(out_dir, f"{idx:06d}.jpg"), frame)
#         idx += 1

#     cap.release()
#     return idx

# def collect_videos(video_dir):
#     video_paths = []
#     for name in sorted(os.listdir(video_dir), key=natural_key):
#         p = os.path.join(video_dir, name)
#         if os.path.isdir(p):
#             for f in sorted(os.listdir(p), key=natural_key):
#                 if f.lower().endswith(".avi"):
#                     video_paths.append(os.path.join(p, f))
#         else:
#             if name.lower().endswith(".avi"):
#                 video_paths.append(p)
#     return video_paths

# def extract_split(split):
#     in_dir = os.path.join(ROOT_DIR, split)
#     out_dir = os.path.join(OUT_ROOT, split)
#     os.makedirs(out_dir, exist_ok=True)

#     videos = collect_videos(in_dir)
#     if len(videos) == 0:
#         raise ValueError(f"No .avi files found in: {in_dir}")

#     for vp in videos:
#         base = os.path.splitext(os.path.basename(vp))[0]
#         out_folder = os.path.join(out_dir, base)
#         if os.path.isdir(out_folder) and len(os.listdir(out_folder)) > 0:
#             # Skip if already extracted
#             continue

#         n = extract_video_to_frames(vp, out_folder)
#         print(f"[{split}] {base}: extracted {n} frames")

# extract_split("train")
# extract_split("test")

# print(f"Done. Frames saved under: {OUT_ROOT}")


In [3]:
import numpy as np
from scipy.io import loadmat

def avenue_labels_from_volLabel(mat_path: str):
    """
    Avenue GT: volLabel is (1, T) object array.
    Each entry is a (H, W) mask (uint8).
    Returns: labels (T,) where 1 means anomalous frame.
    """
    mat = loadmat(mat_path)
    v = mat["volLabel"]          # object array (1, T)
    v = v.reshape(-1)            # -> (T,)

    labels = np.zeros(len(v), dtype=np.int64)
    for i in range(len(v)):
        mask = np.array(v[i])    # (H, W)
        labels[i] = 1 if np.any(mask > 0) else 0

    return labels

labels = avenue_labels_from_volLabel("../datasets/Avenue/test_gt/1_label.mat")
print("Frames:", len(labels))
print("Anomalous frames:", labels.sum())
print("First 50 labels:", labels[:50])


Frames: 1439
Anomalous frames: 408
First 50 labels: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [2]:
import numpy as np
from scipy.io import loadmat

mat_path = "../datasets/Avenue/test_gt/1_label.mat"
mat = loadmat(mat_path)

v = mat["volLabel"]
print("raw type:", type(v), "dtype:", v.dtype, "shape:", v.shape)

# unwrap possible nesting
vv = v
depth = 0
while isinstance(vv, np.ndarray) and vv.dtype == object:
    depth += 1
    vv = vv.reshape(-1)
    print(f"unwrap depth {depth}: object array len =", len(vv))
    vv = vv[0]
    if isinstance(vv, np.ndarray):
        print("  inner dtype:", vv.dtype, "shape:", vv.shape)
    else:
        print("  inner type:", type(vv))

vv = np.array(vv)
print("\nfinal vv dtype:", vv.dtype, "shape:", vv.shape, "ndim:", vv.ndim)

# print a few unique values to verify mask nature
flat = vv.reshape(-1)
print("min:", flat.min(), "max:", flat.max(), "unique sample:", np.unique(flat)[:10], "unique_count:", len(np.unique(flat)))


raw type: <class 'numpy.ndarray'> dtype: object shape: (1, 1439)
unwrap depth 1: object array len = 1439
  inner dtype: uint8 shape: (360, 640)

final vv dtype: uint8 shape: (360, 640) ndim: 2
min: 0 max: 0 unique sample: [0] unique_count: 1


In [1]:
# ============================================================
# QUICK GT CHECK (Avenue) - NO TRAINING / NO TESTING
# Purpose:
# 1) Verify your GT .mat files load correctly (volLabel)
# 2) Verify GT length matches extracted frame counts
# 3) Print anomaly stats per video
#
# Expected:
#   ../datasets/Avenue/frames/test/<video_folder>/*.jpg
#   ../datasets/Avenue/test_gt/1_label.mat, 2_label.mat, ...
# ============================================================

import os
import re
import numpy as np
from scipy.io import loadmat

# ----------------------------
# CONFIG
# ----------------------------
ROOT_DIR = "../datasets/Avenue"
FRAMES_TEST_DIR = os.path.join(ROOT_DIR, "frames", "test")
GT_DIR = os.path.join(ROOT_DIR, "test_gt")   # contains 1_label.mat, 2_label.mat, ...

# ----------------------------
# HELPERS
# ----------------------------
def natural_key(s: str):
    return [int(x) if x.isdigit() else x.lower() for x in re.split(r"(\d+)", s)]

def list_test_videos(frames_test_dir):
    vids = sorted(
        [d for d in os.listdir(frames_test_dir) if os.path.isdir(os.path.join(frames_test_dir, d))],
        key=natural_key
    )
    if len(vids) == 0:
        raise ValueError(f"No test video folders found in: {frames_test_dir}")
    return vids

def count_frames(video_folder_path):
    frames = sorted(
        [f for f in os.listdir(video_folder_path) if f.lower().endswith((".jpg",".jpeg",".png",".tif"))],
        key=natural_key
    )
    return len(frames)

def list_gt_matfiles(gt_dir):
    mats = sorted(
        [f for f in os.listdir(gt_dir) if f.lower().endswith(".mat") and "_label" in f.lower()],
        key=natural_key
    )
    if len(mats) == 0:
        raise ValueError(f"No *_label.mat files found in: {gt_dir}")
    return [os.path.join(gt_dir, f) for f in mats]

def load_volLabel(mat_path):
    """
    Avenue GT commonly has key: 'volLabel'
    This function robustly unwraps it into a 1D binary vector.
    """
    mat = loadmat(mat_path)
    if "volLabel" not in mat:
        keys = [k for k in mat.keys() if not k.startswith("__")]
        raise KeyError(f"'volLabel' not found in {mat_path}. Keys={keys}")

    v = mat["volLabel"]

    # unwrap MATLAB cell/object nesting
    while isinstance(v, np.ndarray) and v.dtype == object:
        v = v.reshape(-1)
        if len(v) == 0:
            break
        v = v[0]

    v = np.array(v).squeeze()
    v = v.reshape(-1)

    # binarize
    v = (v > 0).astype(np.int64)
    return v

# ----------------------------
# MAIN CHECK
# ----------------------------
test_videos = list_test_videos(FRAMES_TEST_DIR)
gt_files = list_gt_matfiles(GT_DIR)

print("Test video folders:", len(test_videos))
print("GT .mat files:", len(gt_files))

# map by sorted order
n = min(len(test_videos), len(gt_files))
print("\n---- Checking first", n, "videos ----")

total_frames = 0
total_anom = 0

for i in range(n):
    vname = test_videos[i]
    vpath = os.path.join(FRAMES_TEST_DIR, vname)
    n_frames = count_frames(vpath)

    labels = load_volLabel(gt_files[i])

    # align lengths (just for checking, no fancy rules)
    if len(labels) != n_frames:
        print(f"[WARN] {vname}: frames={n_frames}, gt={len(labels)}  -> aligning to min length")
        m = min(n_frames, len(labels))
        labels = labels[:m]
        n_frames = m

    anom = int(labels.sum())
    total_frames += n_frames
    total_anom += anom

    # print per-video stats
    print(f"{i+1:02d}) video={vname:>8}  frames={n_frames:5d}  anomalous={anom:5d}  anom%={(anom/(n_frames+1e-8))*100:6.2f}")

print("\nTOTAL frames:", total_frames)
print("TOTAL anomalous frames:", total_anom)
print("Overall anom%:", (total_anom / (total_frames + 1e-8)) * 100)


Test video folders: 21
GT .mat files: 21

---- Checking first 21 videos ----
[WARN] 01: frames=1439, gt=230400  -> aligning to min length
01) video=      01  frames= 1439  anomalous=    0  anom%=  0.00
[WARN] 02: frames=1211, gt=230400  -> aligning to min length
02) video=      02  frames= 1211  anomalous=    0  anom%=  0.00
[WARN] 03: frames=923, gt=230400  -> aligning to min length
03) video=      03  frames=  923  anomalous=    0  anom%=  0.00
[WARN] 04: frames=947, gt=230400  -> aligning to min length
04) video=      04  frames=  947  anomalous=    0  anom%=  0.00
[WARN] 05: frames=1007, gt=230400  -> aligning to min length
05) video=      05  frames= 1007  anomalous=    0  anom%=  0.00
[WARN] 06: frames=1283, gt=230400  -> aligning to min length
06) video=      06  frames= 1283  anomalous=    0  anom%=  0.00
[WARN] 07: frames=605, gt=230400  -> aligning to min length
07) video=      07  frames=  605  anomalous=    0  anom%=  0.00
[WARN] 08: frames=36, gt=230400  -> aligning to min

In [3]:
import os
import re
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score, roc_curve
from scipy.io import loadmat

# ============================================================
# CONFIG
# ============================================================
ROOT_DIR = "../datasets/Avenue"
FRAMES_DIR = os.path.join(ROOT_DIR, "frames")  # created by extraction script
SEQUENCE_LENGTH = 5
IMAGE_SIZE = 128
BATCH_SIZE = 4
LR = 1e-4
NUM_EPOCHS = 1  # change later (e.g., 30-50) once pipeline is verified
MODEL_PATH = "attention_Avenue.pth"

# Avenue GT folder structure assumed:
#   ROOT_DIR/test_gt/1_label/*.mat
#   ROOT_DIR/test_gt/2_label/*.mat
GT_ROOT = os.path.join(ROOT_DIR, "test_gt")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ============================================================
# HELPERS
# ============================================================
def natural_key(s: str):
    return [int(x) if x.isdigit() else x.lower() for x in re.split(r"(\d+)", s)]

def normalize(x):
    x = np.array(x, dtype=np.float64)
    return (x - x.min()) / (x.max() - x.min() + 1e-8)

# ============================================================
# DATASET (FRAME FOLDERS) - FAST
# Returns:
#   input_tensor: (T, 1, H, W)
#   target_img:   (1, H, W)
#   video_name:   folder name under frames/<split>/
#   target_idx:   integer frame index (0-based) inside that folder
# ============================================================
class AvenueFramesDataset(Dataset):
    def __init__(self, root_dir, sequence_length=5, image_size=128, mode="train"):
        self.sequence_length = sequence_length
        self.image_size = image_size
        self.mode = mode

        self.video_dir = os.path.join(root_dir, "frames", "train" if mode == "train" else "test")

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

        self.samples = []
        self.video_folders = []
        self._prepare_samples()

    def _prepare_samples(self):
        if not os.path.isdir(self.video_dir):
            raise FileNotFoundError(
                f"Frames directory not found: {self.video_dir}\n"
                f"Did you extract videos to frames under {FRAMES_DIR}?"
            )

        self.video_folders = sorted(
            [d for d in os.listdir(self.video_dir) if os.path.isdir(os.path.join(self.video_dir, d))],
            key=natural_key
        )

        if len(self.video_folders) == 0:
            raise ValueError(f"No frame folders found under: {self.video_dir}")

        for video in self.video_folders:
            video_path = os.path.join(self.video_dir, video)

            frames = sorted([
                f for f in os.listdir(video_path)
                if f.lower().endswith((".jpg", ".jpeg", ".png", ".tif"))
            ], key=natural_key)

            if len(frames) <= self.sequence_length:
                continue

            for start in range(len(frames) - self.sequence_length):
                # inputs: start..start+T-1, target: start+T
                self.samples.append((video_path, video, frames, start))

        if len(self.samples) == 0:
            raise ValueError(
                f"No samples created. Check that each folder has > {self.sequence_length} frames."
            )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_path, video_name, frames, start = self.samples[idx]

        input_tensor = []
        for t in range(self.sequence_length):
            img = Image.open(os.path.join(video_path, frames[start + t])).convert("L")
            img = self.transform(img)
            input_tensor.append(img)

        input_tensor = torch.stack(input_tensor, dim=0)

        target_idx = start + self.sequence_length
        target_img = Image.open(os.path.join(video_path, frames[target_idx])).convert("L")
        target_img = self.transform(target_img)

        return input_tensor, target_img, video_name, target_idx

# ============================================================
# MODEL (same architecture; ConvLSTM init fix included)
# ============================================================
class CNNEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.conv = nn.Conv2d(input_dim + hidden_dim, 4 * hidden_dim, 3, padding=1)

    def forward(self, x, h, c):
        combined = torch.cat([x, h], dim=1)
        i, f, o, g = torch.split(self.conv(combined), self.hidden_dim, dim=1)

        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        o = torch.sigmoid(o)
        g = torch.tanh(g)

        c = f * c + i * g
        h = o * torch.tanh(c)
        return h, c

class ConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.cell = ConvLSTMCell(input_dim, hidden_dim)
        self.hidden_dim = hidden_dim

    def forward(self, x):
        # x: (B, T, C, H, W)
        B, T, C, H, W = x.size()

        # FIX: hidden state channels must be hidden_dim, not C
        h = torch.zeros(B, self.hidden_dim, H, W, device=x.device)
        c = torch.zeros_like(h)

        for t in range(T):
            h, c = self.cell(x[:, t], h, c)

        return h

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x))

class SpatioTemporalAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = ChannelAttention(channels)
        self.spatial_att = SpatialAttention()

    def forward(self, x):
        x = x * self.channel_att(x)
        x = x * self.spatial_att(x)
        return x

class FutureFramePredictorWithAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = CNNEncoder()
        self.convlstm = ConvLSTM(128, 128)
        self.attention = SpatioTemporalAttention(128)
        self.decoder = Decoder()

    def forward(self, x):
        # x: (B, T, 1, H, W)
        encoded = torch.stack([self.encoder(x[:, t]) for t in range(x.size(1))], dim=1)
        h = self.convlstm(encoded)
        h = self.attention(h)
        return self.decoder(h)

Device: cuda


In [4]:
# ============================================================
# TRAIN
# ============================================================
train_dataset = AvenueFramesDataset(
    root_dir=ROOT_DIR,
    sequence_length=SEQUENCE_LENGTH,
    image_size=IMAGE_SIZE,
    mode="train"
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,       # minimal speedup
    pin_memory=True
)

model = FutureFramePredictorWithAttention().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0

    for inputs, target, _, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        inputs = inputs.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad()
        pred = model(inputs)
        loss = criterion(pred, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1} Avg Loss: {running_loss / len(train_loader):.6f}")

torch.save(model.state_dict(), MODEL_PATH)
print(f"Saved model -> {MODEL_PATH}")

Epoch 1/1: 100%|███████████████████████████████████████████████████████████████████| 3812/3812 [09:12<00:00,  6.90it/s]

Epoch 1 Avg Loss: 0.003305
Saved model -> attention_Avenue.pth





In [None]:
# ============================================================
# TEST: Compute combined anomaly scores + store meta
# ============================================================
test_dataset = AvenueFramesDataset(
    root_dir=ROOT_DIR,
    sequence_length=SEQUENCE_LENGTH,
    image_size=IMAGE_SIZE,
    mode="test"
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

attention_model = FutureFramePredictorWithAttention().to(device)
attention_model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
attention_model.eval()

alpha = 0.3
beta = 0.7

combined_scores = []
meta = []  # (video_name, target_frame_idx)

with torch.no_grad():
    for inputs, target, video_name, target_idx in tqdm(test_loader, desc="Testing (Pixel+Feature)"):
        inputs = inputs.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        pred = attention_model(inputs)

        pixel_err = torch.mean((pred - target) ** 2)
        feat_pred = attention_model.encoder(pred)
        feat_gt = attention_model.encoder(target)
        feature_err = torch.mean((feat_pred - feat_gt) ** 2)

        score = alpha * pixel_err + beta * feature_err

        combined_scores.append(score.item())
        meta.append((video_name[0], int(target_idx.item())))

scores = normalize(combined_scores)

plt.figure(figsize=(12, 4))
plt.plot(scores)
plt.title("Normalized Anomaly Scores (Avenue: Attention + Pixel/Feature)")
plt.xlabel("Sample Index")
plt.ylabel("Score")
plt.grid(True)
plt.show()

  attention_model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
Testing (Pixel+Feature):   0%|                                                               | 0/15219 [00:00<?, ?it/s]

In [None]:
# ============================================================
# GT LOADING (Avenue .mat folders) - robust inference
# Assumes:
#   test_gt/1_label/*.mat corresponds to test video folder order
#   test_gt/2_label/*.mat corresponds to next test video folder, etc.
# ============================================================
def find_first_mat(folder_path: str):
    mats = sorted([f for f in os.listdir(folder_path) if f.lower().endswith(".mat")], key=natural_key)
    if not mats:
        return None
    return os.path.join(folder_path, mats[0])

def to_1d_int_array(x):
    x = np.array(x)
    if x.dtype == object:
        flat = x.reshape(-1)
        if len(flat) == 1:
            x = np.array(flat[0])
        else:
            raise ValueError("Object array has multiple entries; cannot convert directly to vector.")
    return x.reshape(-1).astype(np.int64)

def infer_labels_from_mat(mat_path: str, n_frames: int):
    mat = loadmat(mat_path)
    keys = [k for k in mat.keys() if not k.startswith("__")]

    arrays = [(k, mat[k]) for k in keys if isinstance(mat[k], np.ndarray) and mat[k].size > 1]
    if not arrays:
        raise ValueError(f"No usable arrays in {mat_path}. Keys: {keys}")

    def key_rank(k):
        lk = k.lower()
        prefs = ["label", "labels", "gt", "ground", "truth", "mask"]
        return 0 if any(p in lk for p in prefs) else 1

    arrays = sorted(arrays, key=lambda kv: (key_rank(kv[0]), -kv[1].size))

    for k, v in arrays:
        # Case 1: per-frame binary vector length == n_frames
        try:
            vec = to_1d_int_array(v)
            uniq = np.unique(vec)
            if set(uniq.tolist()).issubset({0, 1}) and len(vec) == n_frames:
                return vec
        except Exception:
            pass

        # Case 2: intervals (K,2) or (2,K)
        if isinstance(v, np.ndarray) and v.dtype != object and v.ndim == 2 and (v.shape[1] == 2 or v.shape[0] == 2):
            intervals = v
            if intervals.shape[1] != 2 and intervals.shape[0] == 2:
                intervals = intervals.T
            intervals = intervals.astype(np.int64)

            labels = np.zeros(n_frames, dtype=np.int64)
            starts = intervals[:, 0]
            one_based = (np.any(starts == 1) and not np.any(starts == 0))

            for s, e in intervals:
                if one_based:
                    s -= 1
                    e -= 1
                s0 = max(int(s), 0)
                e0 = min(int(e), n_frames - 1)
                if e0 >= s0:
                    labels[s0:e0 + 1] = 1
            return labels

        # Case 3: object/cell containing intervals
        if isinstance(v, np.ndarray) and v.dtype == object:
            flat = v.reshape(-1)
            if len(flat) == 1:
                entry = np.array(flat[0])
                if entry.ndim == 2 and (entry.shape[1] == 2 or entry.shape[0] == 2):
                    intervals = entry
                    if intervals.shape[1] != 2 and intervals.shape[0] == 2:
                        intervals = intervals.T
                    intervals = intervals.astype(np.int64)

                    labels = np.zeros(n_frames, dtype=np.int64)
                    starts = intervals[:, 0]
                    one_based = (np.any(starts == 1) and not np.any(starts == 0))

                    for s, e in intervals:
                        if one_based:
                            s -= 1
                            e -= 1
                        s0 = max(int(s), 0)
                        e0 = min(int(e), n_frames - 1)
                        if e0 >= s0:
                            labels[s0:e0 + 1] = 1
                    return labels

    raise ValueError(f"Could not infer labels from {mat_path}. Keys={keys}")

def build_per_video_gt_from_folders(test_video_names, gt_root):
    gt_folders = sorted(
        [d for d in os.listdir(gt_root) if os.path.isdir(os.path.join(gt_root, d)) and d.endswith("_label")],
        key=natural_key
    )
    if len(gt_folders) < len(test_video_names):
        raise ValueError(
            f"Not enough GT folders in {gt_root}: gt={len(gt_folders)} test_videos={len(test_video_names)}"
        )

    per_video = {}
    for i, vname in enumerate(test_video_names):
        video_path = os.path.join(ROOT_DIR, "frames", "test", vname)
        if not os.path.isdir(video_path):
            raise FileNotFoundError(f"Test frame folder missing: {video_path}")

        frames = sorted([f for f in os.listdir(video_path) if f.lower().endswith((".jpg",".jpeg",".png",".tif"))], key=natural_key)
        n_frames = len(frames)

        gt_folder = os.path.join(gt_root, gt_folders[i])
        mat_path = find_first_mat(gt_folder)
        if mat_path is None:
            raise FileNotFoundError(f"No .mat found in {gt_folder}")

        labels = infer_labels_from_mat(mat_path, n_frames)
        per_video[vname] = labels

    return per_video

# ============================================================
# AUC
# ============================================================
test_video_names = test_dataset.video_folders
per_video_gt = build_per_video_gt_from_folders(test_video_names, GT_ROOT)

gt_labels = []
for (vname, frame_idx) in meta:
    labels = per_video_gt[vname]
    if frame_idx < 0 or frame_idx >= len(labels):
        raise ValueError(f"Frame idx out of range for {vname}: idx={frame_idx}, len={len(labels)}")
    gt_labels.append(int(labels[frame_idx]))

gt_labels = np.array(gt_labels, dtype=np.int64)

auc = roc_auc_score(gt_labels, scores)
print(f"\nAvenue AUC (Attention + Pixel/Feature): {auc:.4f}")

fpr, tpr, _ = roc_curve(gt_labels, scores)

plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, label=f"AUC={auc:.3f}")
plt.plot([0, 1], [0, 1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Avenue ROC (Attention + Pixel/Feature)")
plt.legend()
plt.grid(True)
plt.show()

# ============================================================
# OPTIONAL: Visualize a sample prediction (from test set)
# ============================================================
idx = min(200, len(test_dataset) - 1)
with torch.no_grad():
    x, y, vname, fidx = test_dataset[idx]
    x = x.unsqueeze(0).to(device)
    y = y.unsqueeze(0).to(device)

    pred = attention_model(x)

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.title(f"Ground Truth\n{vname} frame {fidx}")
    plt.imshow(y[0, 0].cpu(), cmap="gray")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.title("Predicted")
    plt.imshow(pred[0, 0].cpu(), cmap="gray")
    plt.axis("off")
    plt.show()