### Import Libraries

In [8]:
from tools.imports import *

from tools.pca import *

###  Configuration  (Loading PCA and Frames)

In [25]:
RGB_FOLDER = Path(r"D:\Datasets\Datasets\EPIC_Kitchen\RGB\P01_01\Original")
FLOW_FOLDER = Path(r"D:\Datasets\Datasets\EPIC_Kitchen\OpticalFlow\P01_01\P01_01")
LABEL_CSV = Path(r"EPIC-Kitchens\Labels\P01_01.csv")

OUTPUT_FUSED_CSV = Path(r"EPIC-Kitchens\Features\P01_01_fused_features_PCA.csv")

PCA_PATH = "pca_2048_to_512.pkl"

SAMPLE_RATE = 1
FEAT_DIM = 512
W_RGB = 0.6
W_FLOW = 0.4

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

OUTPUT_FUSED_CSV.parent.mkdir(parents=True, exist_ok=True)

_frame_number_re = re.compile(r"(\d+)(?=\.[^.]+$)")

def parse_frame_index(fname: str):
    m = _frame_number_re.search(fname)
    if m:
        return int(m.group(1))
    digs = re.findall(r"\d+", fname)
    return int(digs[-1]) if digs else 0

In [26]:
_resnet = resnet50(weights=True)
_resnet = nn.Sequential(*list(_resnet.children())[:-1])
_resnet = _resnet.to(DEVICE).eval()   # output: (B,2048,1,1)

_transform = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],
                [0.229,0.224,0.225])
])

# ------------------ LOAD PCA --------------------------------

pca = joblib.load(PCA_PATH)
assert pca.components_.shape == (FEAT_DIM, 2048), "PCA dimension mismatch"

# ---------------- FEATURE EXTRACTION ------------------------

@torch.no_grad()
def extract_feature_from_pil(pil_img: Image.Image):
    x = _transform(pil_img).unsqueeze(0).to(DEVICE)  # (1,3,224,224)
    feat = _resnet(x).view(-1).cpu().numpy()         # (2048,)
    feat = pca.transform(feat[None, :])[0]          # (512,)
    return feat.astype(np.float32)

# ---------------- MAIN EXTRACTION FUNCTION ------------------

def extract_and_save_fused(csv_labels_path: Path,
                           rgb_folder: Path,
                           flow_folder: Path or None,
                           out_fused_csv: Path,
                           sample_rate: int = 1,
                           w_rgb: float = 0.6,
                           w_flow: float = 0.4):

    labels_df = pd.read_csv(csv_labels_path)

    rgb_files = sorted([
        p for p in rgb_folder.iterdir()
        if p.suffix.lower() in [".jpg", ".png", ".jpeg"]
    ])

    sampled = rgb_files[::sample_rate]
    if len(sampled) == 0:
        raise RuntimeError(f"No frames found in {rgb_folder}")

    fused_rows = []

    for fp in tqdm(sampled, desc="Extract & Fuse"):
        fname = fp.name
        frame_idx = parse_frame_index(fname)

        # -------- RGB --------
        try:
            pil = Image.open(fp).convert("RGB")
            rgb_feat = extract_feature_from_pil(pil)
        except Exception as e:
            print(f"[WARN] RGB skip {fname}: {e}")
            continue

        # -------- FLOW --------
        if flow_folder is not None:
            ffp = flow_folder / fname
            if not ffp.exists():
                ffp = fp
            try:
                pilf = Image.open(ffp).convert("RGB")
                flow_feat = extract_feature_from_pil(pilf)
            except Exception as e:
                print(f"[WARN] FLOW skip {fname}: {e}")
                flow_feat = np.zeros(FEAT_DIM, dtype=np.float32)
        else:
            flow_feat = np.zeros(FEAT_DIM, dtype=np.float32)

        # -------- FUSION --------
        fused_vec = w_rgb * rgb_feat + w_flow * flow_feat

        # -------- LABEL --------
        lr = labels_df[
            (labels_df["StartFrame"] <= frame_idx) &
            (labels_df["EndFrame"] >= frame_idx)
        ]

        if not lr.empty:
            action_label = int(lr.iloc[0].get("ActionLabel", -1))
            action_name = str(lr.iloc[0].get("ActionName", "Unknown"))
        else:
            action_label, action_name = -1, "Unknown"

        row = {
            "frame_idx": int(frame_idx),
            "frame_name": fname,
            "ActionLabel": action_label,
            "ActionName": action_name
        }

        for i, v in enumerate(fused_vec):
            row[f"feat_{i}"] = float(v)

        fused_rows.append(row)

    if len(fused_rows) == 0:
        raise RuntimeError("No fused rows extracted")

    df_fused = pd.DataFrame(fused_rows)
    df_fused.to_csv(out_fused_csv, index=False)

    print(f"[SAVED] {out_fused_csv}")
    return df_fused

# ---------------- RUN ---------------------------------------

df_fused = extract_and_save_fused(
    csv_labels_path=LABEL_CSV,
    rgb_folder=RGB_FOLDER,
    flow_folder=FLOW_FOLDER if FLOW_FOLDER.exists() else None,
    out_fused_csv=OUTPUT_FUSED_CSV,
    sample_rate=SAMPLE_RATE,
    w_rgb=W_RGB,
    w_flow=W_FLOW
)




Extract & Fuse:   0%|          | 0/99029 [00:00<?, ?it/s]

[SAVED] EPIC-Kitchens\Features\P01_01_fused_features_PCA.csv


In [27]:
data=pd.read_csv(r"EPIC-Kitchens\Features\P01_01_fused_features_PCA.csv")
data

Unnamed: 0,frame_idx,frame_name,ActionLabel,ActionName,feat_0,feat_1,feat_2,feat_3,feat_4,feat_5,...,feat_502,feat_503,feat_504,feat_505,feat_506,feat_507,feat_508,feat_509,feat_510,feat_511
0,0,frame_00000.jpg,-1,Unknown,-2.077142,-1.958742,1.458371,-0.939553,-0.097697,0.412270,...,-2.543323,-1.541425,0.645887,3.216516,-1.070724,-1.027977,0.753700,1.632008,-1.996307,0.598704
1,1,frame_00001.jpg,-1,Unknown,-2.049851,-1.923094,1.370188,-0.960223,-0.108530,0.477152,...,-2.415737,-0.603733,1.610335,2.387983,-0.953960,-1.429115,1.073136,0.046997,-1.593017,0.923354
2,2,frame_00002.jpg,-1,Unknown,-2.094827,-1.991464,1.457903,-0.970271,-0.096512,0.402129,...,-2.725905,-1.352319,0.944612,2.991761,-0.892881,-1.111937,0.493358,2.788770,-2.128530,0.231484
3,3,frame_00003.jpg,-1,Unknown,-2.060483,-1.953314,1.269848,-0.938713,-0.100473,0.480697,...,-2.177147,-1.289917,1.466615,1.870313,-0.761906,-0.583820,0.706807,1.122907,-2.200795,-0.132436
4,4,frame_00004.jpg,-1,Unknown,-2.082937,-1.923808,1.390279,-0.944767,-0.147835,0.407454,...,-2.051410,-2.412191,1.017748,2.036701,-1.681342,0.010523,1.532404,2.535606,-2.038063,0.125961
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99024,99024,frame_99024.jpg,-1,Unknown,-0.455013,-0.883913,0.765072,0.023528,0.247795,-0.539051,...,-1.500531,-3.974870,-0.731917,2.519051,-1.005597,0.482168,1.012747,2.992799,-3.198439,-1.515829
99025,99025,frame_99025.jpg,-1,Unknown,-0.475310,-0.839709,0.772038,0.142543,0.608118,-0.851301,...,-2.625215,-3.879248,2.434193,2.173644,-2.014056,-1.257069,1.406288,2.305203,-3.076490,-1.044251
99026,99026,frame_99026.jpg,-1,Unknown,-0.509065,-0.859920,0.765281,0.163925,0.600229,-0.804052,...,-2.636701,-3.654099,1.591920,1.877518,-1.479832,-0.907728,0.783232,2.020646,-2.762307,-0.418235
99027,99027,frame_99027.jpg,-1,Unknown,-0.417607,-0.765814,0.743480,0.247130,0.318030,-0.900123,...,-2.496804,-2.632463,2.228806,3.033227,-1.708652,-3.303389,1.849835,1.904002,-3.552166,1.303865


### Load the CSV file and Features

In [14]:
import pandas as pd
from pathlib import Path

def load_fused_csv_by_path(fused_csv_path: str):
    fp = Path(fused_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Fused features CSV not found: {fp}")
    df = pd.read_csv(fp)
    if "frame_idx" not in df.columns:
        raise KeyError("Fused CSV must contain 'frame_idx' column")
    df["frame_idx"] = df["frame_idx"].astype(int)
    df = df.sort_values("frame_idx").reset_index(drop=True)
    return df

def load_label_csv_by_path(label_csv_path: str):
    fp = Path(label_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Label CSV not found: {fp}")
    df = pd.read_csv(fp)
    return df

# ------------------------------------ Paths ---------------------------
fused_df = load_fused_csv_by_path(r"EPIC-Kitchens\Features\P01_03_fused_features_PCA.csv")
labels_df = load_label_csv_by_path(r"EPIC-Kitchens\Labels\P01_03.csv")

### Dataset Loader

In [15]:
IGNORE_INDEX = -1

class SingleVideoAnticipationDataset(Dataset):
    def __init__(self, fused_df_or_path, labels_df_or_path,
                 t_obs: int, k_fut: int, feat_dim: int,
                 fps: float, horizons_s: list[float]):
        
        # load paths
        if isinstance(fused_df_or_path, (str, Path)):
            fused_df = pd.read_csv(fused_df_or_path)
        else:
            fused_df = fused_df_or_path.copy()
        if isinstance(labels_df_or_path, (str, Path)):
            labels_df = pd.read_csv(labels_df_or_path)
        else:
            labels_df = labels_df_or_path.copy()

        if "frame_idx" not in fused_df.columns:
            raise KeyError("fused_df must contain 'frame_idx'")

        fused_df["frame_idx"] = fused_df["frame_idx"].astype(int)
        self.fused_df = fused_df.set_index("frame_idx", drop=False).sort_index()
        self.labels_df = labels_df.reset_index(drop=True)

        if not all(c in self.labels_df.columns for c in ["StartFrame", "EndFrame"]):
            raise KeyError("labels_df must contain StartFrame and EndFrame")

        self.t_obs = int(t_obs)
        self.k_fut = int(k_fut)
        self.feat_dim = int(feat_dim)
        self.feat_cols = [f"feat_{i}" for i in range(self.feat_dim)]

        # NEW: time info
        self.fps = float(fps)
        assert len(horizons_s) == self.k_fut, "len(horizons_s) must equal k_fut"
        self.horizons_s = list(horizons_s)

        # samples: one per label row (use EndFrame as obs_end)
        self.samples = []
        for ridx, row in self.labels_df.iterrows():
            try:
                obs_end = int(row["EndFrame"])
            except:
                continue
            self.samples.append({"label_row_idx": int(ridx), "obs_end": obs_end})

        if len(self.samples) == 0:
            raise RuntimeError("No valid label rows found")

    def __len__(self):
        return len(self.samples)

    def _time_based_future_labels(self, obs_end: int):
        labels_df = self.labels_df

        def pick(cols):
            for c in cols:
                if c in labels_df.columns:
                    return c
            return None

        vcol = pick(["Verb_class","verb","Verb","verb_class"])
        ncol = pick(["Noun_class","noun","Noun","noun_class"])
        acol = pick(["Action_class","action","Action","ActionLabel"])

        verb_targets   = []
        noun_targets   = []
        action_targets = []

        for h_sec in self.horizons_s:
            future_frame = obs_end + int(round(h_sec * self.fps))
            seg = labels_df[(labels_df["StartFrame"] <= future_frame) &
                            (labels_df["EndFrame"]   >= future_frame)]
            if seg.empty:
                # nothing happening at that exact time
                verb_targets.append(IGNORE_INDEX)
                noun_targets.append(IGNORE_INDEX)
                action_targets.append(IGNORE_INDEX)
            else:
                row = seg.iloc[0]
                if vcol is not None and not pd.isna(row[vcol]):
                    verb_targets.append(int(row[vcol]))
                else:
                    verb_targets.append(IGNORE_INDEX)

                if ncol is not None and not pd.isna(row[ncol]):
                    noun_targets.append(int(row[ncol]))
                else:
                    noun_targets.append(IGNORE_INDEX)

                if acol is not None and not pd.isna(row[acol]):
                    action_targets.append(int(row[acol]))
                else:
                    action_targets.append(IGNORE_INDEX)

        return {
            "verb":   torch.LongTensor(verb_targets),
            "noun":   torch.LongTensor(noun_targets),
            "action": torch.LongTensor(action_targets)
        }

    def __getitem__(self, idx):
        rec = self.samples[idx]
        obs_end = rec["obs_end"]
        obs_start = obs_end - (self.t_obs - 1)
        if obs_start < 0:
            obs_start = 0
            obs_end = obs_start + (self.t_obs - 1)

        fused_idx_min = int(self.fused_df.index.min())
        fused_idx_max = int(self.fused_df.index.max())
        obs_end = min(obs_end, fused_idx_max)
        obs_start = max(obs_end - (self.t_obs - 1), fused_idx_min)

        desired = list(range(obs_start, obs_end + 1))
        sel = self.fused_df.reindex(desired).fillna(method="ffill").fillna(method="bfill").fillna(0.0)

        if sel.shape[0] < self.t_obs:
            if sel.shape[0] == 0:
                zero_row = {c:0.0 for c in self.feat_cols}
                sel = pd.DataFrame([zero_row] * self.t_obs)
            else:
                first = sel.iloc[[0]]
                pads = pd.concat([first] * (self.t_obs - sel.shape[0]), ignore_index=True)
                sel = pd.concat([pads, sel.reset_index(drop=True)], ignore_index=True)

        for c in self.feat_cols:
            if c not in sel.columns:
                sel[c] = 0.0

        F_window = torch.from_numpy(sel[self.feat_cols].values).float()   # (T_obs, FEAT_DIM)

        y_multi = self._time_based_future_labels(obs_end)

        meta = {"obs_start": int(obs_start),
                "obs_end":   int(obs_end),
                "label_row_idx": int(rec["label_row_idx"])}
        return F_window, y_multi, meta

### Graph Construction
_Graph Construction using kNN strategy and then apply GAT_

In [16]:
K=5  
DROP=0.1

def build_topk_edge_index(features: torch.Tensor, k=K):
    Tn = int(features.size(0))
    x = F.normalize(features, dim=1)
    sim = torch.matmul(x, x.t())   # (T,T) T is the number of frame-> features
    sim.fill_diagonal_(-1.0)
    vals, idxs = torch.topk(sim, k, dim=1)
    src = torch.arange(Tn).unsqueeze(1).expand(-1, k).reshape(-1)
    dst = idxs.reshape(-1)
    edge = torch.stack([src, dst], dim=0)
    edge_rev = torch.stack([dst, src], dim=0)
    return torch.cat([edge, edge_rev], dim=1).long()

class BatchedGAT(nn.Module):
    def __init__(self, in_dim, hid_dim=None, num_layers=3, heads=8, dropout=DROP):
        super().__init__()
        hid = hid_dim or in_dim
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_ch = in_dim if i==0 else hid
            self.convs.append(GATConv(in_ch, hid//heads, heads=heads, concat=True, dropout=dropout))
        self.proj = nn.Linear(hid, in_dim)
        self.norm = nn.LayerNorm(in_dim)
        self.act = nn.GELU()
    def forward(self, pyg_batch: PyGBatch, T_per_sample: int):
        x = pyg_batch.x; edge_index = pyg_batch.edge_index
        h = x
        for conv in self.convs:
            h = conv(h, edge_index); h = self.act(h)
        h = self.proj(h)
        node_feats, mask = to_dense_batch(h, batch=pyg_batch.batch)
        B, max_nodes, D = node_feats.shape
        if max_nodes < T_per_sample:
            pad = torch.zeros(B, T_per_sample - max_nodes, D, device=node_feats.device)
            node_feats = torch.cat([node_feats, pad], dim=1)
        elif max_nodes > T_per_sample:
            node_feats = node_feats[:, :T_per_sample, :]
        return self.norm(node_feats) 

### Fusion (Graph +Transformer) + Transfomer Decoder


In [17]:
class GETR(nn.Module):
    def __init__(self, d_model, nhead=8, num_layers=3, dim_feedforward=2048, dropout=DROP, max_len=1000):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=num_layers)
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model))
    def forward(self, x):
        B,T,D = x.shape
        pos = self.pos_emb[:, :T, :].to(x.device)
        return self.encoder(x + pos)

class AnticipationModel(nn.Module):
    def __init__(self, feat_dim, num_classes: dict, k_fut=5, gat_layers=3, gat_heads=8, dec_layers=3, dec_heads=8, dropout=DROP):
        super().__init__()
        self.feat_dim = feat_dim; self.k_fut = k_fut
        self.gat = BatchedGAT(in_dim=feat_dim, hid_dim=feat_dim, num_layers=gat_layers, heads=gat_heads, dropout=dropout)
        self.encoder = GETR(d_model=feat_dim, nhead=dec_heads, num_layers=3)
        dec_layer = nn.TransformerDecoderLayer(d_model=feat_dim, nhead=dec_heads, dim_feedforward=feat_dim*4, dropout=dropout, activation='gelu', batch_first=True)
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=dec_layers)
        self.queries = nn.Parameter(torch.randn(1, k_fut, feat_dim))
        assert isinstance(num_classes, dict)
        self.verb_head = nn.Linear(feat_dim, num_classes["verb"])
        self.noun_head = nn.Linear(feat_dim, num_classes["noun"])
        self.action_head = nn.Linear(feat_dim, num_classes["action"])

    def forward(self, F_batch):
        # F_batch: (B, T, D)
        B,T,D = F_batch.shape; device = F_batch.device
        data_list=[]
        for b in range(B):
            x = F_batch[b]
            edge_index = build_topk_edge_index(x.detach().cpu(), k=K).to(device)
            data_list.append(PyGData(x=x, edge_index=edge_index))
        pyg_batch = PyGBatch.from_data_list(data_list).to(device)
        G = self.gat(pyg_batch, T_per_sample=T)   # (B,T,D)
        H = self.encoder(F_batch)                 # (B,T,D)
        U = H + G
        q = self.queries.expand(B, -1, -1).to(device)
        dec_out = self.decoder(tgt=q, memory=U)   # (B, K_fut, D)
        return {"verb": self.verb_head(dec_out),
                "noun": self.noun_head(dec_out),
                "action": self.action_head(dec_out)}

### Cross-Entropy (Masked means Loss for Verb, Noun, Action)

In [18]:
import torch.nn.functional as F

def masked_cross_entropy(logits, labels, ignore_index=IGNORE_INDEX):
    B, K, C = logits.shape
    logits_flat = logits.view(B * K, C)      # (B*K, C)
    labels_flat = labels.view(B * K)         # (B*K,)

    loss_flat = F.cross_entropy(
        logits_flat,
        labels_flat,
        reduction='none',
        ignore_index=ignore_index
    )  # (B*K,)

    mask = (labels_flat != ignore_index).float()
    valid = mask.sum()

    if valid == 0:
        return (logits_flat * 0.0).sum()

    return (loss_flat * mask).sum() / valid


def topk_accuracy_per_task(logits, labels, topk=(1,5), ignore_index=IGNORE_INDEX):
    B,K,C = logits.shape
    res = {}
    overall = {k:0 for k in topk}
    total_cnt = 0
    preds_topk = logits.topk(max(topk), dim=-1)[1]  # (B,K,maxk)
    for h in range(K):
        lab = labels[:,h]; mask = (lab != ignore_index); cnt = int(mask.sum().item())
        for k in topk:
            if cnt == 0:
                res.setdefault(f"per_h{h+1}_top{k}", None)
                continue
            predk = preds_topk[:,h,:k]  # (B,k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            hits = (predk == lab_exp)
            hit = int(hits[mask].any(dim=1).float().sum().item())
            res[f"per_h{h+1}_top{k}"] = hit / cnt
            overall[k] += hit
        total_cnt += cnt
    for k in topk:
        res[f"overall_top{k}"] = overall[k] / total_cnt if total_cnt>0 else None
    return res

### Training and Validation

In [19]:

FUSED_CSV_PATH = r"EPIC-Kitchens\Features\P01_03_fused_features_PCA.csv"
LABEL_CSV_PATH = r"EPIC-Kitchens\Labels\P01_03.csv"
BEST_MODEL_PATH = Path(r"EPIC-Kitchens\Model\P01_03_fused_model_PCA.pth")


# Hyperparams
T_OBS = 90
FEAT_DIM = 512
BATCH_SIZE = 8
NUM_EPOCHS = 20
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 0

# === Time-based anticipation config ===
FPS = 30.0 
HORIZONS_S = [0.25, 0.5, 0.75, 1.0, 1.25, 1.50, 1.75, 2.0] 
K_FUT = len(HORIZONS_S) 

In [28]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def detect_num_classes_from_labels_df(labels_df):
    verbs = set()
    nouns = set()
    actions = set()
    for cand in ["Verb_class", "verb", "Verb", "verb_class"]:
        if cand in labels_df.columns:
            verbs.update(labels_df[cand].dropna().astype(int).tolist())
            break
    for cand in ["Noun_class", "noun", "Noun", "noun_class"]:
        if cand in labels_df.columns:
            nouns.update(labels_df[cand].dropna().astype(int).tolist())
            break
    for cand in ["Action_class", "action", "Action", "ActionLabel"]:
        if cand in labels_df.columns:
            actions.update(labels_df[cand].dropna().astype(int).tolist())
            break
    nv = (max(verbs) + 1) if len(verbs) > 0 else 1
    nn_ = (max(nouns) + 1) if len(nouns) > 0 else 1
    na = (max(actions) + 1) if len(actions) > 0 else 1
    return {"verb": int(nv), "noun": int(nn_), "action": int(na)}


def topk_counts(logits, labels, k):
    # logits: (B, K_fut, C); labels: (B, K_fut)
    with torch.no_grad():
        B, K, C = logits.shape
        topk_preds = logits.topk(k, dim=-1)[1]  # (B, K, k)
        hits = 0
        total = 0
        for h in range(K):
            lab = labels[:, h]  # (B,)
            mask = (lab != IGNORE_INDEX)
            if int(mask.sum().item()) == 0:
                continue
            predk = topk_preds[:, h, :]  # (B, k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            masked_pred = predk[mask]   # (M, k)
            masked_lab = lab_exp[mask]  # (M, k)
            hit_vec = (masked_pred == masked_lab).any(dim=1).float()
            hits += int(hit_vec.sum().item())
            total += int(mask.sum().item())
        return hits, total


# Load fused and labels
fused_df = pd.read_csv(FUSED_CSV_PATH)
labels_df = pd.read_csv(LABEL_CSV_PATH)

# Dataset
dataset = SingleVideoAnticipationDataset(
    fused_df,
    labels_df,
    t_obs=T_OBS,
    k_fut=K_FUT,        # must equal len(HORIZONS_S)
    feat_dim=FEAT_DIM,
    fps=FPS,
    horizons_s=HORIZONS_S
)

# split indices for train/val (60/40)
indices = list(range(len(dataset)))
random.seed(42)
random.shuffle(indices)
split_at = int(0.6 * len(indices))
train_idx = indices[:split_at]
val_idx = indices[split_at:]

train_ds = Subset(dataset, train_idx)
val_ds = Subset(dataset, val_idx)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=(DEVICE == "cuda")
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=(DEVICE == "cuda")
)

num_classes = detect_num_classes_from_labels_df(labels_df)
print("Detected classes:", num_classes)

model = AnticipationModel(
    feat_dim=FEAT_DIM,
    num_classes=num_classes,
    k_fut=K_FUT
).to(DEVICE)



# ================= MODEL PARAMETER COUNT =================
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

total_params, trainable_params = count_parameters(model)

print("\n========== MODEL PARAMETERS ==========")
print(f"Total parameters     : {total_params/1e6:.1f} M")
print(f"Trainable parameters : {trainable_params/1e6:.2f} M")
print("=====================================\n")


# ================= MODEL FLOPs =================
from thop import profile

model.eval()

dummy_input = torch.randn(
    1, T_OBS, FEAT_DIM
).to(DEVICE)

macs, params = profile(
    model,
    inputs=(dummy_input,),
    verbose=False
)

print("========== MODEL FLOPs ==========")
print(f"MACs  : {macs/1e9:.2f} G")
print(f"FLOPs : {(2*macs)/1e9:.2f} G")
print("================================\n")



# ================= INFERENCE LATENCY =================
def measure_latency(model, device, runs=100):
    model.eval()
    dummy = torch.randn(1, T_OBS, FEAT_DIM).to(device)

    # warm-up
    for _ in range(10):
        _ = model(dummy)

    if device == "cuda":
        torch.cuda.synchronize()

    start = time.time()
    for _ in range(runs):
        _ = model(dummy)

    if device == "cuda":
        torch.cuda.synchronize()

    avg_latency = (time.time() - start) / runs
    return avg_latency * 1000  # ms


latency_ms = measure_latency(model, DEVICE)

print("========== INFERENCE LATENCY ==========")
print(f"Average latency per sample: {latency_ms:.2f} ms")
print("======================================\n")



optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=3
)

best_val_loss = float("inf")




Detected classes: {'verb': 81, 'noun': 154, 'action': 32}

Total parameters     : 23.8 M
Trainable parameters : 23.78 M

MACs  : 0.08 G
FLOPs : 0.15 G

Average latency per sample: 20.26 ms



In [45]:
for epoch in range(1, NUM_EPOCHS + 1):
    t0 = time.time()

    # ------------- TRAIN -------------
    model.train()
    train_loss_sum = 0.0
    train_samples = 0
    train_counts = {
        "verb_top1": [0, 0], "verb_top5": [0, 0],
        "noun_top1": [0, 0], "noun_top5": [0, 0],
        "action_top1": [0, 0], "action_top5": [0, 0]
    }

    pbar = tqdm(train_loader, desc=f"Epoch {epoch} Train", leave=False)
    for F_batch, y_multi, meta in pbar:
        F_batch = F_batch.to(DEVICE)               # (B, T, D)
        y_v = y_multi["verb"].to(DEVICE)           # (B, K_fut)
        y_n = y_multi["noun"].to(DEVICE)
        y_a = y_multi["action"].to(DEVICE)

        opt.zero_grad()
        logits = model(F_batch)   # dict: "verb"/"noun"/"action" -> (B, K_fut, C)

        loss_v = masked_cross_entropy(logits["verb"], y_v)
        loss_n = masked_cross_entropy(logits["noun"], y_n)
        loss_a = masked_cross_entropy(logits["action"], y_a)
        loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

        loss.backward()
        opt.step()

        b = F_batch.size(0)
        train_loss_sum += float(loss.item()) * b
        train_samples += b

        for (task, lab, lg) in [
            ("verb", y_v, logits["verb"]),
            ("noun", y_n, logits["noun"]),
            ("action", y_a, logits["action"])
        ]:
            h1, t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
            h5, t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
            train_counts[f"{task}_top1"][0] += h1
            train_counts[f"{task}_top1"][1] += t1
            train_counts[f"{task}_top5"][0] += h5
            train_counts[f"{task}_top5"][1] += t5

    train_loss = train_loss_sum / max(1, train_samples)
    train_metrics = {}
    for task in ["verb", "noun", "action"]:
        h1, t1 = train_counts[f"{task}_top1"]
        h5, t5 = train_counts[f"{task}_top5"]
        train_metrics[f"{task}_top1"] = (h1 / t1) if t1 > 0 else None
        train_metrics[f"{task}_top5"] = (h5 / t5) if t5 > 0 else None

    # ------------- VALIDATION -------------
    model.eval()
    val_loss_sum = 0.0
    val_samples = 0
    val_counts = {
        "verb_top1": [0, 0], "verb_top5": [0, 0],
        "noun_top1": [0, 0], "noun_top5": [0, 0],
        "action_top1": [0, 0], "action_top5": [0, 0]
    }

    # store logits/labels for per-horizon + P/R/F1 metrics
    val_logits_store = {"verb": [], "noun": [], "action": []}
    val_labels_store = {"verb": [], "noun": [], "action": []}

    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f"Epoch {epoch} Val", leave=False)
        for F_batch, y_multi, meta in pbar:
            F_batch = F_batch.to(DEVICE)
            y_v = y_multi["verb"].to(DEVICE)
            y_n = y_multi["noun"].to(DEVICE)
            y_a = y_multi["action"].to(DEVICE)

            logits = model(F_batch)
            loss_v = masked_cross_entropy(logits["verb"], y_v)
            loss_n = masked_cross_entropy(logits["noun"], y_n)
            loss_a = masked_cross_entropy(logits["action"], y_a)
            loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

            b = F_batch.size(0)
            val_loss_sum += float(loss.item()) * b
            val_samples += b

            for (task, lab, lg) in [
                ("verb", y_v, logits["verb"]),
                ("noun", y_n, logits["noun"]),
                ("action", y_a, logits["action"])
            ]:
                h1, t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
                h5, t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
                val_counts[f"{task}_top1"][0] += h1
                val_counts[f"{task}_top1"][1] += t1
                val_counts[f"{task}_top5"][0] += h5
                val_counts[f"{task}_top5"][1] += t5

            # store for per-horizon + P/R/F1 metrics
            val_logits_store["verb"].append(logits["verb"].detach().cpu())
            val_logits_store["noun"].append(logits["noun"].detach().cpu())
            val_logits_store["action"].append(logits["action"].detach().cpu())
            val_labels_store["verb"].append(y_v.detach().cpu())
            val_labels_store["noun"].append(y_n.detach().cpu())
            val_labels_store["action"].append(y_a.detach().cpu())

    val_loss = val_loss_sum / max(1, val_samples)

    # overall val metrics (top-1/top-5 over all horizons)
    val_metrics = {}
    for task in ["verb", "noun", "action"]:
        h1, t1 = val_counts[f"{task}_top1"]
        h5, t5 = val_counts[f"{task}_top5"]
        val_metrics[f"{task}_top1"] = (h1 / t1) if t1 > 0 else None
        val_metrics[f"{task}_top5"] = (h5 / t5) if t5 > 0 else None

    # per-horizon metrics (time-based)
    per_horizon_metrics = {"verb": {}, "noun": {}, "action": {}}
    for task in ["verb", "noun", "action"]:
        if len(val_logits_store[task]) == 0:
            continue
        logits_all = torch.cat(val_logits_store[task], dim=0)  # (N, K_fut, C)
        labels_all = torch.cat(val_labels_store[task], dim=0)  # (N, K_fut)
        m = topk_accuracy_per_task(
            logits_all,
            labels_all,
            topk=(1, 5),
            ignore_index=IGNORE_INDEX
        )
        per_horizon_metrics[task] = m

    # macro precision / recall / F1 over all horizons (validation)
    prf_metrics = {"verb": {}, "noun": {}, "action": {}}
    for task in ["verb", "noun", "action"]:
        if len(val_logits_store[task]) == 0:
            continue

        logits_all = torch.cat(val_logits_store[task], dim=0)
        labels_all = torch.cat(val_labels_store[task], dim=0)

        preds_all = logits_all.argmax(dim=-1)  # (N, K_fut)
        mask = (labels_all != IGNORE_INDEX)
        if mask.sum().item() == 0:
            continue

        y_true = labels_all[mask].numpy()
        y_pred = preds_all[mask].numpy()

        p, r, f1, _ = precision_recall_fscore_support(
            y_true,
            y_pred,
            average="macro",
            zero_division=0
        )
        prf_metrics[task]["precision"] = p
        prf_metrics[task]["recall"] = r
        prf_metrics[task]["f1"] = f1

    # mean Top-5 recall across horizons for each task
    mean_top5_recall = {}
    for task in ["verb", "noun", "action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            mean_top5_recall[task] = None
            continue

        vals = []
        for h_idx in range(K_FUT):
            key = f"per_h{h_idx+1}_top5"
            if key in mh and mh[key] is not None:
                vals.append(mh[key])
        mean_top5_recall[task] = float(np.mean(vals)) if len(vals) > 0 else None

    # scheduler + logging
    sched.step(val_loss)
    elapsed = time.time() - t0

    print(f"Epoch {epoch}/{NUM_EPOCHS} | Time {elapsed:.1f}s")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    for task in ["verb", "noun", "action"]:
        print(
            f"  {task.upper():6s} Train Top1: {train_metrics[f'{task}_top1']}, "
            f"Top5: {train_metrics[f'{task}_top5']}; "
            f"Val Top1: {val_metrics[f'{task}_top1']}, "
            f"Top5: {val_metrics[f'{task}_top5']}"
        )

    # print macro precision / recall / F1 (validation)
    for task in ["verb", "noun", "action"]:
        if prf_metrics[task]:
            p = prf_metrics[task]["precision"]
            r = prf_metrics[task]["recall"]
            f1 = prf_metrics[task]["f1"]
            print(
                f"  {task.upper():6s} Val Precision: {p:.4f}, "
                f"Recall: {r:.4f}, F1: {f1:.4f}"
            )

    # print mean Top-5 recall
    print("  ---- Mean Top-5 Recall (validation) ----")
    for task in ["verb", "noun", "action"]:
        print(
            f"     {task.upper():6s}  Mean Top-5 Recall: {mean_top5_recall[task]}"
        )

    # print per-horizon by seconds
    for task in ["verb", "noun", "action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            continue
        print(f"  {task.upper():6s} per-horizon (time-based):")
        for h_idx, t_sec in enumerate(HORIZONS_S):
            key1 = f"per_h{h_idx+1}_top1"
            key5 = f"per_h{h_idx+1}_top5"
            v1 = mh.get(key1, None)
            v5 = mh.get(key5, None)
            print(f"    @ {t_sec:4.2f}s  Top1: {v1}  Top5: {v5}")
        print(
            f"    overall_top1: {mh.get('overall_top1', None)}, "
            f"overall_top5: {mh.get('overall_top5', None)}"
        )

    # optional: save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(
            {
                'epoch': epoch,
                'model_state': model.state_dict(),
                'opt_state': opt.state_dict(),
                'val_loss': val_loss
            },
            BEST_MODEL_PATH
        )
        print(f"[SAVED BEST] -> {BEST_MODEL_PATH}")

print("Training finished.")

Epoch 1 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1/20 | Time 0.8s
  Train Loss: 8.3991 | Val Loss: 8.6126
  VERB   Train Top1: 0.010416666666666666, Top5: 0.08333333333333333; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  

Epoch 2 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 2 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2/20 | Time 0.8s
  Train Loss: 8.7402 | Val Loss: 8.6126
  VERB   Train Top1: 0.020833333333333332, Top5: 0.09375; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  Top5: 0.1
  

Epoch 3 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 3 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 3/20 | Time 0.7s
  Train Loss: 8.6364 | Val Loss: 8.6126
  VERB   Train Top1: 0.03125, Top5: 0.08333333333333333; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.10416666666666667; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  Top5: 0.1
   

Epoch 4 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 4 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 4/20 | Time 0.8s
  Train Loss: 8.7174 | Val Loss: 8.6126
  VERB   Train Top1: 0.03125, Top5: 0.09375; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.010416666666666666; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.010416666666666666, Top5: 0.09375; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  Top

Epoch 5 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 5 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 5/20 | Time 0.7s
  Train Loss: 8.3792 | Val Loss: 8.6126
  VERB   Train Top1: 0.03125, Top5: 0.07291666666666667; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.010416666666666666; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.13541666666666666; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0

Epoch 6 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 6 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 6/20 | Time 0.8s
  Train Loss: 8.3419 | Val Loss: 8.6126
  VERB   Train Top1: 0.03125, Top5: 0.07291666666666667; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.10416666666666667; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  Top5: 0.1
   

Epoch 7 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 7 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 7/20 | Time 0.9s
  Train Loss: 8.3931 | Val Loss: 8.6126
  VERB   Train Top1: 0.020833333333333332, Top5: 0.08333333333333333; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  

Epoch 8 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 8 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 8/20 | Time 0.8s
  Train Loss: 8.3326 | Val Loss: 8.6126
  VERB   Train Top1: 0.020833333333333332, Top5: 0.0625; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.010416666666666666; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0

Epoch 9 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 9 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 9/20 | Time 0.8s
  Train Loss: 8.7181 | Val Loss: 8.6126
  VERB   Train Top1: 0.020833333333333332, Top5: 0.0625; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  Top5: 0.1
   

Epoch 10 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 10 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 10/20 | Time 0.8s
  Train Loss: 8.7479 | Val Loss: 8.6126
  VERB   Train Top1: 0.041666666666666664, Top5: 0.08333333333333333; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.010416666666666666; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.010416666666666666, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0 

Epoch 11 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 11 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 11/20 | Time 0.8s
  Train Loss: 8.3571 | Val Loss: 8.6126
  VERB   Train Top1: 0.020833333333333332, Top5: 0.07291666666666667; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.010416666666666666; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.010416666666666666, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0 

Epoch 12 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 12 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 12/20 | Time 0.8s
  Train Loss: 8.7597 | Val Loss: 8.6126
  VERB   Train Top1: 0.03125, Top5: 0.07291666666666667; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  Top5: 0.1
  

Epoch 13 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 13 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 13/20 | Time 0.8s
  Train Loss: 8.6114 | Val Loss: 8.6126
  VERB   Train Top1: 0.020833333333333332, Top5: 0.0625; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.010416666666666666; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.07291666666666667; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 

Epoch 14 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 14 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 14/20 | Time 0.9s
  Train Loss: 8.7145 | Val Loss: 8.6126
  VERB   Train Top1: 0.020833333333333332, Top5: 0.09375; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  Top5: 0.1
 

Epoch 15 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 15 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 15/20 | Time 0.8s
  Train Loss: 8.3788 | Val Loss: 8.6126
  VERB   Train Top1: 0.020833333333333332, Top5: 0.0625; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.010416666666666666; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.08333333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 

Epoch 16 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 16 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 16/20 | Time 0.8s
  Train Loss: 8.6929 | Val Loss: 8.6126
  VERB   Train Top1: 0.010416666666666666, Top5: 0.08333333333333333; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.010416666666666666; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.125; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0

Epoch 17 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 17 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 17/20 | Time 0.7s
  Train Loss: 8.3643 | Val Loss: 8.6126
  VERB   Train Top1: 0.03125, Top5: 0.09375; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.010416666666666666; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  Top5: 0

Epoch 18 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 18 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 18/20 | Time 0.8s
  Train Loss: 8.6333 | Val Loss: 8.6126
  VERB   Train Top1: 0.010416666666666666, Top5: 0.08333333333333333; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.10416666666666667; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0 

Epoch 19 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 19 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 19/20 | Time 0.8s
  Train Loss: 8.6734 | Val Loss: 8.6126
  VERB   Train Top1: 0.03125, Top5: 0.10416666666666667; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.020833333333333332; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.10416666666666667; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 

Epoch 20 Train:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 20 Val:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 20/20 | Time 0.8s
  Train Loss: 8.3347 | Val Loss: 8.6126
  VERB   Train Top1: 0.010416666666666666, Top5: 0.09375; Val Top1: 0.019230769230769232, Top5: 0.09615384615384616
  NOUN   Train Top1: 0.0, Top5: 0.0; Val Top1: 0.0, Top5: 0.0
  ACTION Train Top1: 0.0, Top5: 0.11458333333333333; Val Top1: 0.057692307692307696, Top5: 0.17307692307692307
  VERB   Val Precision: 0.0667, Recall: 0.0021, F1: 0.0040
  NOUN   Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ACTION Val Precision: 0.0097, Recall: 0.0714, F1: 0.0171
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.13125
     NOUN    Mean Top-5 Recall: 0.0
     ACTION  Mean Top-5 Recall: 0.134375
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.0  Top5: 0.0
    @ 0.50s  Top1: 0.0  Top5: 0.0
    @ 0.75s  Top1: 0.25  Top5: 0.75
    @ 1.00s  Top1: 0.0  Top5: 0.2
    @ 1.25s  Top1: 0.0  Top5: 0.0
    @ 1.50s  Top1: 0.0  Top5: 0.0
    @ 1.75s  Top1: 0.0  Top5: 0.0
    @ 2.00s  Top1: 0.0  Top5: 0.1
 