In [13]:
"""
Full pipeline â€” single-file.

How to use:
1) Edit USER CONFIG section: paths, hyperparams.
2) Optionally run extraction step to produce fused CSV (set DO_EXTRACT=True).
3) Set LOSS_MODE to the loss you want to run: "ce","focal","smooth","contrast","graph","combined".
4) Run script.

Note: test with small sizes (NUM_EPOCHS=1, BATCH_SIZE=1, T_OBS small) first.
"""
# -------------------------
# IMPORTS
# -------------------------
import os, re, math, random, time
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.models import resnet50
from torch.utils.data import Dataset, DataLoader, Subset
import torch.optim as optim
from torch_geometric.data import Data as PyGData, Batch as PyGBatch
from torch_geometric.nn import GATConv
from torch_geometric.utils import to_dense_batch
from sklearn.metrics import precision_recall_fscore_support
import warnings

In [22]:
# -------------------------
# USER CONFIG (EDIT THESE)
# -------------------------
# Paths (example)
RGB_FOLDER = Path(r"D:\Datasets\Datasets\EPIC_Kitchen\RGB\P01_04\Original")
FLOW_FOLDER = Path(r"D:\Datasets\Datasets\EPIC_Kitchen\OpticalFlow\P01_04\P01_04")
LABEL_CSV   = Path(r"D:\Datasets\Datasets\EPIC\Labels\P01_04.csv")
OUTPUT_FUSED_CSV = Path(r"D:\Datasets\Datasets\EPIC\Features\FusedFeatures\P01_04_fused_features.csv")

# Extraction control
DO_EXTRACT = False    # set True to run extraction step now

# Sampling & features
SAMPLE_RATE = 1     # sample every S-th frame
FEAT_DIM = 512
W_RGB = 0.6
W_FLOW = 0.4

# Graph & model
K = 5               # KNN neighbors
DROP = 0.1



In [63]:
LABEL_CSV   = Path(r"D:\Datasets\Datasets\EPIC\Labels\P01_01.csv")
OUTPUT_FUSED_CSV = Path(r"D:\Datasets\Datasets\EPIC\Features\FusedFeatures\P01_01_fused_features.csv")

In [64]:
# Loss mode (choose one): "ce","focal","smooth","contrast","graph","combined"
LOSS_MODE = "combined"

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Outputs
OUTPUT_FUSED_CSV.parent.mkdir(parents=True, exist_ok=True)
BEST_MODEL_PATH = Path(r"./best_model.pth")

# Repro
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)



<torch._C.Generator at 0x18d163edd30>

In [65]:
# # -------------------------
# # Utilities
# # -------------------------
# _frame_number_re = re.compile(r"(\d+)(?=\.[^.]+$)")
# def parse_frame_index(fname: str):
#     m = _frame_number_re.search(fname)
#     if m:
#         return int(m.group(1))
#     digs = re.findall(r"\d+", fname)
#     return int(digs[-1]) if digs else 0

# def ensure_dir(p: Path):
#     p.mkdir(parents=True, exist_ok=True)
#     return p

# # -------------------------
# # FEATURE EXTRACTOR (ResNet50 -> proj)
# # -------------------------
# _resnet = resnet50(weights=True)
# _resnet = nn.Sequential(*list(_resnet.children())[:-1]).to(DEVICE).eval()   # outputs (B,2048,1,1)
# _proj = nn.Linear(2048, FEAT_DIM).to(DEVICE).eval()
# _transform = T.Compose([T.Resize((224,224)), T.ToTensor(),
#                         T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])

# @torch.no_grad()
# def extract_feature_from_pil(pil_img: Image.Image):
#     x = _transform(pil_img).unsqueeze(0).to(DEVICE)   # (1,3,224,224)
#     feat = _resnet(x).view(1, -1)                     # (1,2048)
#     feat = _proj(feat)                                # (1,FEAT_DIM)
#     return feat.squeeze(0).cpu().numpy()              # (FEAT_DIM,)

# # -------------------------
# # EXTRACT & SAVE FUSED FEATURES
# # -------------------------
# def extract_and_save_fused(csv_labels_path: Path,
#                            rgb_folder: Path,
#                            flow_folder: Path or None,
#                            out_fused_csv: Path,
#                            sample_rate: int = 1,
#                            w_rgb: float = 0.6,
#                            w_flow: float = 0.4):
#     """
#     Extract & fuse features for one video folder. Saves fused CSV with columns:
#     frame_idx, frame_name, ActionLabel, ActionName, feat_0..feat_{FEAT_DIM-1}
#     """
#     if not csv_labels_path.exists():
#         raise FileNotFoundError(f"Labels CSV not found: {csv_labels_path}")
#     labels_df = pd.read_csv(csv_labels_path)
#     rgb_files = sorted([p for p in rgb_folder.iterdir() if p.suffix.lower() in [".jpg",".png",".jpeg"]])
#     sampled = rgb_files[::sample_rate]
#     if len(sampled) == 0:
#         raise RuntimeError(f"No frames found in {rgb_folder}")

#     fused_rows = []
#     feat_cols = [f"feat_{i}" for i in range(FEAT_DIM)]

#     for fp in tqdm(sampled, desc=f"Extract & fuse {rgb_folder.name}"):
#         fname = fp.name
#         frame_idx = parse_frame_index(fname)

#         # RGB
#         try:
#             pil = Image.open(fp).convert("RGB")
#             rgb_feat = extract_feature_from_pil(pil)
#         except Exception as e:
#             warnings.warn(f"[WARN] RGB skip {fname}: {e}")
#             continue

#         # Flow (fallback to zeros if missing)
#         if flow_folder is not None:
#             ffp = Path(flow_folder) / fname
#             if not ffp.exists():
#                 # fallback: use RGB image for alignment (keeps pipeline working)
#                 ffp = fp
#             try:
#                 pilf = Image.open(ffp).convert("RGB")
#                 flow_feat = extract_feature_from_pil(pilf)
#             except Exception as e:
#                 warnings.warn(f"[WARN] FLOW skip {fname}: {e}; using zeros")
#                 flow_feat = np.zeros(FEAT_DIM, dtype=np.float32)
#         else:
#             flow_feat = np.zeros(FEAT_DIM, dtype=np.float32)

#         # weighted fusion
#         fused_vec = w_rgb * rgb_feat.astype(np.float32) + w_flow * flow_feat.astype(np.float32)

#         # label lookup
#         lr = labels_df[(labels_df["StartFrame"] <= frame_idx) & (labels_df["EndFrame"] >= frame_idx)]
#         if not lr.empty:
#             action_label = int(lr.iloc[0].get("ActionLabel", -1))
#             action_name  = str(lr.iloc[0].get("ActionName", "Unknown"))
#         else:
#             action_label, action_name = -1, "Unknown"

#         row = {"frame_idx": int(frame_idx), "frame_name": fname, "ActionLabel": int(action_label), "ActionName": action_name}
#         for i_val, v in enumerate(fused_vec):
#             row[f"feat_{i_val}"] = float(v)
#         fused_rows.append(row)

#     if len(fused_rows) == 0:
#         raise RuntimeError("No fused rows extracted; check paths and files.")
#     df_fused = pd.DataFrame(fused_rows)
#     df_fused.to_csv(out_fused_csv, index=False)
#     print(f"[SAVED] fused CSV -> {out_fused_csv}")
#     return df_fused

# # optionally run extraction
# if DO_EXTRACT:
#     df_fused = extract_and_save_fused(
#         csv_labels_path = LABEL_CSV,
#         rgb_folder = RGB_FOLDER,
#         flow_folder = FLOW_FOLDER if (FLOW_FOLDER is not None and FLOW_FOLDER.exists()) else None,
#         out_fused_csv = OUTPUT_FUSED_CSV,
#         sample_rate = SAMPLE_RATE,
#         w_rgb = W_RGB,
#         w_flow = W_FLOW
#     )


In [66]:
# -------------------------
# LOAD FUSED CSV & LABELS
# -------------------------
def load_fused_csv_by_path(fused_csv_path: str):
    fp = Path(fused_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Fused features CSV not found: {fp}")
    df = pd.read_csv(fp)
    if "frame_idx" not in df.columns:
        raise KeyError("Fused CSV must contain 'frame_idx' column")
    df["frame_idx"] = df["frame_idx"].astype(int)
    df = df.sort_values("frame_idx").reset_index(drop=True)
    return df

def load_label_csv_by_path(label_csv_path: str):
    fp = Path(label_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Label CSV not found: {fp}")
    df = pd.read_csv(fp)
    return df



In [67]:
# -------------------------
# DATASET
# -------------------------
IGNORE_INDEX = -1

class SingleVideoAnticipationDataset(Dataset):
    def __init__(self
                 , fused_df_or_path, labels_df_or_path,
                 t_obs: int, k_fut: int, feat_dim: int,
                 fps: float, horizons_s):
        # load
        if isinstance(fused_df_or_path, (str, Path)):
            fused_df = pd.read_csv(fused_df_or_path)
        else:
            fused_df = fused_df_or_path.copy()
        if isinstance(labels_df_or_path, (str, Path)):
            labels_df = pd.read_csv(labels_df_or_path)
        else:
            labels_df = labels_df_or_path.copy()

        if "frame_idx" not in fused_df.columns:
            raise KeyError("fused_df must contain 'frame_idx'")
        fused_df["frame_idx"] = fused_df["frame_idx"].astype(int)
        self.fused_df = fused_df.set_index("frame_idx", drop=False).sort_index()
        self.labels_df = labels_df.reset_index(drop=True)

        if not all(c in self.labels_df.columns for c in ["StartFrame", "EndFrame"]):
            raise KeyError("labels_df must contain StartFrame and EndFrame")

        self.t_obs = int(t_obs)
        self.k_fut = int(k_fut)
        self.feat_dim = int(feat_dim)
        self.feat_cols = [f"feat_{i}" for i in range(self.feat_dim)]
        self.fps = float(fps)
        assert len(horizons_s) == self.k_fut, "len(horizons_s) must equal k_fut"
        self.horizons_s = list(horizons_s)

        # build samples: one per label row (use EndFrame as obs_end)
        self.samples = []
        for ridx, row in self.labels_df.iterrows():
            try:
                obs_end = int(row["EndFrame"])
            except:
                continue
            self.samples.append({"label_row_idx": int(ridx), "obs_end": obs_end})

        if len(self.samples) == 0:
            raise RuntimeError("No valid label rows found")

    def __len__(self):
        return len(self.samples)

    def _time_based_future_labels(self, obs_end: int):
        labels_df = self.labels_df
        def pick(cols):
            for c in cols:
                if c in labels_df.columns:
                    return c
            return None
        vcol = pick(["Verb_class","verb","Verb","verb_class"])
        ncol = pick(["Noun_class","noun","Noun","noun_class"])
        acol = pick(["Action_class","action","Action","ActionLabel"])
        verb_targets   = []
        noun_targets   = []
        action_targets = []
        for h_sec in self.horizons_s:
            future_frame = obs_end + int(round(h_sec * self.fps))
            seg = labels_df[(labels_df["StartFrame"] <= future_frame) &
                            (labels_df["EndFrame"]   >= future_frame)]
            if seg.empty:
                verb_targets.append(IGNORE_INDEX)
                noun_targets.append(IGNORE_INDEX)
                action_targets.append(IGNORE_INDEX)
            else:
                row = seg.iloc[0]
                if vcol is not None and not pd.isna(row[vcol]):
                    verb_targets.append(int(row[vcol]))
                else:
                    verb_targets.append(IGNORE_INDEX)
                if ncol is not None and not pd.isna(row[ncol]):
                    noun_targets.append(int(row[ncol]))
                else:
                    noun_targets.append(IGNORE_INDEX)
                if acol is not None and not pd.isna(row[acol]):
                    action_targets.append(int(row[acol]))
                else:
                    action_targets.append(IGNORE_INDEX)
        return {
            "verb":   torch.LongTensor(verb_targets),
            "noun":   torch.LongTensor(noun_targets),
            "action": torch.LongTensor(action_targets)
        }

    def __getitem__(self, idx):
        rec = self.samples[idx]
        obs_end = rec["obs_end"]
        obs_start = obs_end - (self.t_obs - 1)
        if obs_start < 0:
            obs_start = 0
            obs_end = obs_start + (self.t_obs - 1)

        fused_idx_min = int(self.fused_df.index.min())
        fused_idx_max = int(self.fused_df.index.max())
        obs_end = min(obs_end, fused_idx_max)
        obs_start = max(obs_end - (self.t_obs - 1), fused_idx_min)

        desired = list(range(obs_start, obs_end + 1))
        sel = self.fused_df.reindex(desired).fillna(method="ffill").fillna(method="bfill").fillna(0.0)

        if sel.shape[0] < self.t_obs:
            if sel.shape[0] == 0:
                zero_row = {c:0.0 for c in self.feat_cols}
                sel = pd.DataFrame([zero_row] * self.t_obs)
            else:
                first = sel.iloc[[0]]
                pads = pd.concat([first] * (self.t_obs - sel.shape[0]), ignore_index=True)
                sel = pd.concat([pads, sel.reset_index(drop=True)], ignore_index=True)

        for c in self.feat_cols:
            if c not in sel.columns:
                sel[c] = 0.0

        F_window = torch.from_numpy(sel[self.feat_cols].values).float()   # (T_obs, FEAT_DIM)
        y_multi = self._time_based_future_labels(obs_end)
        meta = {"obs_start": int(obs_start),
                "obs_end":   int(obs_end),
                "label_row_idx": int(rec["label_row_idx"])}
        return F_window, y_multi, meta


In [68]:
# -------------------------
# GRAPH: device-friendly top-k
# -------------------------
def build_topk_edge_index(features: torch.Tensor, k=K, device=None):
    """
    Compute top-k edges on the device of `features`.
    - features: (T, D) float tensor (can be on CPU or GPU)
    - returns: edge_index (2, E) long tensor on same device as features (or 'device' arg)
    """
    if device is None:
        device = features.device
    features = features.to(device)
    Tn = int(features.size(0))
    if Tn == 0:
        return torch.zeros((2, 0), dtype=torch.long, device=device)
    x = F.normalize(features, dim=1)          # (T, D)
    sim = torch.matmul(x, x.t())              # (T, T)
    sim.fill_diagonal_(-1.0)
    kk = min(max(0, int(k)), max(0, Tn - 1))
    if kk <= 0:
        return torch.zeros((2, 0), dtype=torch.long, device=device)
    vals, idxs = torch.topk(sim, kk, dim=1)   # (T, kk)
    src = torch.arange(Tn, device=device).unsqueeze(1).expand(-1, kk).reshape(-1)
    dst = idxs.reshape(-1)
    edge = torch.stack([src, dst], dim=0)
    edge_rev = torch.stack([dst, src], dim=0)
    return torch.cat([edge, edge_rev], dim=1).long()

class BatchedGAT(nn.Module):
    def __init__(self, in_dim, hid_dim=None, num_layers=3, heads=8, dropout=DROP):
        super().__init__()
        hid = hid_dim or in_dim
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_ch = in_dim if i==0 else hid
            self.convs.append(GATConv(in_ch, hid//heads, heads=heads, concat=True, dropout=dropout))
        self.proj = nn.Linear(hid, in_dim)
        self.norm = nn.LayerNorm(in_dim)
        self.act = nn.GELU()
    def forward(self, pyg_batch: PyGBatch, T_per_sample: int):
        x = pyg_batch.x; edge_index = pyg_batch.edge_index
        h = x
        for conv in self.convs:
            h = conv(h, edge_index); h = self.act(h)
        h = self.proj(h)
        node_feats, mask = to_dense_batch(h, batch=pyg_batch.batch)  # (B, max_nodes, D)
        B, max_nodes, D = node_feats.shape
        if max_nodes < T_per_sample:
            pad = torch.zeros(B, T_per_sample - max_nodes, D, device=node_feats.device)
            node_feats = torch.cat([node_feats, pad], dim=1)
        elif max_nodes > T_per_sample:
            node_feats = node_feats[:, :T_per_sample, :]
        return self.norm(node_feats)  # (B, T_per_sample, D)



In [69]:
# -------------------------
# Encoder, Decoder, AnticipationModel
# -------------------------
class SimpleTransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead=8, num_layers=3, dim_feedforward=2048, dropout=DROP, max_len=1000):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=num_layers)
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model))
    def forward(self, x):
        B,T,D = x.shape
        pos = self.pos_emb[:, :T, :].to(x.device)
        return self.encoder(x + pos)

class AnticipationModel(nn.Module):
    def __init__(self, feat_dim, num_classes: dict, k_fut=5, gat_layers=3, gat_heads=8, dec_layers=3, dec_heads=8, dropout=DROP):
        super().__init__()
        self.feat_dim = feat_dim; self.k_fut = k_fut
        self.gat = BatchedGAT(in_dim=feat_dim, hid_dim=feat_dim, num_layers=gat_layers, heads=gat_heads, dropout=dropout)
        self.encoder = SimpleTransformerEncoder(d_model=feat_dim, nhead=dec_heads, num_layers=3)
        dec_layer = nn.TransformerDecoderLayer(d_model=feat_dim, nhead=dec_heads, dim_feedforward=feat_dim*4, dropout=dropout, activation='gelu', batch_first=True)
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=dec_layers)
        self.queries = nn.Parameter(torch.randn(1, k_fut, feat_dim))
        assert isinstance(num_classes, dict)
        self.verb_head = nn.Linear(feat_dim, num_classes["verb"])
        self.noun_head = nn.Linear(feat_dim, num_classes["noun"])
        self.action_head = nn.Linear(feat_dim, num_classes["action"])

    def forward(self, F_batch):
        # F_batch: (B, T, D)
        B,T,D = F_batch.shape; device = F_batch.device
        data_list=[]
        for b in range(B):
            x = F_batch[b]
            edge_index = build_topk_edge_index(x, k=K, device=device)
            data_list.append(PyGData(x=x, edge_index=edge_index))
        pyg_batch = PyGBatch.from_data_list(data_list).to(device)
        gat_out = self.gat(pyg_batch, T_per_sample=T)   # (B,T,D)
        enc_out = self.encoder(F_batch)                 # (B,T,D)
        U = enc_out + gat_out
        q = self.queries.expand(B, -1, -1).to(device)
        dec_out = self.decoder(tgt=q, memory=U)         # (B, K_fut, D)
        logits = {
            "verb": self.verb_head(dec_out),
            "noun": self.noun_head(dec_out),
            "action": self.action_head(dec_out)
        }
        return logits, dec_out, gat_out

In [70]:
# -------------------------
# LOSS HELPERS (inline)
# -------------------------
class FocalLoss(nn.Module):
    def __init__(self, gamma: float = 2.0, reduction: str = 'mean', ignore_index: int = -1, eps: float = 1e-8):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.eps = eps
    def forward(self, logits: torch.Tensor, target: torch.Tensor):
        B, K, C = logits.shape
        logits_flat = logits.view(-1, C)
        target_flat = target.view(-1)
        mask = (target_flat != self.ignore_index)
        if int(mask.sum().item()) == 0:
            return logits.new_tensor(0.0)
        probs = F.softmax(logits_flat, dim=-1)
        idx = target_flat.clamp_min(0).long()
        pt = probs[torch.arange(probs.size(0), device=probs.device), idx]
        pt = torch.clamp(pt, min=self.eps)
        loss = -((1.0 - pt) ** self.gamma) * torch.log(pt)
        loss = loss[mask]
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

def temporal_smoothness_loss(dec_outs: torch.Tensor):
    if dec_outs is None:
        return torch.tensor(0.0)
    diff = dec_outs[:, 1:, :] - dec_outs[:, :-1, :]
    return diff.pow(2).mean()

def supervised_contrastive_loss(features: torch.Tensor, labels: torch.Tensor, temperature: float = 0.07):
    device = features.device
    features = F.normalize(features, dim=1)
    logits = torch.matmul(features, features.t()) / temperature  # (N, N)
    logits_max, _ = logits.max(dim=1, keepdim=True)
    logits = logits - logits_max.detach()
    labels = labels.contiguous().view(-1, 1)
    mask = torch.eq(labels, labels.t()).float().to(device)
    diag = torch.eye(mask.size(0), device=device)
    mask_non_self = mask * (1.0 - diag)
    exp_logits = torch.exp(logits) * (1.0 - diag)
    log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)
    mean_log_prob_pos = (mask_non_self * log_prob).sum(1) / (mask_non_self.sum(1) + 1e-12)
    loss = - mean_log_prob_pos
    loss = loss.mean()
    return loss

def build_batch_adjacency_from_features(features_batch: torch.Tensor, k: int, symmetric: bool = True):
    """
    features_batch: (B, T, D)
    returns: A_blocks (B*T, B*T) block-diagonal with per-sample top-k adjacency
    For B=1 it returns (T,T) adjacency for that sample.
    """
    device = features_batch.device
    B, T, D = features_batch.shape
    N = B * T
    A_blocks = torch.zeros((N, N), dtype=torch.float32, device=device)
    for b in range(B):
        feats = features_batch[b]  # (T, D)
        x = F.normalize(feats, dim=1)
        sim = torch.matmul(x, x.t())
        sim.fill_diagonal_(-1.0)
        kk = min(k, max(0, T-1))
        if kk <= 0:
            continue
        vals, idxs = torch.topk(sim, kk, dim=1)
        for i in range(T):
            neighbors = idxs[i]
            for nbr in neighbors.tolist():
                A_blocks[b*T + i, b*T + nbr] = 1.0
                if symmetric:
                    A_blocks[b*T + nbr, b*T + i] = 1.0
    return A_blocks

def graph_reconstruction_loss(gat_out: torch.Tensor, features_batch: torch.Tensor, k: int):
    if gat_out is None:
        return torch.tensor(0.0, device=features_batch.device)
    B, T, D = gat_out.shape
    losses = []
    for b in range(B):
        emb = F.normalize(gat_out[b], dim=1)  # (T, D)
        A_hat = torch.sigmoid(torch.matmul(emb, emb.t()))  # (T, T)
        A_gt_block = build_batch_adjacency_from_features(features_batch[b:b+1].detach(), k=k)
        A_gt = A_gt_block.to(A_hat.device)
        losses.append(F.binary_cross_entropy(A_hat, A_gt))
    return torch.stack(losses).mean() if len(losses) > 0 else torch.tensor(0.0, device=gat_out.device)

# -------------------------
# masked CE & metrics helpers
# -------------------------
def masked_cross_entropy(logits, labels, ignore_index=IGNORE_INDEX):
    B, K, C = logits.shape
    logits_flat = logits.view(B * K, C)      # (B*K, C)
    labels_flat = labels.view(B * K)         # (B*K,)
    loss_flat = F.cross_entropy(logits_flat, labels_flat, reduction='none', ignore_index=ignore_index)
    mask = (labels_flat != ignore_index).float()
    valid = mask.sum()
    if valid == 0:
        return (logits_flat * 0.0).sum()
    return (loss_flat * mask).sum() / valid

def topk_accuracy_per_task(logits, labels, topk=(1,5), ignore_index=IGNORE_INDEX):
    B,K,C = logits.shape
    res = {}
    overall = {k:0 for k in topk}
    total_cnt = 0
    preds_topk = logits.topk(max(topk), dim=-1)[1]  # (B,K,maxk)
    for h in range(K):
        lab = labels[:,h]; mask = (lab != ignore_index); cnt = int(mask.sum().item())
        for k in topk:
            if cnt == 0:
                res.setdefault(f"per_h{h+1}_top{k}", None)
                continue
            predk = preds_topk[:,h,:k]  # (B,k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            hits = (predk == lab_exp)
            hit = int(hits[mask].any(dim=1).float().sum().item())
            res[f"per_h{h+1}_top{k}"] = hit / cnt
            overall[k] += hit
        total_cnt += cnt
    for k in topk:
        res[f"overall_top{k}"] = overall[k] / total_cnt if total_cnt>0 else None
    return res

def topk_counts(logits, labels, k):
    with torch.no_grad():
        B, K, C = logits.shape
        topk_preds = logits.topk(k, dim=-1)[1]  # (B, K, k)
        hits = 0
        total = 0
        for h in range(K):
            lab = labels[:, h]  # (B,)
            mask = (lab != IGNORE_INDEX)
            if int(mask.sum().item()) == 0:
                continue
            predk = topk_preds[:, h, :]  # (B, k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            masked_pred = predk[mask]
            masked_lab = lab_exp[mask]
            hit_vec = (masked_pred == masked_lab).any(dim=1).float()
            hits += int(hit_vec.sum().item())
            total += int(mask.sum().item())
        return hits, total

In [71]:
# -------------------------
# MAIN: Prepare dataset, model, and training
# -------------------------


FUSED_CSV_PATH = str(OUTPUT_FUSED_CSV) 
LABEL_CSV_PATH = str(LABEL_CSV)

# load fused and labels
fused_df = load_fused_csv_by_path(FUSED_CSV_PATH)
labels_df = load_label_csv_by_path(LABEL_CSV_PATH)

# Training & dataset
T_OBS = 90
FPS = 30.0
HORIZONS_S = [0.25, 0.5, 0.75, 1.0, 1.25, 1.50, 1.75, 2.0]
K_FUT = len(HORIZONS_S)

BATCH_SIZE = 8
NUM_EPOCHS = 20
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 0

# dataset
dataset = SingleVideoAnticipationDataset(
    fused_df,
    labels_df,
    t_obs=T_OBS,
    k_fut=K_FUT,
    feat_dim=FEAT_DIM,
    fps=FPS,
    horizons_s=HORIZONS_S
)


# split (60/40)
indices = list(range(len(dataset)))
random.seed(SEED)
random.shuffle(indices)
split_at = int(0.6 * len(indices))
train_idx = indices[:split_at]; val_idx = indices[split_at:]
train_ds = Subset(dataset, train_idx); val_ds = Subset(dataset, val_idx)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=(DEVICE=="cuda"))
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=(DEVICE=="cuda"))

# detect num classes
def detect_num_classes_from_labels_df(labels_df):
    verbs = set(); nouns = set(); actions = set()
    for cand in ["Verb_class","verb","Verb","verb_class"]:
        if cand in labels_df.columns:
            verbs.update(labels_df[cand].dropna().astype(int).tolist()); break
    for cand in ["Noun_class","noun","Noun","noun_class"]:
        if cand in labels_df.columns:
            nouns.update(labels_df[cand].dropna().astype(int).tolist()); break
    for cand in ["Action_class","action","Action","ActionLabel"]:
        if cand in labels_df.columns:
            actions.update(labels_df[cand].dropna().astype(int).tolist()); break
    nv = (max(verbs) + 1) if len(verbs) > 0 else 1
    nn_ = (max(nouns) + 1) if len(nouns) > 0 else 1
    na = (max(actions) + 1) if len(actions) > 0 else 1
    return {"verb": int(nv), "noun": int(nn_), "action": int(na)}

num_classes = detect_num_classes_from_labels_df(labels_df)
print("Detected num_classes:", num_classes)

# instantiate model, optimizer, scheduler
model = AnticipationModel(feat_dim=FEAT_DIM, num_classes=num_classes, k_fut=K_FUT).to(DEVICE)
opt = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=3)

# instantiate auxiliary helpers and weights (used by LOSS_MODE)
focal_fn = FocalLoss(gamma=2.0, ignore_index=IGNORE_INDEX)
focal_alpha = 0.3
smooth_weight = 0.1
contrast_weight = 0.1
contrast_temperature = 0.07
graph_rec_weight = 0.05

# training loop
best_val_loss = float("inf")
for epoch in range(1, NUM_EPOCHS + 1):
    t0 = time.time()
    model.train()
    train_loss_sum = 0.0; train_samples = 0
    train_counts = {"verb_top1":[0,0],"verb_top5":[0,0],"noun_top1":[0,0],"noun_top5":[0,0],"action_top1":[0,0],"action_top5":[0,0]}

    for F_batch, y_multi, meta in tqdm(train_loader, desc=f"Epoch {epoch} Train"):
        F_batch = F_batch.to(DEVICE)
        y_v = y_multi["verb"].to(DEVICE)
        y_n = y_multi["noun"].to(DEVICE)
        y_a = y_multi["action"].to(DEVICE)

        opt.zero_grad()
        logits, dec_outs, gat_out = model(F_batch)

        # base CE
        loss_v_ce = masked_cross_entropy(logits["verb"], y_v)
        loss_n_ce = masked_cross_entropy(logits["noun"], y_n)
        loss_a_ce = masked_cross_entropy(logits["action"], y_a)
        base_loss = loss_a_ce + 0.5 * loss_v_ce + 0.5 * loss_n_ce

        # init aux terms
        focal_term = torch.tensor(0.0, device=F_batch.device)
        smooth_term = torch.tensor(0.0, device=F_batch.device)
        contrast_term = torch.tensor(0.0, device=F_batch.device)
        graph_term = torch.tensor(0.0, device=F_batch.device)

        # LOSS_MODE branches
        if LOSS_MODE == "ce":
            loss = base_loss

        elif LOSS_MODE == "focal":
            f_v = focal_fn(logits["verb"], y_v)
            f_n = focal_fn(logits["noun"], y_n)
            f_a = focal_fn(logits["action"], y_a)
            loss_v = (1.0 - focal_alpha) * loss_v_ce + focal_alpha * f_v
            loss_n = (1.0 - focal_alpha) * loss_n_ce + focal_alpha * f_n
            loss_a = (1.0 - focal_alpha) * loss_a_ce + focal_alpha * f_a
            focal_term = (f_v + f_n + f_a) / 3.0
            loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

        elif LOSS_MODE == "smooth":
            loss = base_loss
            if dec_outs is not None:
                smooth_term = temporal_smoothness_loss(dec_outs)
                loss = loss + smooth_weight * smooth_term

        elif LOSS_MODE == "contrast":
            loss = base_loss
            feats_for_contrast = F.normalize(F_batch[:, -1, :], dim=1)
            labels_for_contrast = y_a[:, 0].clone().detach()
            valid_mask = (labels_for_contrast != IGNORE_INDEX)
            if int(valid_mask.sum().item()) > 1:
                contrast_term = supervised_contrastive_loss(feats_for_contrast[valid_mask], labels_for_contrast[valid_mask], temperature=contrast_temperature)
                loss = loss + contrast_weight * contrast_term

        elif LOSS_MODE == "graph":
            loss = base_loss
            if gat_out is not None:
                graph_term = graph_reconstruction_loss(gat_out, F_batch.detach(), k=K)
                loss = loss + graph_rec_weight * graph_term

        elif LOSS_MODE == "combined":
            f_v = focal_fn(logits["verb"], y_v)
            f_n = focal_fn(logits["noun"], y_n)
            f_a = focal_fn(logits["action"], y_a)
            loss_v = (1.0 - focal_alpha) * loss_v_ce + focal_alpha * f_v
            loss_n = (1.0 - focal_alpha) * loss_n_ce + focal_alpha * f_n
            loss_a = (1.0 - focal_alpha) * loss_a_ce + focal_alpha * f_a
            focal_term = (f_v + f_n + f_a) / 3.0
            if dec_outs is not None:
                smooth_term = temporal_smoothness_loss(dec_outs)
            feats_for_contrast = F.normalize(F_batch[:, -1, :], dim=1)
            labels_for_contrast = y_a[:, 0].clone().detach()
            if int((labels_for_contrast != IGNORE_INDEX).sum().item()) > 1:
                contrast_term = supervised_contrastive_loss(feats_for_contrast[labels_for_contrast != IGNORE_INDEX], labels_for_contrast[labels_for_contrast != IGNORE_INDEX], temperature=contrast_temperature)
            if gat_out is not None:
                graph_term = graph_reconstruction_loss(gat_out, F_batch.detach(), k=K)
            loss = loss_a + 0.5 * loss_v + 0.5 * loss_n
            loss = loss + smooth_weight * smooth_term + contrast_weight * contrast_term + graph_rec_weight * graph_term

        else:
            raise ValueError(f"Unknown LOSS_MODE: {LOSS_MODE}")

        # backward + step
        loss.backward()
        opt.step()

        # bookkeeping
        b = F_batch.size(0)
        train_loss_sum += float(loss.item()) * b
        train_samples += b

        for (task, lab, lg) in [("verb", y_v, logits["verb"]), ("noun", y_n, logits["noun"]), ("action", y_a, logits["action"])]:
            h1, t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
            h5, t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
            train_counts[f"{task}_top1"][0] += h1; train_counts[f"{task}_top1"][1] += t1
            train_counts[f"{task}_top5"][0] += h5; train_counts[f"{task}_top5"][1] += t5

    # train metrics
    train_loss = train_loss_sum / max(1, train_samples)
    train_metrics = {}
    for task in ["verb","noun","action"]:
        h1,t1 = train_counts[f"{task}_top1"]; h5,t5 = train_counts[f"{task}_top5"]
        train_metrics[f"{task}_top1"] = (h1 / t1) if t1>0 else None
        train_metrics[f"{task}_top5"] = (h5 / t5) if t5>0 else None

    # ------------- VALIDATION -------------
    model.eval()
    val_loss_sum = 0.0; val_samples = 0
    val_counts = {"verb_top1":[0,0],"verb_top5":[0,0],"noun_top1":[0,0],"noun_top5":[0,0],"action_top1":[0,0],"action_top5":[0,0]}
    val_logits_store = {"verb": [], "noun": [], "action": []}
    val_labels_store = {"verb": [], "noun": [], "action": []}

    with torch.no_grad():
        for F_batch, y_multi, meta in tqdm(val_loader, desc=f"Epoch {epoch} Val"):
            F_batch = F_batch.to(DEVICE)
            y_v = y_multi["verb"].to(DEVICE)
            y_n = y_multi["noun"].to(DEVICE)
            y_a = y_multi["action"].to(DEVICE)

            logits, _, _ = model(F_batch)
            loss_v = masked_cross_entropy(logits["verb"], y_v)
            loss_n = masked_cross_entropy(logits["noun"], y_n)
            loss_a = masked_cross_entropy(logits["action"], y_a)
            loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

            b = F_batch.size(0)
            val_loss_sum += float(loss.item()) * b
            val_samples += b

            for (task, lab, lg) in [("verb", y_v, logits["verb"]), ("noun", y_n, logits["noun"]), ("action", y_a, logits["action"])]:
                h1,t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
                h5,t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
                val_counts[f"{task}_top1"][0] += h1; val_counts[f"{task}_top1"][1] += t1
                val_counts[f"{task}_top5"][0] += h5; val_counts[f"{task}_top5"][1] += t5

            val_logits_store["verb"].append(logits["verb"].detach().cpu())
            val_logits_store["noun"].append(logits["noun"].detach().cpu())
            val_logits_store["action"].append(logits["action"].detach().cpu())
            val_labels_store["verb"].append(y_v.detach().cpu())
            val_labels_store["noun"].append(y_n.detach().cpu())
            val_labels_store["action"].append(y_a.detach().cpu())

    val_loss = val_loss_sum / max(1, val_samples)

    # overall val metrics
    val_metrics = {}
    for task in ["verb","noun","action"]:
        h1,t1 = val_counts[f"{task}_top1"]; h5,t5 = val_counts[f"{task}_top5"]
        val_metrics[f"{task}_top1"] = (h1 / t1) if t1>0 else None
        val_metrics[f"{task}_top5"] = (h5 / t5) if t5>0 else None

    # per-horizon metrics
    per_horizon_metrics = {"verb":{},"noun":{},"action":{}}
    for task in ["verb","noun","action"]:
        if len(val_logits_store[task]) == 0:
            continue
        logits_all = torch.cat(val_logits_store[task], dim=0)
        labels_all = torch.cat(val_labels_store[task], dim=0)
        m = topk_accuracy_per_task(logits_all, labels_all, topk=(1,5), ignore_index=IGNORE_INDEX)
        per_horizon_metrics[task] = m

    # PRF macro
    prf_metrics = {"verb":{}, "noun":{}, "action":{}}
    for task in ["verb","noun","action"]:
        if len(val_logits_store[task]) == 0:
            continue
        logits_all = torch.cat(val_logits_store[task], dim=0)
        labels_all = torch.cat(val_labels_store[task], dim=0)
        preds_all = logits_all.argmax(dim=-1)
        mask = (labels_all != IGNORE_INDEX)
        if mask.sum().item() == 0:
            continue
        y_true = labels_all[mask].numpy()
        y_pred = preds_all[mask].numpy()
        p,r,f1,_ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
        prf_metrics[task]["precision"] = p; prf_metrics[task]["recall"] = r; prf_metrics[task]["f1"] = f1

    # mean top-5 recall
    mean_top5_recall = {}
    for task in ["verb","noun","action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            mean_top5_recall[task] = None; continue
        vals = []
        for h_idx in range(K_FUT):
            key = f"per_h{h_idx+1}_top5"
            if key in mh and mh[key] is not None:
                vals.append(mh[key])
        mean_top5_recall[task] = float(np.mean(vals)) if len(vals) > 0 else None

    sched.step(val_loss)
    elapsed = time.time() - t0

    # print summary
    print(f"Epoch {epoch}/{NUM_EPOCHS} | Time {elapsed:.1f}s")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    for task in ["verb","noun","action"]:
        print(f"  {task.upper():6s} Train Top1: {train_metrics[f'{task}_top1']}, Top5: {train_metrics[f'{task}_top5']}; Val Top1: {val_metrics[f'{task}_top1']}, Top5: {val_metrics[f'{task}_top5']}")
    for task in ["verb","noun","action"]:
        if prf_metrics[task]:
            print(f"  {task.upper():6s} Val Precision: {prf_metrics[task]['precision']:.4f}, Recall: {prf_metrics[task]['recall']:.4f}, F1: {prf_metrics[task]['f1']:.4f}")
    print("  ---- Mean Top-5 Recall (validation) ----")
    for task in ["verb","noun","action"]:
        print(f"     {task.upper():6s}  Mean Top-5 Recall: {mean_top5_recall[task]}")
    # per-horizon
    for task in ["verb","noun","action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            continue
        print(f"  {task.upper():6s} per-horizon (time-based):")
        for h_idx, t_sec in enumerate(HORIZONS_S):
            key1 = f"per_h{h_idx+1}_top1"; key5 = f"per_h{h_idx+1}_top5"
            v1 = mh.get(key1, None); v5 = mh.get(key5, None)
            print(f"    @ {t_sec:4.2f}s  Top1: {v1}  Top5: {v5}")
        print(f"    overall_top1: {mh.get('overall_top1', None)}, overall_top5: {mh.get('overall_top5', None)}")

    # save best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'opt_state': opt.state_dict(), 'val_loss': val_loss}, BEST_MODEL_PATH)
        print(f"[SAVED BEST] -> {BEST_MODEL_PATH}")

print("Training finished.")

Detected num_classes: {'verb': 62, 'noun': 186, 'action': 114}


Epoch 1 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1/20 | Time 12.6s
  Train Loss: 8.3516 | Val Loss: 7.4672
  VERB   Train Top1: 0.18342541436464088, Top5: 0.7204419889502762; Val Top1: 0.2309027777777778, Top5: 0.7552083333333334
  NOUN   Train Top1: 0.07624309392265194, Top5: 0.3027624309392265; Val Top1: 0.140625, Top5: 0.4565972222222222
  ACTION Train Top1: 0.0718232044198895, Top5: 0.2132596685082873; Val Top1: 0.07118055555555555, Top5: 0.1996527777777778
  VERB   Val Precision: 0.0333, Recall: 0.0782, F1: 0.0448
  NOUN   Val Precision: 0.0047, Recall: 0.0333, F1: 0.0082
  ACTION Val Precision: 0.0025, Recall: 0.0172, F1: 0.0044
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7566290368994246
     NOUN    Mean Top-5 Recall: 0.456086398107553
     ACTION  Mean Top-5 Recall: 0.19842277857813442
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.2413793103448276  Top5: 0.7931034482758621
    @ 0.50s  Top1: 0.21875  Top5: 0.78125
    @ 0.75s  Top1: 0.23943661971830985  Top5: 0.746478873239436

Epoch 2 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 2 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 2/20 | Time 13.2s
  Train Loss: 7.0398 | Val Loss: 7.3928
  VERB   Train Top1: 0.19447513812154696, Top5: 0.7944751381215469; Val Top1: 0.20833333333333334, Top5: 0.6979166666666666
  NOUN   Train Top1: 0.10718232044198896, Top5: 0.4419889502762431; Val Top1: 0.08854166666666667, Top5: 0.4756944444444444
  ACTION Train Top1: 0.13922651933701657, Top5: 0.292817679558011; Val Top1: 0.0763888888888889, Top5: 0.2638888888888889
  VERB   Val Precision: 0.0149, Recall: 0.0714, F1: 0.0246
  NOUN   Val Precision: 0.0362, Recall: 0.0337, F1: 0.0061
  ACTION Val Precision: 0.0192, Recall: 0.0188, F1: 0.0043
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6979106800488531
     NOUN    Mean Top-5 Recall: 0.47489848707872795
     ACTION  Mean Top-5 Recall: 0.2628790631798046
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.20689655172413793  Top5: 0.6896551724137931
    @ 0.50s  Top1: 0.21875  Top5: 0.703125
    @ 0.75s  Top1: 0.2112676056338028  Top5: 0.71

Epoch 3 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 3 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 3/20 | Time 12.7s
  Train Loss: 6.8516 | Val Loss: 7.4227
  VERB   Train Top1: 0.22209944751381216, Top5: 0.7933701657458564; Val Top1: 0.13368055555555555, Top5: 0.7552083333333334
  NOUN   Train Top1: 0.10718232044198896, Top5: 0.46298342541436466; Val Top1: 0.13020833333333334, Top5: 0.4618055555555556
  ACTION Train Top1: 0.08729281767955802, Top5: 0.3270718232044199; Val Top1: 0.07465277777777778, Top5: 0.22916666666666666
  VERB   Val Precision: 0.0299, Recall: 0.0755, F1: 0.0229
  NOUN   Val Precision: 0.0068, Recall: 0.0324, F1: 0.0109
  ACTION Val Precision: 0.0013, Recall: 0.0179, F1: 0.0025
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7566290368994246
     NOUN    Mean Top-5 Recall: 0.4617379053978894
     ACTION  Mean Top-5 Recall: 0.2284961772810221
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1206896551724138  Top5: 0.7931034482758621
    @ 0.50s  Top1: 0.125  Top5: 0.78125
    @ 0.75s  Top1: 0.1267605633802817  Top5: 0.746

Epoch 4 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 4 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 4/20 | Time 12.4s
  Train Loss: 6.6432 | Val Loss: 7.3830
  VERB   Train Top1: 0.29613259668508285, Top5: 0.8088397790055248; Val Top1: 0.2465277777777778, Top5: 0.7725694444444444
  NOUN   Train Top1: 0.1425414364640884, Top5: 0.4685082872928177; Val Top1: 0.1371527777777778, Top5: 0.4652777777777778
  ACTION Train Top1: 0.15138121546961325, Top5: 0.3138121546961326; Val Top1: 0.078125, Top5: 0.2065972222222222
  VERB   Val Precision: 0.0365, Recall: 0.0837, F1: 0.0495
  NOUN   Val Precision: 0.0113, Recall: 0.0325, F1: 0.0095
  ACTION Val Precision: 0.0077, Recall: 0.0248, F1: 0.0101
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7733316126528404
     NOUN    Mean Top-5 Recall: 0.46606355631542096
     ACTION  Mean Top-5 Recall: 0.2061797835665074
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.2413793103448276  Top5: 0.7931034482758621
    @ 0.50s  Top1: 0.265625  Top5: 0.78125
    @ 0.75s  Top1: 0.23943661971830985  Top5: 0.77464788732394

Epoch 5 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 5 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 5/20 | Time 13.1s
  Train Loss: 6.3798 | Val Loss: 7.4591
  VERB   Train Top1: 0.2662983425414365, Top5: 0.8198895027624309; Val Top1: 0.2569444444444444, Top5: 0.6736111111111112
  NOUN   Train Top1: 0.13812154696132597, Top5: 0.47624309392265196; Val Top1: 0.041666666666666664, Top5: 0.4236111111111111
  ACTION Train Top1: 0.11712707182320442, Top5: 0.36464088397790057; Val Top1: 0.043402777777777776, Top5: 0.1909722222222222
  VERB   Val Precision: 0.0607, Recall: 0.0949, F1: 0.0689
  NOUN   Val Precision: 0.0158, Recall: 0.0148, F1: 0.0057
  ACTION Val Precision: 0.0185, Recall: 0.0123, F1: 0.0078
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6729615310473903
     NOUN    Mean Top-5 Recall: 0.42376181734531915
     ACTION  Mean Top-5 Recall: 0.19146275494905873
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.25862068965517243  Top5: 0.6551724137931034
    @ 0.50s  Top1: 0.265625  Top5: 0.671875
    @ 0.75s  Top1: 0.23943661971830985  Top

Epoch 6 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 6 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 6/20 | Time 12.1s
  Train Loss: 6.0825 | Val Loss: 7.5633
  VERB   Train Top1: 0.27624309392265195, Top5: 0.8265193370165745; Val Top1: 0.11284722222222222, Top5: 0.7638888888888888
  NOUN   Train Top1: 0.1723756906077348, Top5: 0.5646408839779006; Val Top1: 0.171875, Top5: 0.5277777777777778
  ACTION Train Top1: 0.13370165745856355, Top5: 0.40331491712707185; Val Top1: 0.08854166666666667, Top5: 0.22743055555555555
  VERB   Val Precision: 0.0443, Recall: 0.0714, F1: 0.0288
  NOUN   Val Precision: 0.0214, Recall: 0.0480, F1: 0.0254
  ACTION Val Precision: 0.0103, Recall: 0.0254, F1: 0.0107
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7647158190012469
     NOUN    Mean Top-5 Recall: 0.5267768374071113
     ACTION  Mean Top-5 Recall: 0.22632822122189278
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1206896551724138  Top5: 0.7758620689655172
    @ 0.50s  Top1: 0.125  Top5: 0.796875
    @ 0.75s  Top1: 0.1267605633802817  Top5: 0.7605633802816

Epoch 7 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 7 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 7/20 | Time 12.5s
  Train Loss: 5.6276 | Val Loss: 7.2488
  VERB   Train Top1: 0.33370165745856356, Top5: 0.8585635359116022; Val Top1: 0.125, Top5: 0.78125
  NOUN   Train Top1: 0.1867403314917127, Top5: 0.5988950276243094; Val Top1: 0.16493055555555555, Top5: 0.5069444444444444
  ACTION Train Top1: 0.2154696132596685, Top5: 0.46740331491712706; Val Top1: 0.08854166666666667, Top5: 0.2638888888888889
  VERB   Val Precision: 0.0657, Recall: 0.0613, F1: 0.0556
  NOUN   Val Precision: 0.0568, Recall: 0.0901, F1: 0.0603
  ACTION Val Precision: 0.0150, Recall: 0.0345, F1: 0.0176
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7792901374477141
     NOUN    Mean Top-5 Recall: 0.5072520309128103
     ACTION  Mean Top-5 Recall: 0.2639481158455447
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1206896551724138  Top5: 0.7241379310344828
    @ 0.50s  Top1: 0.125  Top5: 0.765625
    @ 0.75s  Top1: 0.11267605633802817  Top5: 0.8028169014084507
    @ 1.00s 

Epoch 8 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 8 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 8/20 | Time 12.5s
  Train Loss: 5.0239 | Val Loss: 7.4734
  VERB   Train Top1: 0.3524861878453039, Top5: 0.881767955801105; Val Top1: 0.0920138888888889, Top5: 0.7534722222222222
  NOUN   Train Top1: 0.31712707182320443, Top5: 0.6895027624309392; Val Top1: 0.1996527777777778, Top5: 0.5347222222222222
  ACTION Train Top1: 0.2861878453038674, Top5: 0.5933701657458563; Val Top1: 0.08854166666666667, Top5: 0.2534722222222222
  VERB   Val Precision: 0.0375, Recall: 0.0625, F1: 0.0398
  NOUN   Val Precision: 0.0714, Recall: 0.0743, F1: 0.0583
  ACTION Val Precision: 0.0143, Recall: 0.0344, F1: 0.0156
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7513623797784934
     NOUN    Mean Top-5 Recall: 0.5347065608434479
     ACTION  Mean Top-5 Recall: 0.2528172489873071
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.08620689655172414  Top5: 0.7068965517241379
    @ 0.50s  Top1: 0.09375  Top5: 0.734375
    @ 0.75s  Top1: 0.08450704225352113  Top5: 0.73239

Epoch 9 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 9 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 9/20 | Time 12.4s
  Train Loss: 4.5273 | Val Loss: 7.3966
  VERB   Train Top1: 0.43867403314917125, Top5: 0.9337016574585635; Val Top1: 0.1753472222222222, Top5: 0.7552083333333334
  NOUN   Train Top1: 0.3237569060773481, Top5: 0.7303867403314918; Val Top1: 0.16319444444444445, Top5: 0.53125
  ACTION Train Top1: 0.36685082872928176, Top5: 0.7027624309392265; Val Top1: 0.11805555555555555, Top5: 0.3107638888888889
  VERB   Val Precision: 0.1425, Recall: 0.1178, F1: 0.1025
  NOUN   Val Precision: 0.0843, Recall: 0.0968, F1: 0.0752
  ACTION Val Precision: 0.0428, Recall: 0.0578, F1: 0.0399
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7546253566998117
     NOUN    Mean Top-5 Recall: 0.5296121494241753
     ACTION  Mean Top-5 Recall: 0.3110777455319731
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1724137931034483  Top5: 0.7413793103448276
    @ 0.50s  Top1: 0.21875  Top5: 0.765625
    @ 0.75s  Top1: 0.16901408450704225  Top5: 0.74647887323943

Epoch 10 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 10 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 10/20 | Time 11.8s
  Train Loss: 4.0530 | Val Loss: 7.4350
  VERB   Train Top1: 0.481767955801105, Top5: 0.9370165745856354; Val Top1: 0.2569444444444444, Top5: 0.7361111111111112
  NOUN   Train Top1: 0.4287292817679558, Top5: 0.8; Val Top1: 0.21006944444444445, Top5: 0.5416666666666666
  ACTION Train Top1: 0.4895027624309392, Top5: 0.7679558011049724; Val Top1: 0.16319444444444445, Top5: 0.2899305555555556
  VERB   Val Precision: 0.1546, Recall: 0.1414, F1: 0.1389
  NOUN   Val Precision: 0.1566, Recall: 0.1609, F1: 0.1330
  ACTION Val Precision: 0.0375, Recall: 0.0747, F1: 0.0473
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7352853432372439
     NOUN    Mean Top-5 Recall: 0.5411522208353896
     ACTION  Mean Top-5 Recall: 0.2902033514911241
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1896551724137931  Top5: 0.7241379310344828
    @ 0.50s  Top1: 0.234375  Top5: 0.71875
    @ 0.75s  Top1: 0.19718309859154928  Top5: 0.7183098591549296
   

Epoch 11 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 11 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 11/20 | Time 12.4s
  Train Loss: 3.3494 | Val Loss: 7.6260
  VERB   Train Top1: 0.5922651933701657, Top5: 0.9712707182320443; Val Top1: 0.17708333333333334, Top5: 0.7256944444444444
  NOUN   Train Top1: 0.4950276243093923, Top5: 0.8574585635359117; Val Top1: 0.16319444444444445, Top5: 0.5260416666666666
  ACTION Train Top1: 0.5668508287292817, Top5: 0.8574585635359117; Val Top1: 0.11631944444444445, Top5: 0.2760416666666667
  VERB   Val Precision: 0.0745, Recall: 0.0938, F1: 0.0767
  NOUN   Val Precision: 0.0784, Recall: 0.0749, F1: 0.0638
  ACTION Val Precision: 0.0503, Recall: 0.0578, F1: 0.0476
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7256392977254041
     NOUN    Mean Top-5 Recall: 0.5274495534265817
     ACTION  Mean Top-5 Recall: 0.27564342189597557
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.15517241379310345  Top5: 0.7241379310344828
    @ 0.50s  Top1: 0.1875  Top5: 0.734375
    @ 0.75s  Top1: 0.18309859154929578  Top5: 0.71

Epoch 12 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 12 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 12/20 | Time 12.6s
  Train Loss: 2.8736 | Val Loss: 7.5044
  VERB   Train Top1: 0.5779005524861879, Top5: 0.9812154696132597; Val Top1: 0.2048611111111111, Top5: 0.7916666666666666
  NOUN   Train Top1: 0.523756906077348, Top5: 0.9193370165745857; Val Top1: 0.21180555555555555, Top5: 0.4618055555555556
  ACTION Train Top1: 0.630939226519337, Top5: 0.9458563535911603; Val Top1: 0.1527777777777778, Top5: 0.3020833333333333
  VERB   Val Precision: 0.1522, Recall: 0.1109, F1: 0.1117
  NOUN   Val Precision: 0.1735, Recall: 0.1852, F1: 0.1348
  ACTION Val Precision: 0.0709, Recall: 0.0930, F1: 0.0742
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7926618130337015
     NOUN    Mean Top-5 Recall: 0.4637590745172073
     ACTION  Mean Top-5 Recall: 0.3037301567574026
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.15517241379310345  Top5: 0.8103448275862069
    @ 0.50s  Top1: 0.203125  Top5: 0.8125
    @ 0.75s  Top1: 0.18309859154929578  Top5: 0.8169014

Epoch 13 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 13 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 13/20 | Time 12.7s
  Train Loss: 2.3620 | Val Loss: 7.4503
  VERB   Train Top1: 0.6994475138121546, Top5: 0.9900552486187846; Val Top1: 0.1753472222222222, Top5: 0.7690972222222222
  NOUN   Train Top1: 0.6850828729281768, Top5: 0.938121546961326; Val Top1: 0.1545138888888889, Top5: 0.5572916666666666
  ACTION Train Top1: 0.7414364640883978, Top5: 0.9569060773480663; Val Top1: 0.11805555555555555, Top5: 0.3315972222222222
  VERB   Val Precision: 0.1268, Recall: 0.1042, F1: 0.1052
  NOUN   Val Precision: 0.1076, Recall: 0.0922, F1: 0.0903
  ACTION Val Precision: 0.0804, Recall: 0.0629, F1: 0.0584
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7690428475100187
     NOUN    Mean Top-5 Recall: 0.5559794901170072
     ACTION  Mean Top-5 Recall: 0.3313344651255936
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.13793103448275862  Top5: 0.7758620689655172
    @ 0.50s  Top1: 0.15625  Top5: 0.75
    @ 0.75s  Top1: 0.14084507042253522  Top5: 0.760563380

Epoch 14 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 14 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 14/20 | Time 12.1s
  Train Loss: 2.1691 | Val Loss: 7.5110
  VERB   Train Top1: 0.7325966850828729, Top5: 0.9845303867403314; Val Top1: 0.16319444444444445, Top5: 0.7604166666666666
  NOUN   Train Top1: 0.7016574585635359, Top5: 0.9657458563535911; Val Top1: 0.21006944444444445, Top5: 0.5208333333333334
  ACTION Train Top1: 0.7756906077348066, Top5: 0.9756906077348066; Val Top1: 0.1527777777777778, Top5: 0.3246527777777778
  VERB   Val Precision: 0.1692, Recall: 0.1306, F1: 0.1402
  NOUN   Val Precision: 0.1507, Recall: 0.1648, F1: 0.1425
  ACTION Val Precision: 0.0890, Recall: 0.0820, F1: 0.0718
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7611113106906393
     NOUN    Mean Top-5 Recall: 0.5214889787486515
     ACTION  Mean Top-5 Recall: 0.3244207361359738
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.13793103448275862  Top5: 0.7758620689655172
    @ 0.50s  Top1: 0.140625  Top5: 0.765625
    @ 0.75s  Top1: 0.14084507042253522  Top5: 0.77

Epoch 15 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 15 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 15/20 | Time 12.4s
  Train Loss: 1.8212 | Val Loss: 7.6282
  VERB   Train Top1: 0.8022099447513812, Top5: 0.9900552486187846; Val Top1: 0.1753472222222222, Top5: 0.8159722222222222
  NOUN   Train Top1: 0.7502762430939226, Top5: 0.9834254143646409; Val Top1: 0.13194444444444445, Top5: 0.5815972222222222
  ACTION Train Top1: 0.830939226519337, Top5: 0.9856353591160221; Val Top1: 0.109375, Top5: 0.2829861111111111
  VERB   Val Precision: 0.2143, Recall: 0.1064, F1: 0.1176
  NOUN   Val Precision: 0.1295, Recall: 0.1124, F1: 0.1101
  ACTION Val Precision: 0.0394, Recall: 0.0435, F1: 0.0335
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8159064782281695
     NOUN    Mean Top-5 Recall: 0.5826181231890387
     ACTION  Mean Top-5 Recall: 0.28252060472573515
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1724137931034483  Top5: 0.7931034482758621
    @ 0.50s  Top1: 0.171875  Top5: 0.84375
    @ 0.75s  Top1: 0.16901408450704225  Top5: 0.816901408450704

Epoch 16 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 16 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 16/20 | Time 11.8s
  Train Loss: 1.5064 | Val Loss: 7.5597
  VERB   Train Top1: 0.8298342541436464, Top5: 0.9955801104972376; Val Top1: 0.1736111111111111, Top5: 0.7951388888888888
  NOUN   Train Top1: 0.8198895027624309, Top5: 0.9911602209944751; Val Top1: 0.21180555555555555, Top5: 0.5607638888888888
  ACTION Train Top1: 0.9038674033149171, Top5: 0.988950276243094; Val Top1: 0.1284722222222222, Top5: 0.3420138888888889
  VERB   Val Precision: 0.1986, Recall: 0.1402, F1: 0.1571
  NOUN   Val Precision: 0.1689, Recall: 0.1831, F1: 0.1539
  ACTION Val Precision: 0.0681, Recall: 0.0712, F1: 0.0602
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7944474132448138
     NOUN    Mean Top-5 Recall: 0.5615584250225175
     ACTION  Mean Top-5 Recall: 0.3416407335444293
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1724137931034483  Top5: 0.7758620689655172
    @ 0.50s  Top1: 0.171875  Top5: 0.78125
    @ 0.75s  Top1: 0.16901408450704225  Top5: 0.816901

Epoch 17 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 17 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 17/20 | Time 11.9s
  Train Loss: 1.3812 | Val Loss: 7.6437
  VERB   Train Top1: 0.8453038674033149, Top5: 0.9988950276243094; Val Top1: 0.1909722222222222, Top5: 0.7899305555555556
  NOUN   Train Top1: 0.861878453038674, Top5: 0.994475138121547; Val Top1: 0.19791666666666666, Top5: 0.5520833333333334
  ACTION Train Top1: 0.9359116022099447, Top5: 0.994475138121547; Val Top1: 0.109375, Top5: 0.296875
  VERB   Val Precision: 0.1494, Recall: 0.1107, F1: 0.1136
  NOUN   Val Precision: 0.1546, Recall: 0.1814, F1: 0.1407
  ACTION Val Precision: 0.0570, Recall: 0.0550, F1: 0.0481
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7905241294238311
     NOUN    Mean Top-5 Recall: 0.5525020007926426
     ACTION  Mean Top-5 Recall: 0.2961489106119707
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1724137931034483  Top5: 0.8275862068965517
    @ 0.50s  Top1: 0.1875  Top5: 0.765625
    @ 0.75s  Top1: 0.2112676056338028  Top5: 0.8028169014084507
    @ 1.00s  

Epoch 18 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 18 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 18/20 | Time 11.9s
  Train Loss: 1.3327 | Val Loss: 7.6304
  VERB   Train Top1: 0.8861878453038674, Top5: 1.0; Val Top1: 0.1753472222222222, Top5: 0.7621527777777778
  NOUN   Train Top1: 0.881767955801105, Top5: 0.9966850828729282; Val Top1: 0.22395833333333334, Top5: 0.5850694444444444
  ACTION Train Top1: 0.9558011049723757, Top5: 0.9966850828729282; Val Top1: 0.13541666666666666, Top5: 0.3159722222222222
  VERB   Val Precision: 0.2767, Recall: 0.1514, F1: 0.1722
  NOUN   Val Precision: 0.1687, Recall: 0.1674, F1: 0.1493
  ACTION Val Precision: 0.0719, Recall: 0.0729, F1: 0.0624
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7599094640922794
     NOUN    Mean Top-5 Recall: 0.5855353482268649
     ACTION  Mean Top-5 Recall: 0.3153095766627219
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.15517241379310345  Top5: 0.7241379310344828
    @ 0.50s  Top1: 0.15625  Top5: 0.71875
    @ 0.75s  Top1: 0.16901408450704225  Top5: 0.7464788732394366
   

Epoch 19 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 19 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 19/20 | Time 11.6s
  Train Loss: 1.2184 | Val Loss: 7.6643
  VERB   Train Top1: 0.9104972375690608, Top5: 1.0; Val Top1: 0.1892361111111111, Top5: 0.7673611111111112
  NOUN   Train Top1: 0.8972375690607735, Top5: 0.994475138121547; Val Top1: 0.2013888888888889, Top5: 0.5364583333333334
  ACTION Train Top1: 0.9646408839779006, Top5: 0.9955801104972376; Val Top1: 0.11979166666666667, Top5: 0.3177083333333333
  VERB   Val Precision: 0.2289, Recall: 0.1187, F1: 0.1329
  NOUN   Val Precision: 0.1335, Recall: 0.1761, F1: 0.1405
  ACTION Val Precision: 0.0723, Recall: 0.0560, F1: 0.0502
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7646387879975085
     NOUN    Mean Top-5 Recall: 0.5392246578721707
     ACTION  Mean Top-5 Recall: 0.31863099642541043
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1896551724137931  Top5: 0.7068965517241379
    @ 0.50s  Top1: 0.203125  Top5: 0.734375
    @ 0.75s  Top1: 0.19718309859154928  Top5: 0.7605633802816901
  

Epoch 20 Train:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 20 Val:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 20/20 | Time 12.1s
  Train Loss: 1.0638 | Val Loss: 7.6150
  VERB   Train Top1: 0.9038674033149171, Top5: 0.9988950276243094; Val Top1: 0.1684027777777778, Top5: 0.7777777777777778
  NOUN   Train Top1: 0.907182320441989, Top5: 0.9966850828729282; Val Top1: 0.1909722222222222, Top5: 0.578125
  ACTION Train Top1: 0.9767955801104973, Top5: 1.0; Val Top1: 0.1284722222222222, Top5: 0.3020833333333333
  VERB   Val Precision: 0.2137, Recall: 0.1130, F1: 0.1275
  NOUN   Val Precision: 0.1417, Recall: 0.1588, F1: 0.1398
  ACTION Val Precision: 0.0681, Recall: 0.0676, F1: 0.0596
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7775365970238333
     NOUN    Mean Top-5 Recall: 0.5797692450220056
     ACTION  Mean Top-5 Recall: 0.3021642783379425
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.1896551724137931  Top5: 0.7758620689655172
    @ 0.50s  Top1: 0.1875  Top5: 0.765625
    @ 0.75s  Top1: 0.16901408450704225  Top5: 0.8028169014084507
    @ 1.00s  Top