## Import Dependencies

In [1]:
import os, re
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import resnet50
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch_geometric.data import Data as PyGData, Batch as PyGBatch
from sklearn.metrics import precision_recall_fscore_support
from torch_geometric.nn import GATConv
from torch_geometric.utils import to_dense_batch
from torch.utils.data import Dataset
import math, random, time
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

## USER CONFIG


In [2]:
RGB_FOLDER = Path(r"D:\Datasets\Datasets\EPIC_Kitchen\RGB\P01_01\Original")   # folder with RGB frames
FLOW_FOLDER = Path(r"D:\Datasets\Datasets\EPIC_Kitchen\OpticalFlow\P01_01\P01_01") # folder with flow frames (same filenames); set None to skip flow
LABEL_CSV   = Path(r"D:\Datasets\Datasets\EPIC\Labels\P01_01.csv")  # label file for this video
OUTPUT_FUSED_CSV = Path(r"D:\Datasets\Datasets\EPIC\Features\P01_01_fused_features.csv")  # where fused CSV will be saved

SAMPLE_RATE = 1     # take every S-th frame (1 => every frame)
FEAT_DIM = 512      # output projection dim
W_RGB = 0.6         # fusion weight for RGB
W_FLOW = 0.4        # fusion weight for Flow (ignored if FLOW_FOLDER is None)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

OUTPUT_FUSED_CSV.parent.mkdir(parents=True, exist_ok=True)


### helper: parse integer frame index from filename e.g. frame_000123.jpg -> 123


In [74]:
_frame_number_re = re.compile(r"(\d+)(?=\.[^.]+$)")
def parse_frame_index(fname: str):
    m = _frame_number_re.search(fname)
    if m:
        return int(m.group(1))
    digs = re.findall(r"\d+", fname)
    return int(digs[-1]) if digs else 0

## Feature extractor: ResNet50 backbone (pretrained) + linear proj to FEAT_DIM

In [75]:
_resnet = resnet50(weights=True)
_resnet = nn.Sequential(*list(_resnet.children())[:-1]).to(DEVICE).eval()   # outputs (B,2048,1,1)
_proj = nn.Linear(2048, FEAT_DIM).to(DEVICE).eval()
_transform = T.Compose([T.Resize((224,224)), T.ToTensor(),
                        T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])

@torch.no_grad()
def extract_feature_from_pil(pil_img: Image.Image):
    x = _transform(pil_img).unsqueeze(0).to(DEVICE)   # (1,3,224,224)
    feat = _resnet(x).view(1, -1)                     # (1,2048)
    feat = _proj(feat)                                # (1,FEAT_DIM)
    return feat.squeeze(0).cpu().numpy()              # (FEAT_DIM,)



# -------------------------
# Extract RGB & Flow features, fuse, save fused CSV
# -------------------------
def extract_and_save_fused(csv_labels_path: Path,
                           rgb_folder: Path,
                           flow_folder: Path or None,
                           out_fused_csv: Path,
                           sample_rate: int = 1,
                           w_rgb: float = 0.6,
                           w_flow: float = 0.4):
    """
    Extract & fuse features for one video folder. Saves fused CSV with columns:
    frame_idx, frame_name, ActionLabel, ActionName, feat_0..feat_{FEAT_DIM-1}
    """
    
    # load labels
    labels_df = pd.read_csv(csv_labels_path)
    # list rgb frames (canonical filenames)
    rgb_files = sorted([p for p in rgb_folder.iterdir() if p.suffix.lower() in [".jpg",".png",".jpeg"]])
    sampled = rgb_files[::sample_rate]
    if len(sampled) == 0:
        raise RuntimeError(f"No frames found in {rgb_folder}")

    # containers
    fused_rows = []
    feat_cols = [f"feat_{i}" for i in range(FEAT_DIM)]

    # iterate frames
    for fp in tqdm(sampled, desc="Extract & fuse"):
        fname = fp.name
        frame_idx = parse_frame_index(fname)

        # --- RGB feature ---
        try:
            pil = Image.open(fp).convert("RGB")
            rgb_feat = extract_feature_from_pil(pil)
        except Exception as e:
            print(f"[WARN] RGB skip {fname}: {e}")
            continue

        # --- Flow feature ---
        if flow_folder is not None:
            ffp = Path(flow_folder) / fname
            if not ffp.exists():
                # fallback: use RGB image for alignment (keeps pipeline working)
                ffp = fp
            try:
                pilf = Image.open(ffp).convert("RGB")
                flow_feat = extract_feature_from_pil(pilf)
            except Exception as e:
                print(f"[WARN] FLOW skip {fname}: {e}; using zeros")
                flow_feat = np.zeros(FEAT_DIM, dtype=np.float32)
        else:
            flow_feat = np.zeros(FEAT_DIM, dtype=np.float32)
            
# -------------------------- FUSION Here ------------------------------------------------

        # --- fused vector (weighted sum) ---
        fused_vec = w_rgb * rgb_feat.astype(np.float32) + w_flow * flow_feat.astype(np.float32)

        # find label row that contains this frame (StartFrame <= idx <= EndFrame)
        lr = labels_df[(labels_df["StartFrame"] <= frame_idx) & (labels_df["EndFrame"] >= frame_idx)]
        if not lr.empty:
            action_label = int(lr.iloc[0].get("ActionLabel", -1))
            action_name  = str(lr.iloc[0].get("ActionName", "Unknown"))
        else:
            action_label, action_name = -1, "Unknown"

        # build row
        row = {"frame_idx": int(frame_idx), "frame_name": fname, "ActionLabel": int(action_label), "ActionName": action_name}
        for i_val, v in enumerate(fused_vec):
            row[f"feat_{i_val}"] = float(v)
        fused_rows.append(row)

    # save
    if len(fused_rows) == 0:
        raise RuntimeError("No fused rows extracted; check paths and files.")
    df_fused = pd.DataFrame(fused_rows)
    df_fused.to_csv(out_fused_csv, index=False)
    print(f"[SAVED] fused CSV -> {out_fused_csv}")
    return df_fused



## Run extraction

In [76]:
df_fused = extract_and_save_fused(
    csv_labels_path = LABEL_CSV,
    rgb_folder = RGB_FOLDER,
    flow_folder = FLOW_FOLDER if (FLOW_FOLDER is not None and FLOW_FOLDER.exists()) else None,
    out_fused_csv = OUTPUT_FUSED_CSV,
    sample_rate = SAMPLE_RATE,
    w_rgb = W_RGB,
    w_flow = W_FLOW
)

Extract & fuse:   0%|          | 0/99029 [00:00<?, ?it/s]

[SAVED] fused CSV -> D:\Datasets\Datasets\EPIC\Features\P01_01_fused_features.csv


In [33]:
data=pd.read_csv(r"D:\Datasets\Datasets\EPIC\Features\FusedFeatures\P01_03_fused_features.csv")
data

Unnamed: 0,frame_idx,frame_name,ActionLabel,ActionName,feat_0,feat_1,feat_2,feat_3,feat_4,feat_5,...,feat_502,feat_503,feat_504,feat_505,feat_506,feat_507,feat_508,feat_509,feat_510,feat_511
0,0,frame_00000.jpg,-1,Unknown,0.224212,-0.416339,0.093212,-0.107410,0.128738,-0.086018,...,-0.093053,-0.069731,-0.055004,-0.190981,-0.225788,-0.117915,0.111749,-0.181230,-0.133861,-0.064304
1,1,frame_00001.jpg,-1,Unknown,0.213923,-0.386899,0.059060,-0.069629,0.119920,-0.166948,...,-0.080017,-0.080227,-0.068668,-0.140048,-0.231339,-0.136919,0.105033,-0.226436,-0.147560,-0.031502
2,2,frame_00002.jpg,-1,Unknown,0.243420,-0.428746,0.088513,-0.109414,0.146825,-0.124339,...,-0.125939,-0.101465,-0.083090,-0.207229,-0.214796,-0.130478,0.117336,-0.226607,-0.148587,-0.036388
3,3,frame_00003.jpg,-1,Unknown,0.199176,-0.394861,0.059091,-0.070849,0.138222,-0.188922,...,-0.076485,-0.085018,-0.078830,-0.138564,-0.230030,-0.130138,0.084694,-0.234740,-0.113729,-0.036507
4,4,frame_00004.jpg,-1,Unknown,0.201879,-0.418372,0.067712,-0.092059,0.153830,-0.097980,...,-0.124027,-0.072635,-0.070298,-0.177150,-0.205052,-0.097257,0.119052,-0.231058,-0.133520,-0.042069
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7119,7119,frame_07119.jpg,-1,Unknown,0.319755,-0.180561,-0.020468,-0.662251,0.286641,-0.220271,...,-0.117545,0.197170,-0.497525,0.254925,-0.314211,-0.339845,0.104476,-0.593676,0.167948,0.032733
7120,7120,frame_07120.jpg,-1,Unknown,0.377411,-0.274535,-0.041729,-0.628119,0.238089,-0.184473,...,-0.133505,0.089116,-0.308133,0.235201,-0.286874,-0.237807,0.027914,-0.566877,0.290931,0.117833
7121,7121,frame_07121.jpg,-1,Unknown,0.350540,-0.272856,-0.002134,-0.596341,0.220748,-0.251758,...,-0.146171,0.129557,-0.440564,0.188763,-0.353000,-0.293067,0.075543,-0.538834,0.203080,0.097905
7122,7122,frame_07122.jpg,-1,Unknown,0.312043,-0.282243,-0.060604,-0.584190,0.237332,-0.254098,...,-0.134249,0.060154,-0.250969,0.223738,-0.298771,-0.242504,0.065116,-0.583450,0.259569,0.088858


### NOW From Here, We have to import the fused features and Label file

In [70]:
import pandas as pd
from pathlib import Path

def load_fused_csv_by_path(fused_csv_path: str):
    fp = Path(fused_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Fused features CSV not found: {fp}")
    df = pd.read_csv(fp)
    if "frame_idx" not in df.columns:
        raise KeyError("Fused CSV must contain 'frame_idx' column")
    df["frame_idx"] = df["frame_idx"].astype(int)
    df = df.sort_values("frame_idx").reset_index(drop=True)
    return df

def load_label_csv_by_path(label_csv_path: str):
    fp = Path(label_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Label CSV not found: {fp}")
    df = pd.read_csv(fp)
    return df

# ------------------------------------ Paths ---------------------------
fused_df = load_fused_csv_by_path(r"D:\Datasets\Datasets\EPIC\Features\FusedFeatures\P01_02_fused_features.csv")
labels_df = load_label_csv_by_path(r"D:\Datasets\Datasets\EPIC\Labels\P01_02.csv")


## Dataset Loader

In [71]:
IGNORE_INDEX = -1

class SingleVideoAnticipationDataset(Dataset):
    def __init__(self, fused_df_or_path, labels_df_or_path,
                 t_obs: int, k_fut: int, feat_dim: int,
                 fps: float, horizons_s: list[float]):
        
        # load paths
        if isinstance(fused_df_or_path, (str, Path)):
            fused_df = pd.read_csv(fused_df_or_path)
        else:
            fused_df = fused_df_or_path.copy()
        if isinstance(labels_df_or_path, (str, Path)):
            labels_df = pd.read_csv(labels_df_or_path)
        else:
            labels_df = labels_df_or_path.copy()

        if "frame_idx" not in fused_df.columns:
            raise KeyError("fused_df must contain 'frame_idx'")

        fused_df["frame_idx"] = fused_df["frame_idx"].astype(int)
        self.fused_df = fused_df.set_index("frame_idx", drop=False).sort_index()
        self.labels_df = labels_df.reset_index(drop=True)

        if not all(c in self.labels_df.columns for c in ["StartFrame", "EndFrame"]):
            raise KeyError("labels_df must contain StartFrame and EndFrame")

        self.t_obs = int(t_obs)
        self.k_fut = int(k_fut)
        self.feat_dim = int(feat_dim)
        self.feat_cols = [f"feat_{i}" for i in range(self.feat_dim)]

        # NEW: time info
        self.fps = float(fps)
        assert len(horizons_s) == self.k_fut, "len(horizons_s) must equal k_fut"
        self.horizons_s = list(horizons_s)

        # samples: one per label row (use EndFrame as obs_end)
        self.samples = []
        for ridx, row in self.labels_df.iterrows():
            try:
                obs_end = int(row["EndFrame"])
            except:
                continue
            self.samples.append({"label_row_idx": int(ridx), "obs_end": obs_end})

        if len(self.samples) == 0:
            raise RuntimeError("No valid label rows found")

    def __len__(self):
        return len(self.samples)

    # === NEW: time-based future labels instead of next segments ===
    def _time_based_future_labels(self, obs_end: int):
        """
        For each horizon t in self.horizons_s (seconds),
        compute future_frame = obs_end + round(t * fps),
        then find which action segment covers that frame.
        """
        labels_df = self.labels_df

        def pick(cols):
            for c in cols:
                if c in labels_df.columns:
                    return c
            return None

        vcol = pick(["Verb_class","verb","Verb","verb_class"])
        ncol = pick(["Noun_class","noun","Noun","noun_class"])
        acol = pick(["Action_class","action","Action","ActionLabel"])

        verb_targets   = []
        noun_targets   = []
        action_targets = []

        for h_sec in self.horizons_s:
            future_frame = obs_end + int(round(h_sec * self.fps))
            seg = labels_df[(labels_df["StartFrame"] <= future_frame) &
                            (labels_df["EndFrame"]   >= future_frame)]
            if seg.empty:
                # nothing happening at that exact time
                verb_targets.append(IGNORE_INDEX)
                noun_targets.append(IGNORE_INDEX)
                action_targets.append(IGNORE_INDEX)
            else:
                row = seg.iloc[0]
                if vcol is not None and not pd.isna(row[vcol]):
                    verb_targets.append(int(row[vcol]))
                else:
                    verb_targets.append(IGNORE_INDEX)

                if ncol is not None and not pd.isna(row[ncol]):
                    noun_targets.append(int(row[ncol]))
                else:
                    noun_targets.append(IGNORE_INDEX)

                if acol is not None and not pd.isna(row[acol]):
                    action_targets.append(int(row[acol]))
                else:
                    action_targets.append(IGNORE_INDEX)

        return {
            "verb":   torch.LongTensor(verb_targets),
            "noun":   torch.LongTensor(noun_targets),
            "action": torch.LongTensor(action_targets)
        }
    # ================================================================

    def __getitem__(self, idx):
        rec = self.samples[idx]
        obs_end = rec["obs_end"]
        obs_start = obs_end - (self.t_obs - 1)
        if obs_start < 0:
            obs_start = 0
            obs_end = obs_start + (self.t_obs - 1)

        fused_idx_min = int(self.fused_df.index.min())
        fused_idx_max = int(self.fused_df.index.max())
        obs_end = min(obs_end, fused_idx_max)
        obs_start = max(obs_end - (self.t_obs - 1), fused_idx_min)

        desired = list(range(obs_start, obs_end + 1))
        sel = self.fused_df.reindex(desired).fillna(method="ffill").fillna(method="bfill").fillna(0.0)

        if sel.shape[0] < self.t_obs:
            if sel.shape[0] == 0:
                zero_row = {c:0.0 for c in self.feat_cols}
                sel = pd.DataFrame([zero_row] * self.t_obs)
            else:
                first = sel.iloc[[0]]
                pads = pd.concat([first] * (self.t_obs - sel.shape[0]), ignore_index=True)
                sel = pd.concat([pads, sel.reset_index(drop=True)], ignore_index=True)

        for c in self.feat_cols:
            if c not in sel.columns:
                sel[c] = 0.0

        F_window = torch.from_numpy(sel[self.feat_cols].values).float()   # (T_obs, FEAT_DIM)

        # OLD:
        # y_multi = self._future_labels(self.labels_df, obs_end)
        # NEW: time-based labels
        y_multi = self._time_based_future_labels(obs_end)

        meta = {"obs_start": int(obs_start),
                "obs_end":   int(obs_end),
                "label_row_idx": int(rec["label_row_idx"])}
        return F_window, y_multi, meta


## GRAPH Construction

In [72]:
K=5   #----------- This K is to make KNN graph -------------
DROP=0.1

def build_topk_edge_index(features: torch.Tensor, k=K):
    """
    features: (T, D) torch tensor
    returns edge_index: (2, E) long tensor (undirected duplicated edges)
    """
    Tn = int(features.size(0))
    x = F.normalize(features, dim=1)
    sim = torch.matmul(x, x.t())   # (T,T)
    sim.fill_diagonal_(-1.0)
    vals, idxs = torch.topk(sim, k, dim=1)
    src = torch.arange(Tn).unsqueeze(1).expand(-1, k).reshape(-1)
    dst = idxs.reshape(-1)
    edge = torch.stack([src, dst], dim=0)
    edge_rev = torch.stack([dst, src], dim=0)
    return torch.cat([edge, edge_rev], dim=1).long()

class BatchedGAT(nn.Module):
    def __init__(self, in_dim, hid_dim=None, num_layers=3, heads=8, dropout=DROP):
        super().__init__()
        hid = hid_dim or in_dim
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_ch = in_dim if i==0 else hid
            self.convs.append(GATConv(in_ch, hid//heads, heads=heads, concat=True, dropout=dropout))
        self.proj = nn.Linear(hid, in_dim)
        self.norm = nn.LayerNorm(in_dim)
        self.act = nn.GELU()
    def forward(self, pyg_batch: PyGBatch, T_per_sample: int):
        x = pyg_batch.x; edge_index = pyg_batch.edge_index
        h = x
        for conv in self.convs:
            h = conv(h, edge_index); h = self.act(h)
        h = self.proj(h)
        node_feats, mask = to_dense_batch(h, batch=pyg_batch.batch)  # (B, max_nodes, D)
        B, max_nodes, D = node_feats.shape
        if max_nodes < T_per_sample:
            pad = torch.zeros(B, T_per_sample - max_nodes, D, device=node_feats.device)
            node_feats = torch.cat([node_feats, pad], dim=1)
        elif max_nodes > T_per_sample:
            node_feats = node_feats[:, :T_per_sample, :]
        return self.norm(node_feats)  # (B, T_per_sample, D)

## Encoder, Decoder, Anticipation Model

In [73]:
class SimpleTransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead=8, num_layers=3, dim_feedforward=2048, dropout=DROP, max_len=1000):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=num_layers)
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model))
    def forward(self, x):
        B,T,D = x.shape
        pos = self.pos_emb[:, :T, :].to(x.device)
        return self.encoder(x + pos)

class AnticipationModel(nn.Module):
    def __init__(self, feat_dim, num_classes: dict, k_fut=5, gat_layers=3, gat_heads=8, dec_layers=3, dec_heads=8, dropout=DROP):
        super().__init__()
        self.feat_dim = feat_dim; self.k_fut = k_fut
        self.gat = BatchedGAT(in_dim=feat_dim, hid_dim=feat_dim, num_layers=gat_layers, heads=gat_heads, dropout=dropout)
        self.encoder = SimpleTransformerEncoder(d_model=feat_dim, nhead=dec_heads, num_layers=3)
        dec_layer = nn.TransformerDecoderLayer(d_model=feat_dim, nhead=dec_heads, dim_feedforward=feat_dim*4, dropout=dropout, activation='gelu', batch_first=True)
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=dec_layers)
        self.queries = nn.Parameter(torch.randn(1, k_fut, feat_dim))
        assert isinstance(num_classes, dict)
        self.verb_head = nn.Linear(feat_dim, num_classes["verb"])
        self.noun_head = nn.Linear(feat_dim, num_classes["noun"])
        self.action_head = nn.Linear(feat_dim, num_classes["action"])

    def forward(self, F_batch):
        # F_batch: (B, T, D)
        B,T,D = F_batch.shape; device = F_batch.device
        data_list=[]
        for b in range(B):
            x = F_batch[b]
            edge_index = build_topk_edge_index(x.detach().cpu(), k=K).to(device)
            data_list.append(PyGData(x=x, edge_index=edge_index))
        pyg_batch = PyGBatch.from_data_list(data_list).to(device)
        G = self.gat(pyg_batch, T_per_sample=T)   # (B,T,D)
        H = self.encoder(F_batch)                 # (B,T,D)
        U = H + G
        q = self.queries.expand(B, -1, -1).to(device)
        dec_out = self.decoder(tgt=q, memory=U)   # (B, K_fut, D)
        return {"verb": self.verb_head(dec_out),
                "noun": self.noun_head(dec_out),
                "action": self.action_head(dec_out)}

## Novel Loss Function

In [74]:
import torch.nn.functional as F

def masked_cross_entropy(logits, labels, ignore_index=IGNORE_INDEX):
    """
    logits: (B,K,C), labels: (B,K)
    Ignores positions where label == ignore_index.
    Returns a scalar loss that is always connected to the graph
    (even if there are no valid labels in a batch).
    """
    B, K, C = logits.shape
    logits_flat = logits.view(B * K, C)      # (B*K, C)
    labels_flat = labels.view(B * K)         # (B*K,)

    loss_flat = F.cross_entropy(
        logits_flat,
        labels_flat,
        reduction='none',
        ignore_index=ignore_index
    )  # (B*K,)

    mask = (labels_flat != ignore_index).float()
    valid = mask.sum()

    if valid == 0:
        # return a zero that still depends on logits so backward() is valid
        return (logits_flat * 0.0).sum()

    return (loss_flat * mask).sum() / valid


def topk_accuracy_per_task(logits, labels, topk=(1,5), ignore_index=IGNORE_INDEX):
    """
    logits: (B,K,C), labels: (B,K)
    returns dict with 'per_h' (dict of lists per topk) and 'overall_top{k}' floats
    """
    B,K,C = logits.shape
    res = {}
    overall = {k:0 for k in topk}
    total_cnt = 0
    preds_topk = logits.topk(max(topk), dim=-1)[1]  # (B,K,maxk)
    for h in range(K):
        lab = labels[:,h]; mask = (lab != ignore_index); cnt = int(mask.sum().item())
        for k in topk:
            if cnt == 0:
                res.setdefault(f"per_h{h+1}_top{k}", None)
                continue
            predk = preds_topk[:,h,:k]  # (B,k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            hits = (predk == lab_exp)
            hit = int(hits[mask].any(dim=1).float().sum().item())
            res[f"per_h{h+1}_top{k}"] = hit / cnt
            overall[k] += hit
        total_cnt += cnt
    for k in topk:
        res[f"overall_top{k}"] = overall[k] / total_cnt if total_cnt>0 else None
    return res



## Training and Testing (Validation) HERE

In [75]:
## === EDIT these paths ===
FUSED_CSV_PATH = r"D:\Datasets\Datasets\EPIC\Features\FusedFeatures\P01_02_fused_features.csv"
LABEL_CSV_PATH = r"D:\Datasets\Datasets\EPIC\Labels\P01_02.csv"
BEST_MODEL_PATH = Path(r"D:\Datasets\Datasets\EPIC\Model\P01_02_fused_model.pth")
# ========================

# Hyperparams
T_OBS = 90
FEAT_DIM = 512
BATCH_SIZE = 8
NUM_EPOCHS = 20
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 0

# === Time-based anticipation config ===
FPS = 30.0 
HORIZONS_S = [0.25, 0.5, 0.75, 1.0, 1.25, 1.50, 1.75, 2.0]   # seconds into the future
K_FUT = len(HORIZONS_S)               # model will output one label per horizon
# =====================================


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def detect_num_classes_from_labels_df(labels_df):
    verbs = set()
    nouns = set()
    actions = set()
    for cand in ["Verb_class", "verb", "Verb", "verb_class"]:
        if cand in labels_df.columns:
            verbs.update(labels_df[cand].dropna().astype(int).tolist())
            break
    for cand in ["Noun_class", "noun", "Noun", "noun_class"]:
        if cand in labels_df.columns:
            nouns.update(labels_df[cand].dropna().astype(int).tolist())
            break
    for cand in ["Action_class", "action", "Action", "ActionLabel"]:
        if cand in labels_df.columns:
            actions.update(labels_df[cand].dropna().astype(int).tolist())
            break
    nv = (max(verbs) + 1) if len(verbs) > 0 else 1
    nn_ = (max(nouns) + 1) if len(nouns) > 0 else 1
    na = (max(actions) + 1) if len(actions) > 0 else 1
    return {"verb": int(nv), "noun": int(nn_), "action": int(na)}


def topk_counts(logits, labels, k):
    # logits: (B, K_fut, C); labels: (B, K_fut)
    with torch.no_grad():
        B, K, C = logits.shape
        topk_preds = logits.topk(k, dim=-1)[1]  # (B, K, k)
        hits = 0
        total = 0
        for h in range(K):
            lab = labels[:, h]  # (B,)
            mask = (lab != IGNORE_INDEX)
            if int(mask.sum().item()) == 0:
                continue
            predk = topk_preds[:, h, :]  # (B, k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            masked_pred = predk[mask]   # (M, k)
            masked_lab = lab_exp[mask]  # (M, k)
            hit_vec = (masked_pred == masked_lab).any(dim=1).float()
            hits += int(hit_vec.sum().item())
            total += int(mask.sum().item())
        return hits, total


# Load fused and labels
fused_df = pd.read_csv(FUSED_CSV_PATH)
labels_df = pd.read_csv(LABEL_CSV_PATH)

# Dataset
dataset = SingleVideoAnticipationDataset(
    fused_df,
    labels_df,
    t_obs=T_OBS,
    k_fut=K_FUT,        # must equal len(HORIZONS_S)
    feat_dim=FEAT_DIM,
    fps=FPS,
    horizons_s=HORIZONS_S
)

# split indices for train/val (60/40)
indices = list(range(len(dataset)))
random.seed(42)
random.shuffle(indices)
split_at = int(0.6 * len(indices))
train_idx = indices[:split_at]
val_idx = indices[split_at:]

train_ds = Subset(dataset, train_idx)
val_ds = Subset(dataset, val_idx)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=(DEVICE == "cuda")
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=(DEVICE == "cuda")
)

# detect number of classes and instantiate model
num_classes = detect_num_classes_from_labels_df(labels_df)
print("Detected num_classes:", num_classes)
model = AnticipationModel(
    feat_dim=FEAT_DIM,
    num_classes=num_classes,
    k_fut=K_FUT
).to(DEVICE)

opt = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=3)

best_val_loss = float("inf")

for epoch in range(1, NUM_EPOCHS + 1):
    t0 = time.time()

    # ------------- TRAIN -------------
    model.train()
    train_loss_sum = 0.0
    train_samples = 0
    train_counts = {
        "verb_top1": [0, 0], "verb_top5": [0, 0],
        "noun_top1": [0, 0], "noun_top5": [0, 0],
        "action_top1": [0, 0], "action_top5": [0, 0]
    }

    pbar = tqdm(train_loader, desc=f"Epoch {epoch} Train", leave=False)
    for F_batch, y_multi, meta in pbar:
        F_batch = F_batch.to(DEVICE)               # (B, T, D)
        y_v = y_multi["verb"].to(DEVICE)           # (B, K_fut)
        y_n = y_multi["noun"].to(DEVICE)
        y_a = y_multi["action"].to(DEVICE)

        opt.zero_grad()
        logits = model(F_batch)   # dict: "verb"/"noun"/"action" -> (B, K_fut, C)

        loss_v = masked_cross_entropy(logits["verb"], y_v)
        loss_n = masked_cross_entropy(logits["noun"], y_n)
        loss_a = masked_cross_entropy(logits["action"], y_a)
        loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

        loss.backward()
        opt.step()

        b = F_batch.size(0)
        train_loss_sum += float(loss.item()) * b
        train_samples += b

        for (task, lab, lg) in [
            ("verb", y_v, logits["verb"]),
            ("noun", y_n, logits["noun"]),
            ("action", y_a, logits["action"])
        ]:
            h1, t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
            h5, t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
            train_counts[f"{task}_top1"][0] += h1
            train_counts[f"{task}_top1"][1] += t1
            train_counts[f"{task}_top5"][0] += h5
            train_counts[f"{task}_top5"][1] += t5

    train_loss = train_loss_sum / max(1, train_samples)
    train_metrics = {}
    for task in ["verb", "noun", "action"]:
        h1, t1 = train_counts[f"{task}_top1"]
        h5, t5 = train_counts[f"{task}_top5"]
        train_metrics[f"{task}_top1"] = (h1 / t1) if t1 > 0 else None
        train_metrics[f"{task}_top5"] = (h5 / t5) if t5 > 0 else None

    # ------------- VALIDATION -------------
    model.eval()
    val_loss_sum = 0.0
    val_samples = 0
    val_counts = {
        "verb_top1": [0, 0], "verb_top5": [0, 0],
        "noun_top1": [0, 0], "noun_top5": [0, 0],
        "action_top1": [0, 0], "action_top5": [0, 0]
    }

    # store logits/labels for per-horizon + P/R/F1 metrics
    val_logits_store = {"verb": [], "noun": [], "action": []}
    val_labels_store = {"verb": [], "noun": [], "action": []}

    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f"Epoch {epoch} Val", leave=False)
        for F_batch, y_multi, meta in pbar:
            F_batch = F_batch.to(DEVICE)
            y_v = y_multi["verb"].to(DEVICE)
            y_n = y_multi["noun"].to(DEVICE)
            y_a = y_multi["action"].to(DEVICE)

            logits = model(F_batch)
            loss_v = masked_cross_entropy(logits["verb"], y_v)
            loss_n = masked_cross_entropy(logits["noun"], y_n)
            loss_a = masked_cross_entropy(logits["action"], y_a)
            loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

            b = F_batch.size(0)
            val_loss_sum += float(loss.item()) * b
            val_samples += b

            for (task, lab, lg) in [
                ("verb", y_v, logits["verb"]),
                ("noun", y_n, logits["noun"]),
                ("action", y_a, logits["action"])
            ]:
                h1, t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
                h5, t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
                val_counts[f"{task}_top1"][0] += h1
                val_counts[f"{task}_top1"][1] += t1
                val_counts[f"{task}_top5"][0] += h5
                val_counts[f"{task}_top5"][1] += t5

            # store for per-horizon + P/R/F1 metrics
            val_logits_store["verb"].append(logits["verb"].detach().cpu())
            val_logits_store["noun"].append(logits["noun"].detach().cpu())
            val_logits_store["action"].append(logits["action"].detach().cpu())
            val_labels_store["verb"].append(y_v.detach().cpu())
            val_labels_store["noun"].append(y_n.detach().cpu())
            val_labels_store["action"].append(y_a.detach().cpu())

    val_loss = val_loss_sum / max(1, val_samples)

    # overall val metrics (top-1/top-5 over all horizons)
    val_metrics = {}
    for task in ["verb", "noun", "action"]:
        h1, t1 = val_counts[f"{task}_top1"]
        h5, t5 = val_counts[f"{task}_top5"]
        val_metrics[f"{task}_top1"] = (h1 / t1) if t1 > 0 else None
        val_metrics[f"{task}_top5"] = (h5 / t5) if t5 > 0 else None

    # per-horizon metrics (time-based)
    per_horizon_metrics = {"verb": {}, "noun": {}, "action": {}}
    for task in ["verb", "noun", "action"]:
        if len(val_logits_store[task]) == 0:
            continue
        logits_all = torch.cat(val_logits_store[task], dim=0)  # (N, K_fut, C)
        labels_all = torch.cat(val_labels_store[task], dim=0)  # (N, K_fut)
        m = topk_accuracy_per_task(
            logits_all,
            labels_all,
            topk=(1, 5),
            ignore_index=IGNORE_INDEX
        )
        per_horizon_metrics[task] = m

    # macro precision / recall / F1 over all horizons (validation)
    prf_metrics = {"verb": {}, "noun": {}, "action": {}}
    for task in ["verb", "noun", "action"]:
        if len(val_logits_store[task]) == 0:
            continue

        logits_all = torch.cat(val_logits_store[task], dim=0)
        labels_all = torch.cat(val_labels_store[task], dim=0)

        preds_all = logits_all.argmax(dim=-1)  # (N, K_fut)
        mask = (labels_all != IGNORE_INDEX)
        if mask.sum().item() == 0:
            continue

        y_true = labels_all[mask].numpy()
        y_pred = preds_all[mask].numpy()

        p, r, f1, _ = precision_recall_fscore_support(
            y_true,
            y_pred,
            average="macro",
            zero_division=0
        )
        prf_metrics[task]["precision"] = p
        prf_metrics[task]["recall"] = r
        prf_metrics[task]["f1"] = f1

    # mean Top-5 recall across horizons for each task
    mean_top5_recall = {}
    for task in ["verb", "noun", "action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            mean_top5_recall[task] = None
            continue

        vals = []
        for h_idx in range(K_FUT):
            key = f"per_h{h_idx+1}_top5"
            if key in mh and mh[key] is not None:
                vals.append(mh[key])
        mean_top5_recall[task] = float(np.mean(vals)) if len(vals) > 0 else None

    # scheduler + logging
    sched.step(val_loss)
    elapsed = time.time() - t0

    print(f"Epoch {epoch}/{NUM_EPOCHS} | Time {elapsed:.1f}s")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    for task in ["verb", "noun", "action"]:
        print(
            f"  {task.upper():6s} Train Top1: {train_metrics[f'{task}_top1']}, "
            f"Top5: {train_metrics[f'{task}_top5']}; "
            f"Val Top1: {val_metrics[f'{task}_top1']}, "
            f"Top5: {val_metrics[f'{task}_top5']}"
        )

    # print macro precision / recall / F1 (validation)
    for task in ["verb", "noun", "action"]:
        if prf_metrics[task]:
            p = prf_metrics[task]["precision"]
            r = prf_metrics[task]["recall"]
            f1 = prf_metrics[task]["f1"]
            print(
                f"  {task.upper():6s} Val Precision: {p:.4f}, "
                f"Recall: {r:.4f}, F1: {f1:.4f}"
            )

    # print mean Top-5 recall
    print("  ---- Mean Top-5 Recall (validation) ----")
    for task in ["verb", "noun", "action"]:
        print(
            f"     {task.upper():6s}  Mean Top-5 Recall: {mean_top5_recall[task]}"
        )

    # print per-horizon by seconds
    for task in ["verb", "noun", "action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            continue
        print(f"  {task.upper():6s} per-horizon (time-based):")
        for h_idx, t_sec in enumerate(HORIZONS_S):
            key1 = f"per_h{h_idx+1}_top1"
            key5 = f"per_h{h_idx+1}_top5"
            v1 = mh.get(key1, None)
            v5 = mh.get(key5, None)
            print(f"    @ {t_sec:4.2f}s  Top1: {v1}  Top5: {v5}")
        print(
            f"    overall_top1: {mh.get('overall_top1', None)}, "
            f"overall_top5: {mh.get('overall_top5', None)}"
        )

    # optional: save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(
            {
                'epoch': epoch,
                'model_state': model.state_dict(),
                'opt_state': opt.state_dict(),
                'val_loss': val_loss
            },
            BEST_MODEL_PATH
        )
        print(f"[SAVED BEST] -> {BEST_MODEL_PATH}")

print("Training finished.")

Detected num_classes: {'verb': 59, 'noun': 210, 'action': 63}


Epoch 1 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 1/20 | Time 2.8s
  Train Loss: 8.0264 | Val Loss: 7.1594
  VERB   Train Top1: 0.3327205882352941, Top5: 0.7371323529411765; Val Top1: 0.5637393767705382, Top5: 0.9008498583569405
  NOUN   Train Top1: 0.15441176470588236, Top5: 0.3161764705882353; Val Top1: 0.16147308781869688, Top5: 0.3087818696883853
  ACTION Train Top1: 0.07904411764705882, Top5: 0.15808823529411764; Val Top1: 0.12181303116147309, Top5: 0.26628895184135976
  VERB   Val Precision: 0.0626, Recall: 0.1111, F1: 0.0801
  NOUN   Val Precision: 0.0090, Recall: 0.0556, F1: 0.0154
  ACTION Val Precision: 0.0036, Recall: 0.0294, F1: 0.0064
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.9023952897726757
     NOUN    Mean Top-5 Recall: 0.3101982768174444
     ACTION  Mean Top-5 Recall: 0.2670465411517904
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5641025641025641  Top5: 0.9487179487179487
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.9512195121951219
    @ 0.75s  Top1: 0.56818181

Epoch 2 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 2/20 | Time 2.2s
  Train Loss: 6.6309 | Val Loss: 7.0808
  VERB   Train Top1: 0.43933823529411764, Top5: 0.9025735294117647; Val Top1: 0.5637393767705382, Top5: 0.8583569405099151
  NOUN   Train Top1: 0.18566176470588236, Top5: 0.47058823529411764; Val Top1: 0.16147308781869688, Top5: 0.3314447592067989
  ACTION Train Top1: 0.13419117647058823, Top5: 0.24080882352941177; Val Top1: 0.12181303116147309, Top5: 0.23796033994334279
  VERB   Val Precision: 0.0626, Recall: 0.1111, F1: 0.0801
  NOUN   Val Precision: 0.0090, Recall: 0.0556, F1: 0.0154
  ACTION Val Precision: 0.0036, Recall: 0.0294, F1: 0.0064
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8607768663674019
     NOUN    Mean Top-5 Recall: 0.3325781110311863
     ACTION  Mean Top-5 Recall: 0.238404397645185
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5641025641025641  Top5: 0.9230769230769231
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.926829268292683
    @ 0.75s  Top1: 0.56818181

Epoch 3 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 3/20 | Time 2.3s
  Train Loss: 6.3473 | Val Loss: 6.9476
  VERB   Train Top1: 0.4411764705882353, Top5: 0.9117647058823529; Val Top1: 0.5637393767705382, Top5: 0.8583569405099151
  NOUN   Train Top1: 0.18566176470588236, Top5: 0.45588235294117646; Val Top1: 0.16147308781869688, Top5: 0.29745042492917845
  ACTION Train Top1: 0.13419117647058823, Top5: 0.31066176470588236; Val Top1: 0.12181303116147309, Top5: 0.2747875354107649
  VERB   Val Precision: 0.0626, Recall: 0.1111, F1: 0.0801
  NOUN   Val Precision: 0.0090, Recall: 0.0556, F1: 0.0154
  ACTION Val Precision: 0.0036, Recall: 0.0294, F1: 0.0064
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8607768663674019
     NOUN    Mean Top-5 Recall: 0.29797032493546277
     ACTION  Mean Top-5 Recall: 0.2758072249124741
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5641025641025641  Top5: 0.9230769230769231
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.926829268292683
    @ 0.75s  Top1: 0.5681818

Epoch 4 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 4/20 | Time 2.2s
  Train Loss: 6.0857 | Val Loss: 6.9624
  VERB   Train Top1: 0.43933823529411764, Top5: 0.9117647058823529; Val Top1: 0.5637393767705382, Top5: 0.8583569405099151
  NOUN   Train Top1: 0.18566176470588236, Top5: 0.4944852941176471; Val Top1: 0.16147308781869688, Top5: 0.3597733711048159
  ACTION Train Top1: 0.1488970588235294, Top5: 0.3272058823529412; Val Top1: 0.12464589235127478, Top5: 0.28611898016997167
  VERB   Val Precision: 0.0626, Recall: 0.1111, F1: 0.0801
  NOUN   Val Precision: 0.0090, Recall: 0.0556, F1: 0.0154
  ACTION Val Precision: 0.0130, Recall: 0.0364, F1: 0.0164
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8607768663674019
     NOUN    Mean Top-5 Recall: 0.3604738744303299
     ACTION  Mean Top-5 Recall: 0.287445822889553
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5641025641025641  Top5: 0.9230769230769231
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.926829268292683
    @ 0.75s  Top1: 0.56818181818

Epoch 5 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 5/20 | Time 2.2s
  Train Loss: 5.9684 | Val Loss: 7.0624
  VERB   Train Top1: 0.43933823529411764, Top5: 0.9117647058823529; Val Top1: 0.5637393767705382, Top5: 0.8583569405099151
  NOUN   Train Top1: 0.18382352941176472, Top5: 0.5091911764705882; Val Top1: 0.16147308781869688, Top5: 0.3087818696883853
  ACTION Train Top1: 0.14338235294117646, Top5: 0.34558823529411764; Val Top1: 0.12181303116147309, Top5: 0.19263456090651557
  VERB   Val Precision: 0.0626, Recall: 0.1111, F1: 0.0801
  NOUN   Val Precision: 0.0106, Recall: 0.0556, F1: 0.0178
  ACTION Val Precision: 0.0036, Recall: 0.0294, F1: 0.0064
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8607768663674019
     NOUN    Mean Top-5 Recall: 0.3086307042712796
     ACTION  Mean Top-5 Recall: 0.19280352854550892
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5641025641025641  Top5: 0.9230769230769231
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.926829268292683
    @ 0.75s  Top1: 0.5681818

Epoch 6 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 6/20 | Time 2.3s
  Train Loss: 5.6904 | Val Loss: 6.8768
  VERB   Train Top1: 0.4411764705882353, Top5: 0.9136029411764706; Val Top1: 0.5694050991501416, Top5: 0.8583569405099151
  NOUN   Train Top1: 0.22794117647058823, Top5: 0.5349264705882353; Val Top1: 0.18413597733711048, Top5: 0.42776203966005666
  ACTION Train Top1: 0.17463235294117646, Top5: 0.4025735294117647; Val Top1: 0.1558073654390935, Top5: 0.2577903682719547
  VERB   Val Precision: 0.1227, Recall: 0.1522, F1: 0.1307
  NOUN   Val Precision: 0.0356, Recall: 0.0704, F1: 0.0350
  ACTION Val Precision: 0.0464, Recall: 0.0450, F1: 0.0284
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8607768663674019
     NOUN    Mean Top-5 Recall: 0.4264823074224585
     ACTION  Mean Top-5 Recall: 0.25850259158185984
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5641025641025641  Top5: 0.9230769230769231
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.926829268292683
    @ 0.75s  Top1: 0.5909090909

Epoch 7 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 7/20 | Time 2.3s
  Train Loss: 5.3004 | Val Loss: 6.9518
  VERB   Train Top1: 0.48161764705882354, Top5: 0.9338235294117647; Val Top1: 0.5779036827195467, Top5: 0.8640226628895185
  NOUN   Train Top1: 0.2757352941176471, Top5: 0.6397058823529411; Val Top1: 0.18413597733711048, Top5: 0.39943342776203966
  ACTION Train Top1: 0.19301470588235295, Top5: 0.4632352941176471; Val Top1: 0.141643059490085, Top5: 0.2521246458923513
  VERB   Val Precision: 0.1381, Recall: 0.1539, F1: 0.1364
  NOUN   Val Precision: 0.0588, Recall: 0.0704, F1: 0.0388
  ACTION Val Precision: 0.0088, Recall: 0.0373, F1: 0.0132
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8652264600899278
     NOUN    Mean Top-5 Recall: 0.399319423948268
     ACTION  Mean Top-5 Recall: 0.2518714257861262
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5897435897435898  Top5: 0.8974358974358975
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.9024390243902439
    @ 0.75s  Top1: 0.590909090909

Epoch 8 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 8/20 | Time 2.3s
  Train Loss: 4.8697 | Val Loss: 6.9778
  VERB   Train Top1: 0.5257352941176471, Top5: 0.9209558823529411; Val Top1: 0.5637393767705382, Top5: 0.8838526912181303
  NOUN   Train Top1: 0.32169117647058826, Top5: 0.71875; Val Top1: 0.20396600566572237, Top5: 0.40793201133144474
  ACTION Train Top1: 0.22058823529411764, Top5: 0.6801470588235294; Val Top1: 0.17280453257790368, Top5: 0.22662889518413598
  VERB   Val Precision: 0.1192, Recall: 0.1511, F1: 0.1303
  NOUN   Val Precision: 0.0953, Recall: 0.0842, F1: 0.0668
  ACTION Val Precision: 0.0745, Recall: 0.0519, F1: 0.0495
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8861334869619599
     NOUN    Mean Top-5 Recall: 0.4082036254955131
     ACTION  Mean Top-5 Recall: 0.2256270079043144
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5641025641025641  Top5: 0.9487179487179487
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.9512195121951219
    @ 0.75s  Top1: 0.5681818181818182  T

Epoch 9 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 9/20 | Time 2.3s
  Train Loss: 4.1044 | Val Loss: 6.7722
  VERB   Train Top1: 0.5386029411764706, Top5: 0.9540441176470589; Val Top1: 0.5637393767705382, Top5: 0.8526912181303116
  NOUN   Train Top1: 0.4742647058823529, Top5: 0.8382352941176471; Val Top1: 0.2096317280453258, Top5: 0.46742209631728043
  ACTION Train Top1: 0.3952205882352941, Top5: 0.7536764705882353; Val Top1: 0.18980169971671387, Top5: 0.2804532577903683
  VERB   Val Precision: 0.1147, Recall: 0.1360, F1: 0.1206
  NOUN   Val Precision: 0.1159, Recall: 0.0926, F1: 0.0720
  ACTION Val Precision: 0.0737, Recall: 0.0799, F1: 0.0509
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8543994398879076
     NOUN    Mean Top-5 Recall: 0.467110271128365
     ACTION  Mean Top-5 Recall: 0.2792312515088894
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5641025641025641  Top5: 0.8974358974358975
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.9024390243902439
    @ 0.75s  Top1: 0.5681818181818

Epoch 10 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 10/20 | Time 2.3s
  Train Loss: 3.5929 | Val Loss: 6.8810
  VERB   Train Top1: 0.49816176470588236, Top5: 0.9669117647058824; Val Top1: 0.5609065155807366, Top5: 0.8441926345609065
  NOUN   Train Top1: 0.5091911764705882, Top5: 0.9080882352941176; Val Top1: 0.20679886685552407, Top5: 0.45042492917847027
  ACTION Train Top1: 0.5165441176470589, Top5: 0.875; Val Top1: 0.22946175637393768, Top5: 0.2776203966005666
  VERB   Val Precision: 0.1181, Recall: 0.1505, F1: 0.1292
  NOUN   Val Precision: 0.0944, Recall: 0.0864, F1: 0.0685
  ACTION Val Precision: 0.1029, Recall: 0.0972, F1: 0.0798
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.845078172828703
     NOUN    Mean Top-5 Recall: 0.4495950215611536
     ACTION  Mean Top-5 Recall: 0.27714908202878774
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.5384615384615384  Top5: 0.8717948717948718
    @ 0.50s  Top1: 0.5853658536585366  Top5: 0.8780487804878049
    @ 0.75s  Top1: 0.5681818181818182  Top5

Epoch 11 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 11/20 | Time 2.2s
  Train Loss: 3.0497 | Val Loss: 6.8453
  VERB   Train Top1: 0.5992647058823529, Top5: 0.9926470588235294; Val Top1: 0.5014164305949008, Top5: 0.8526912181303116
  NOUN   Train Top1: 0.5900735294117647, Top5: 0.9430147058823529; Val Top1: 0.2804532577903683, Top5: 0.46458923512747874
  ACTION Train Top1: 0.6507352941176471, Top5: 0.9466911764705882; Val Top1: 0.21246458923512748, Top5: 0.2577903682719547
  VERB   Val Precision: 0.1553, Recall: 0.1266, F1: 0.1304
  NOUN   Val Precision: 0.2279, Recall: 0.1528, F1: 0.1361
  ACTION Val Precision: 0.0641, Recall: 0.0839, F1: 0.0568
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8541729906125453
     NOUN    Mean Top-5 Recall: 0.46409179202292933
     ACTION  Mean Top-5 Recall: 0.25684301669521714
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.46153846153846156  Top5: 0.8974358974358975
    @ 0.50s  Top1: 0.4634146341463415  Top5: 0.9024390243902439
    @ 0.75s  Top1: 0.5  Top5:

Epoch 12 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 12 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 12/20 | Time 2.3s
  Train Loss: 2.5492 | Val Loss: 6.7522
  VERB   Train Top1: 0.6819852941176471, Top5: 0.9797794117647058; Val Top1: 0.3371104815864023, Top5: 0.7903682719546742
  NOUN   Train Top1: 0.6838235294117647, Top5: 0.96875; Val Top1: 0.32294617563739375, Top5: 0.5609065155807366
  ACTION Train Top1: 0.7481617647058824, Top5: 0.9650735294117647; Val Top1: 0.21529745042492918, Top5: 0.32011331444759206
  VERB   Val Precision: 0.2412, Recall: 0.2033, F1: 0.1954
  NOUN   Val Precision: 0.2347, Recall: 0.2268, F1: 0.1738
  ACTION Val Precision: 0.1193, Recall: 0.0994, F1: 0.0873
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7922827621594855
     NOUN    Mean Top-5 Recall: 0.5610213824042695
     ACTION  Mean Top-5 Recall: 0.3203193955729083
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.3333333333333333  Top5: 0.8461538461538461
    @ 0.50s  Top1: 0.34146341463414637  Top5: 0.8536585365853658
    @ 0.75s  Top1: 0.3181818181818182  To

Epoch 13 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 13 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 13/20 | Time 2.3s
  Train Loss: 2.1420 | Val Loss: 7.2606
  VERB   Train Top1: 0.6838235294117647, Top5: 0.9963235294117647; Val Top1: 0.3286118980169972, Top5: 0.7705382436260623
  NOUN   Train Top1: 0.78125, Top5: 0.9705882352941176; Val Top1: 0.2719546742209632, Top5: 0.4192634560906516
  ACTION Train Top1: 0.7720588235294118, Top5: 0.9816176470588235; Val Top1: 0.1813031161473088, Top5: 0.254957507082153
  VERB   Val Precision: 0.2283, Recall: 0.1687, F1: 0.1840
  NOUN   Val Precision: 0.2355, Recall: 0.2498, F1: 0.1993
  ACTION Val Precision: 0.1185, Recall: 0.0915, F1: 0.0875
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7725275650398264
     NOUN    Mean Top-5 Recall: 0.42169806212257166
     ACTION  Mean Top-5 Recall: 0.2549975896393579
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.3333333333333333  Top5: 0.8461538461538461
    @ 0.50s  Top1: 0.36585365853658536  Top5: 0.8048780487804879
    @ 0.75s  Top1: 0.3181818181818182  Top5:

Epoch 14 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 14 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 14/20 | Time 2.3s
  Train Loss: 1.6651 | Val Loss: 6.9589
  VERB   Train Top1: 0.7941176470588235, Top5: 1.0; Val Top1: 0.44192634560906513, Top5: 0.8441926345609065
  NOUN   Train Top1: 0.8566176470588235, Top5: 0.9852941176470589; Val Top1: 0.3342776203966006, Top5: 0.5155807365439093
  ACTION Train Top1: 0.8823529411764706, Top5: 0.9889705882352942; Val Top1: 0.20679886685552407, Top5: 0.32577903682719545
  VERB   Val Precision: 0.2396, Recall: 0.1871, F1: 0.2047
  NOUN   Val Precision: 0.2485, Recall: 0.2322, F1: 0.1891
  ACTION Val Precision: 0.1184, Recall: 0.0942, F1: 0.0870
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8453149152529456
     NOUN    Mean Top-5 Recall: 0.515596602753615
     ACTION  Mean Top-5 Recall: 0.32576447133048403
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.41025641025641024  Top5: 0.8717948717948718
    @ 0.50s  Top1: 0.4878048780487805  Top5: 0.8780487804878049
    @ 0.75s  Top1: 0.4318181818181818  Top5: 

Epoch 15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 15 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 15/20 | Time 2.2s
  Train Loss: 1.3237 | Val Loss: 7.0182
  VERB   Train Top1: 0.8033088235294118, Top5: 0.9981617647058824; Val Top1: 0.48725212464589235, Top5: 0.8243626062322946
  NOUN   Train Top1: 0.9264705882352942, Top5: 0.9908088235294118; Val Top1: 0.311614730878187, Top5: 0.5042492917847026
  ACTION Train Top1: 0.9227941176470589, Top5: 0.9889705882352942; Val Top1: 0.2237960339943343, Top5: 0.31444759206798867
  VERB   Val Precision: 0.2324, Recall: 0.2079, F1: 0.2065
  NOUN   Val Precision: 0.2047, Recall: 0.1845, F1: 0.1615
  ACTION Val Precision: 0.1055, Recall: 0.0996, F1: 0.0834
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8254965868201553
     NOUN    Mean Top-5 Recall: 0.5046496541707968
     ACTION  Mean Top-5 Recall: 0.3142836514257511
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.48717948717948717  Top5: 0.8717948717948718
    @ 0.50s  Top1: 0.5121951219512195  Top5: 0.8292682926829268
    @ 0.75s  Top1: 0.47727272727

Epoch 16 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 16 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 16/20 | Time 2.3s
  Train Loss: 1.1774 | Val Loss: 7.0835
  VERB   Train Top1: 0.8345588235294118, Top5: 0.9981617647058824; Val Top1: 0.37960339943342775, Top5: 0.8130311614730878
  NOUN   Train Top1: 0.9503676470588235, Top5: 0.9926470588235294; Val Top1: 0.3031161473087819, Top5: 0.546742209631728
  ACTION Train Top1: 0.9099264705882353, Top5: 0.9963235294117647; Val Top1: 0.23229461756373937, Top5: 0.3342776203966006
  VERB   Val Precision: 0.1911, Recall: 0.2008, F1: 0.1840
  NOUN   Val Precision: 0.2247, Recall: 0.1864, F1: 0.1640
  ACTION Val Precision: 0.1205, Recall: 0.1054, F1: 0.0946
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8134818955935075
     NOUN    Mean Top-5 Recall: 0.5468401680871856
     ACTION  Mean Top-5 Recall: 0.3344319386869758
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.358974358974359  Top5: 0.8461538461538461
    @ 0.50s  Top1: 0.3902439024390244  Top5: 0.8292682926829268
    @ 0.75s  Top1: 0.3409090909090

Epoch 17 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 17 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 17/20 | Time 2.3s
  Train Loss: 0.8870 | Val Loss: 7.0010
  VERB   Train Top1: 0.8952205882352942, Top5: 1.0; Val Top1: 0.47875354107648727, Top5: 0.8243626062322946
  NOUN   Train Top1: 0.9558823529411765, Top5: 0.9944852941176471; Val Top1: 0.33994334277620397, Top5: 0.5382436260623229
  ACTION Train Top1: 0.9595588235294118, Top5: 0.9963235294117647; Val Top1: 0.20113314447592068, Top5: 0.3342776203966006
  VERB   Val Precision: 0.2467, Recall: 0.2026, F1: 0.2147
  NOUN   Val Precision: 0.2918, Recall: 0.2501, F1: 0.2195
  ACTION Val Precision: 0.1139, Recall: 0.1153, F1: 0.0874
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8254033704159632
     NOUN    Mean Top-5 Recall: 0.5387305391895132
     ACTION  Mean Top-5 Recall: 0.3344319386869758
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.4358974358974359  Top5: 0.8461538461538461
    @ 0.50s  Top1: 0.5121951219512195  Top5: 0.8536585365853658
    @ 0.75s  Top1: 0.4772727272727273  Top5: 0

Epoch 18 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 18 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 18/20 | Time 2.3s
  Train Loss: 0.7932 | Val Loss: 7.0709
  VERB   Train Top1: 0.8860294117647058, Top5: 1.0; Val Top1: 0.4475920679886686, Top5: 0.7762039660056658
  NOUN   Train Top1: 0.9595588235294118, Top5: 0.9981617647058824; Val Top1: 0.29745042492917845, Top5: 0.5354107648725213
  ACTION Train Top1: 0.9705882352941176, Top5: 0.9963235294117647; Val Top1: 0.22096317280453256, Top5: 0.3342776203966006
  VERB   Val Precision: 0.1859, Recall: 0.1960, F1: 0.1864
  NOUN   Val Precision: 0.1835, Recall: 0.1932, F1: 0.1587
  ACTION Val Precision: 0.1232, Recall: 0.1223, F1: 0.0993
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7774849038966218
     NOUN    Mean Top-5 Recall: 0.5345690931255618
     ACTION  Mean Top-5 Recall: 0.3344319386869758
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.4358974358974359  Top5: 0.8205128205128205
    @ 0.50s  Top1: 0.4878048780487805  Top5: 0.8048780487804879
    @ 0.75s  Top1: 0.4772727272727273  Top5: 0.

Epoch 19 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 19 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 19/20 | Time 2.4s
  Train Loss: 0.6789 | Val Loss: 7.1082
  VERB   Train Top1: 0.9191176470588235, Top5: 1.0; Val Top1: 0.49008498583569404, Top5: 0.8130311614730878
  NOUN   Train Top1: 0.9632352941176471, Top5: 0.9981617647058824; Val Top1: 0.3654390934844193, Top5: 0.5410764872521246
  ACTION Train Top1: 0.9724264705882353, Top5: 0.9981617647058824; Val Top1: 0.23512747875354106, Top5: 0.3342776203966006
  VERB   Val Precision: 0.2610, Recall: 0.2176, F1: 0.2277
  NOUN   Val Precision: 0.2978, Recall: 0.2594, F1: 0.2225
  ACTION Val Precision: 0.1314, Recall: 0.1068, F1: 0.0989
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8134333832554272
     NOUN    Mean Top-5 Recall: 0.5406994840319337
     ACTION  Mean Top-5 Recall: 0.3344319386869758
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.4358974358974359  Top5: 0.8205128205128205
    @ 0.50s  Top1: 0.5121951219512195  Top5: 0.8048780487804879
    @ 0.75s  Top1: 0.4772727272727273  Top5: 0.

Epoch 20 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 20 Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 20/20 | Time 2.4s
  Train Loss: 0.6367 | Val Loss: 7.0497
  VERB   Train Top1: 0.9319852941176471, Top5: 1.0; Val Top1: 0.49291784702549574, Top5: 0.7932011331444759
  NOUN   Train Top1: 0.9577205882352942, Top5: 0.9963235294117647; Val Top1: 0.33994334277620397, Top5: 0.5807365439093485
  ACTION Train Top1: 0.9669117647058824, Top5: 1.0; Val Top1: 0.20679886685552407, Top5: 0.3342776203966006
  VERB   Val Precision: 0.2618, Recall: 0.2150, F1: 0.2273
  NOUN   Val Precision: 0.2684, Recall: 0.2425, F1: 0.2149
  ACTION Val Precision: 0.1087, Recall: 0.0964, F1: 0.0854
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7937467067073376
     NOUN    Mean Top-5 Recall: 0.5813905790509714
     ACTION  Mean Top-5 Recall: 0.3344319386869758
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.48717948717948717  Top5: 0.8205128205128205
    @ 0.50s  Top1: 0.5365853658536586  Top5: 0.8048780487804879
    @ 0.75s  Top1: 0.4772727272727273  Top5: 0.7727272727272