### Import Libraries

In [2]:
from tools.imports import *

from tools.pca import *

###  Configuration  (Loading PCA and Frames)

In [25]:
RGB_FOLDER = Path(r"D:\Datasets\Datasets\EPIC_Kitchen\RGB\P01_01\Original")
FLOW_FOLDER = Path(r"D:\Datasets\Datasets\EPIC_Kitchen\OpticalFlow\P01_01\P01_01")
LABEL_CSV = Path(r"EPIC-Kitchens\Labels\P01_01.csv")

OUTPUT_FUSED_CSV = Path(r"EPIC-Kitchens\Features\P01_01_fused_features_PCA.csv")

PCA_PATH = "pca_2048_to_512.pkl"

SAMPLE_RATE = 1
FEAT_DIM = 512
W_RGB = 0.6
W_FLOW = 0.4

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

OUTPUT_FUSED_CSV.parent.mkdir(parents=True, exist_ok=True)

_frame_number_re = re.compile(r"(\d+)(?=\.[^.]+$)")

def parse_frame_index(fname: str):
    m = _frame_number_re.search(fname)
    if m:
        return int(m.group(1))
    digs = re.findall(r"\d+", fname)
    return int(digs[-1]) if digs else 0

In [26]:
_resnet = resnet50(weights=True)
_resnet = nn.Sequential(*list(_resnet.children())[:-1])
_resnet = _resnet.to(DEVICE).eval()   # output: (B,2048,1,1)

_transform = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],
                [0.229,0.224,0.225])
])

# ------------------ LOAD PCA --------------------------------

pca = joblib.load(PCA_PATH)
assert pca.components_.shape == (FEAT_DIM, 2048), "PCA dimension mismatch"

# ---------------- FEATURE EXTRACTION ------------------------

@torch.no_grad()
def extract_feature_from_pil(pil_img: Image.Image):
    x = _transform(pil_img).unsqueeze(0).to(DEVICE)  # (1,3,224,224)
    feat = _resnet(x).view(-1).cpu().numpy()         # (2048,)
    feat = pca.transform(feat[None, :])[0]          # (512,)
    return feat.astype(np.float32)

# ---------------- MAIN EXTRACTION FUNCTION ------------------

def extract_and_save_fused(csv_labels_path: Path,
                           rgb_folder: Path,
                           flow_folder: Path or None,
                           out_fused_csv: Path,
                           sample_rate: int = 1,
                           w_rgb: float = 0.6,
                           w_flow: float = 0.4):

    labels_df = pd.read_csv(csv_labels_path)

    rgb_files = sorted([
        p for p in rgb_folder.iterdir()
        if p.suffix.lower() in [".jpg", ".png", ".jpeg"]
    ])

    sampled = rgb_files[::sample_rate]
    if len(sampled) == 0:
        raise RuntimeError(f"No frames found in {rgb_folder}")

    fused_rows = []

    for fp in tqdm(sampled, desc="Extract & Fuse"):
        fname = fp.name
        frame_idx = parse_frame_index(fname)

        # -------- RGB --------
        try:
            pil = Image.open(fp).convert("RGB")
            rgb_feat = extract_feature_from_pil(pil)
        except Exception as e:
            print(f"[WARN] RGB skip {fname}: {e}")
            continue

        # -------- FLOW --------
        if flow_folder is not None:
            ffp = flow_folder / fname
            if not ffp.exists():
                ffp = fp
            try:
                pilf = Image.open(ffp).convert("RGB")
                flow_feat = extract_feature_from_pil(pilf)
            except Exception as e:
                print(f"[WARN] FLOW skip {fname}: {e}")
                flow_feat = np.zeros(FEAT_DIM, dtype=np.float32)
        else:
            flow_feat = np.zeros(FEAT_DIM, dtype=np.float32)

        # -------- FUSION --------
        fused_vec = w_rgb * rgb_feat + w_flow * flow_feat

        # -------- LABEL --------
        lr = labels_df[
            (labels_df["StartFrame"] <= frame_idx) &
            (labels_df["EndFrame"] >= frame_idx)
        ]

        if not lr.empty:
            action_label = int(lr.iloc[0].get("ActionLabel", -1))
            action_name = str(lr.iloc[0].get("ActionName", "Unknown"))
        else:
            action_label, action_name = -1, "Unknown"

        row = {
            "frame_idx": int(frame_idx),
            "frame_name": fname,
            "ActionLabel": action_label,
            "ActionName": action_name
        }

        for i, v in enumerate(fused_vec):
            row[f"feat_{i}"] = float(v)

        fused_rows.append(row)

    if len(fused_rows) == 0:
        raise RuntimeError("No fused rows extracted")

    df_fused = pd.DataFrame(fused_rows)
    df_fused.to_csv(out_fused_csv, index=False)

    print(f"[SAVED] {out_fused_csv}")
    return df_fused

# ---------------- RUN ---------------------------------------

df_fused = extract_and_save_fused(
    csv_labels_path=LABEL_CSV,
    rgb_folder=RGB_FOLDER,
    flow_folder=FLOW_FOLDER if FLOW_FOLDER.exists() else None,
    out_fused_csv=OUTPUT_FUSED_CSV,
    sample_rate=SAMPLE_RATE,
    w_rgb=W_RGB,
    w_flow=W_FLOW
)




Extract & Fuse:   0%|          | 0/99029 [00:00<?, ?it/s]

[SAVED] EPIC-Kitchens\Features\P01_01_fused_features_PCA.csv


In [5]:
data=pd.read_csv(r"EPIC-Kitchens\Features\FusedFeatures\P01_05_fused_features.csv")
data

Unnamed: 0,frame_idx,frame_name,ActionLabel,ActionName,feat_0,feat_1,feat_2,feat_3,feat_4,feat_5,...,feat_502,feat_503,feat_504,feat_505,feat_506,feat_507,feat_508,feat_509,feat_510,feat_511
0,0,frame_00000.jpg,-1,Unknown,-0.255704,0.014638,-0.127248,0.135866,-0.104432,0.529116,...,-0.081831,-0.088528,-0.030375,0.093051,0.060815,0.002188,-0.084385,0.218704,0.416650,-0.089111
1,1,frame_00001.jpg,-1,Unknown,-0.230471,-0.016865,-0.117778,0.106307,-0.078303,0.522713,...,-0.096457,-0.081578,-0.038350,0.075853,0.064037,-0.006416,-0.072980,0.234338,0.377129,-0.094787
2,2,frame_00002.jpg,-1,Unknown,-0.273896,-0.004563,-0.079866,0.153029,-0.142856,0.486159,...,-0.039064,-0.073958,-0.103700,0.016556,0.052797,0.036062,-0.094503,0.249313,0.400744,-0.055098
3,3,frame_00003.jpg,-1,Unknown,-0.186261,-0.046241,-0.045938,0.125934,-0.146717,0.536510,...,-0.048182,-0.073675,-0.129074,-0.007902,0.046740,0.017751,-0.068501,0.309177,0.469210,-0.033593
4,4,frame_00004.jpg,-1,Unknown,-0.155096,-0.016367,-0.074818,0.140323,-0.188490,0.487527,...,-0.094458,-0.096937,-0.110825,-0.026800,0.013834,-0.002211,-0.115081,0.297364,0.386572,-0.028801
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
76320,76320,frame_76320.jpg,-1,Unknown,0.247844,-0.085575,0.053120,0.611242,-0.299913,0.715905,...,-0.032531,0.135446,0.012278,-0.044854,0.086764,0.046305,0.082122,0.253486,0.450076,0.017609
76321,76321,frame_76321.jpg,-1,Unknown,0.045791,-0.003731,-0.000484,0.464047,-0.210662,0.553370,...,-0.001068,0.213597,-0.024942,-0.065709,0.017595,-0.005818,-0.079159,0.176065,0.269000,0.035629
76322,76322,frame_76322.jpg,-1,Unknown,0.090800,-0.070876,-0.058114,0.529759,-0.270593,0.681007,...,0.013677,0.212171,0.000859,-0.044839,0.090881,-0.009402,0.097548,0.200796,0.435586,0.218567
76323,76323,frame_76323.jpg,-1,Unknown,0.094027,0.004024,0.074413,0.527673,-0.240405,0.576615,...,0.037703,0.203596,-0.049638,-0.022510,-0.021834,-0.098123,-0.090395,0.206394,0.312601,0.091531


### Load the CSV file and Features

In [7]:
import pandas as pd
from pathlib import Path

def load_fused_csv_by_path(fused_csv_path: str):
    fp = Path(fused_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Fused features CSV not found: {fp}")
    df = pd.read_csv(fp)
    if "frame_idx" not in df.columns:
        raise KeyError("Fused CSV must contain 'frame_idx' column")
    df["frame_idx"] = df["frame_idx"].astype(int)
    df = df.sort_values("frame_idx").reset_index(drop=True)
    return df

def load_label_csv_by_path(label_csv_path: str):
    fp = Path(label_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Label CSV not found: {fp}")
    df = pd.read_csv(fp)
    return df

# ------------------------------------ Paths ---------------------------
fused_df = load_fused_csv_by_path(r"EPIC-Kitchens\Features\FusedFeatures\P01_05_fused_features.csv")
labels_df = load_label_csv_by_path(r"EPIC-Kitchens\Labels\P01_05.csv")

### Dataset Loader

In [8]:
IGNORE_INDEX = -1

class SingleVideoAnticipationDataset(Dataset):
    def __init__(self, fused_df_or_path, labels_df_or_path,
                 t_obs: int, k_fut: int, feat_dim: int,
                 fps: float, horizons_s: list[float]):
        
        # load paths
        if isinstance(fused_df_or_path, (str, Path)):
            fused_df = pd.read_csv(fused_df_or_path)
        else:
            fused_df = fused_df_or_path.copy()
        if isinstance(labels_df_or_path, (str, Path)):
            labels_df = pd.read_csv(labels_df_or_path)
        else:
            labels_df = labels_df_or_path.copy()

        if "frame_idx" not in fused_df.columns:
            raise KeyError("fused_df must contain 'frame_idx'")

        fused_df["frame_idx"] = fused_df["frame_idx"].astype(int)
        self.fused_df = fused_df.set_index("frame_idx", drop=False).sort_index()
        self.labels_df = labels_df.reset_index(drop=True)

        if not all(c in self.labels_df.columns for c in ["StartFrame", "EndFrame"]):
            raise KeyError("labels_df must contain StartFrame and EndFrame")

        self.t_obs = int(t_obs)
        self.k_fut = int(k_fut)
        self.feat_dim = int(feat_dim)
        self.feat_cols = [f"feat_{i}" for i in range(self.feat_dim)]

        # NEW: time info
        self.fps = float(fps)
        assert len(horizons_s) == self.k_fut, "len(horizons_s) must equal k_fut"
        self.horizons_s = list(horizons_s)

        # samples: one per label row (use EndFrame as obs_end)
        self.samples = []
        for ridx, row in self.labels_df.iterrows():
            try:
                obs_end = int(row["EndFrame"])
            except:
                continue
            self.samples.append({"label_row_idx": int(ridx), "obs_end": obs_end})

        if len(self.samples) == 0:
            raise RuntimeError("No valid label rows found")

    def __len__(self):
        return len(self.samples)

    def _time_based_future_labels(self, obs_end: int):
        labels_df = self.labels_df

        def pick(cols):
            for c in cols:
                if c in labels_df.columns:
                    return c
            return None

        vcol = pick(["Verb_class","verb","Verb","verb_class"])
        ncol = pick(["Noun_class","noun","Noun","noun_class"])
        acol = pick(["Action_class","action","Action","ActionLabel"])

        verb_targets   = []
        noun_targets   = []
        action_targets = []

        for h_sec in self.horizons_s:
            future_frame = obs_end + int(round(h_sec * self.fps))
            seg = labels_df[(labels_df["StartFrame"] <= future_frame) &
                            (labels_df["EndFrame"]   >= future_frame)]
            if seg.empty:
                # nothing happening at that exact time
                verb_targets.append(IGNORE_INDEX)
                noun_targets.append(IGNORE_INDEX)
                action_targets.append(IGNORE_INDEX)
            else:
                row = seg.iloc[0]
                if vcol is not None and not pd.isna(row[vcol]):
                    verb_targets.append(int(row[vcol]))
                else:
                    verb_targets.append(IGNORE_INDEX)

                if ncol is not None and not pd.isna(row[ncol]):
                    noun_targets.append(int(row[ncol]))
                else:
                    noun_targets.append(IGNORE_INDEX)

                if acol is not None and not pd.isna(row[acol]):
                    action_targets.append(int(row[acol]))
                else:
                    action_targets.append(IGNORE_INDEX)

        return {
            "verb":   torch.LongTensor(verb_targets),
            "noun":   torch.LongTensor(noun_targets),
            "action": torch.LongTensor(action_targets)
        }

    def __getitem__(self, idx):
        rec = self.samples[idx]
        obs_end = rec["obs_end"]
        obs_start = obs_end - (self.t_obs - 1)
        if obs_start < 0:
            obs_start = 0
            obs_end = obs_start + (self.t_obs - 1)

        fused_idx_min = int(self.fused_df.index.min())
        fused_idx_max = int(self.fused_df.index.max())
        obs_end = min(obs_end, fused_idx_max)
        obs_start = max(obs_end - (self.t_obs - 1), fused_idx_min)

        desired = list(range(obs_start, obs_end + 1))
        sel = self.fused_df.reindex(desired).fillna(method="ffill").fillna(method="bfill").fillna(0.0)

        if sel.shape[0] < self.t_obs:
            if sel.shape[0] == 0:
                zero_row = {c:0.0 for c in self.feat_cols}
                sel = pd.DataFrame([zero_row] * self.t_obs)
            else:
                first = sel.iloc[[0]]
                pads = pd.concat([first] * (self.t_obs - sel.shape[0]), ignore_index=True)
                sel = pd.concat([pads, sel.reset_index(drop=True)], ignore_index=True)

        for c in self.feat_cols:
            if c not in sel.columns:
                sel[c] = 0.0

        F_window = torch.from_numpy(sel[self.feat_cols].values).float()   # (T_obs, FEAT_DIM)

        y_multi = self._time_based_future_labels(obs_end)

        meta = {"obs_start": int(obs_start),
                "obs_end":   int(obs_end),
                "label_row_idx": int(rec["label_row_idx"])}
        return F_window, y_multi, meta

### Graph Construction
_Graph Construction using kNN strategy and then apply GAT_

In [9]:
K=5  
DROP=0.1

def build_topk_edge_index(features: torch.Tensor, k=K):
    Tn = int(features.size(0))
    x = F.normalize(features, dim=1)
    sim = torch.matmul(x, x.t())   # (T,T) T is the number of frame-> features
    sim.fill_diagonal_(-1.0)
    vals, idxs = torch.topk(sim, k, dim=1)
    src = torch.arange(Tn).unsqueeze(1).expand(-1, k).reshape(-1)
    dst = idxs.reshape(-1)
    edge = torch.stack([src, dst], dim=0)
    edge_rev = torch.stack([dst, src], dim=0)
    return torch.cat([edge, edge_rev], dim=1).long()

class BatchedGAT(nn.Module):
    def __init__(self, in_dim, hid_dim=None, num_layers=3, heads=8, dropout=DROP):
        super().__init__()
        hid = hid_dim or in_dim
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_ch = in_dim if i==0 else hid
            self.convs.append(GATConv(in_ch, hid//heads, heads=heads, concat=True, dropout=dropout))
        self.proj = nn.Linear(hid, in_dim)
        self.norm = nn.LayerNorm(in_dim)
        self.act = nn.GELU()
    def forward(self, pyg_batch: PyGBatch, T_per_sample: int):
        x = pyg_batch.x; edge_index = pyg_batch.edge_index
        h = x
        for conv in self.convs:
            h = conv(h, edge_index); h = self.act(h)
        h = self.proj(h)
        node_feats, mask = to_dense_batch(h, batch=pyg_batch.batch)
        B, max_nodes, D = node_feats.shape
        if max_nodes < T_per_sample:
            pad = torch.zeros(B, T_per_sample - max_nodes, D, device=node_feats.device)
            node_feats = torch.cat([node_feats, pad], dim=1)
        elif max_nodes > T_per_sample:
            node_feats = node_feats[:, :T_per_sample, :]
        return self.norm(node_feats) 

### Fusion (Graph +Transformer) + Transfomer Decoder


In [10]:
class GETR(nn.Module):
    def __init__(self, d_model, nhead=8, num_layers=3, dim_feedforward=2048, dropout=DROP, max_len=1000):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=num_layers)
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model))
    def forward(self, x):
        B,T,D = x.shape
        pos = self.pos_emb[:, :T, :].to(x.device)
        return self.encoder(x + pos)

class AnticipationModel(nn.Module):
    def __init__(self, feat_dim, num_classes: dict, k_fut=5, gat_layers=3, gat_heads=8, dec_layers=3, dec_heads=8, dropout=DROP):
        super().__init__()
        self.feat_dim = feat_dim; self.k_fut = k_fut
        self.gat = BatchedGAT(in_dim=feat_dim, hid_dim=feat_dim, num_layers=gat_layers, heads=gat_heads, dropout=dropout)
        self.encoder = GETR(d_model=feat_dim, nhead=dec_heads, num_layers=3)
        dec_layer = nn.TransformerDecoderLayer(d_model=feat_dim, nhead=dec_heads, dim_feedforward=feat_dim*4, dropout=dropout, activation='gelu', batch_first=True)
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=dec_layers)
        self.queries = nn.Parameter(torch.randn(1, k_fut, feat_dim))
        assert isinstance(num_classes, dict)
        self.verb_head = nn.Linear(feat_dim, num_classes["verb"])
        self.noun_head = nn.Linear(feat_dim, num_classes["noun"])
        self.action_head = nn.Linear(feat_dim, num_classes["action"])

    def forward(self, F_batch):
        # F_batch: (B, T, D)
        B,T,D = F_batch.shape; device = F_batch.device
        data_list=[]
        for b in range(B):
            x = F_batch[b]
            edge_index = build_topk_edge_index(x.detach().cpu(), k=K).to(device)
            data_list.append(PyGData(x=x, edge_index=edge_index))
        pyg_batch = PyGBatch.from_data_list(data_list).to(device)
        G = self.gat(pyg_batch, T_per_sample=T)   # (B,T,D)
        H = self.encoder(F_batch)                 # (B,T,D)
        U = H + G
        q = self.queries.expand(B, -1, -1).to(device)
        dec_out = self.decoder(tgt=q, memory=U)   # (B, K_fut, D)
        return {"verb": self.verb_head(dec_out),
                "noun": self.noun_head(dec_out),
                "action": self.action_head(dec_out)}

### Cross-Entropy (Masked means Loss for Verb, Noun, Action)

In [11]:
import torch.nn.functional as F

def masked_cross_entropy(logits, labels, ignore_index=IGNORE_INDEX):
    B, K, C = logits.shape
    logits_flat = logits.view(B * K, C)      # (B*K, C)
    labels_flat = labels.view(B * K)         # (B*K,)

    loss_flat = F.cross_entropy(
        logits_flat,
        labels_flat,
        reduction='none',
        ignore_index=ignore_index
    )  # (B*K,)

    mask = (labels_flat != ignore_index).float()
    valid = mask.sum()

    if valid == 0:
        return (logits_flat * 0.0).sum()

    return (loss_flat * mask).sum() / valid


def topk_accuracy_per_task(logits, labels, topk=(1,5), ignore_index=IGNORE_INDEX):
    B,K,C = logits.shape
    res = {}
    overall = {k:0 for k in topk}
    total_cnt = 0
    preds_topk = logits.topk(max(topk), dim=-1)[1]  # (B,K,maxk)
    for h in range(K):
        lab = labels[:,h]; mask = (lab != ignore_index); cnt = int(mask.sum().item())
        for k in topk:
            if cnt == 0:
                res.setdefault(f"per_h{h+1}_top{k}", None)
                continue
            predk = preds_topk[:,h,:k]  # (B,k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            hits = (predk == lab_exp)
            hit = int(hits[mask].any(dim=1).float().sum().item())
            res[f"per_h{h+1}_top{k}"] = hit / cnt
            overall[k] += hit
        total_cnt += cnt
    for k in topk:
        res[f"overall_top{k}"] = overall[k] / total_cnt if total_cnt>0 else None
    return res

### Training and Validation

In [22]:
FUSED_CSV_PATH = r"EPIC-Kitchens\Features\FusedFeatures\P01_05_fused_features.csv"
LABEL_CSV_PATH = r"EPIC-Kitchens\Labels\P01_05.csv"
BEST_MODEL_PATH = Path(r"Model\P01_05_fused_model.pth")


# Hyperparams
T_OBS = 90
FEAT_DIM = 512
BATCH_SIZE = 8
NUM_EPOCHS = 20
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 0

# === Time-based anticipation config ===
FPS = 30.0 
HORIZONS_S = [0.25, 0.5, 0.75, 1.0, 1.25, 1.50, 1.75, 2.0] 
K_FUT = len(HORIZONS_S) 

In [23]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def detect_num_classes_from_labels_df(labels_df):
    verbs = set()
    nouns = set()
    actions = set()
    for cand in ["Verb_class", "verb", "Verb", "verb_class"]:
        if cand in labels_df.columns:
            verbs.update(labels_df[cand].dropna().astype(int).tolist())
            break
    for cand in ["Noun_class", "noun", "Noun", "noun_class"]:
        if cand in labels_df.columns:
            nouns.update(labels_df[cand].dropna().astype(int).tolist())
            break
    for cand in ["Action_class", "action", "Action", "ActionLabel"]:
        if cand in labels_df.columns:
            actions.update(labels_df[cand].dropna().astype(int).tolist())
            break
    nv = (max(verbs) + 1) if len(verbs) > 0 else 1
    nn_ = (max(nouns) + 1) if len(nouns) > 0 else 1
    na = (max(actions) + 1) if len(actions) > 0 else 1
    return {"verb": int(nv), "noun": int(nn_), "action": int(na)}


def topk_counts(logits, labels, k):
    # logits: (B, K_fut, C); labels: (B, K_fut)
    with torch.no_grad():
        B, K, C = logits.shape
        topk_preds = logits.topk(k, dim=-1)[1]  # (B, K, k)
        hits = 0
        total = 0
        for h in range(K):
            lab = labels[:, h]  # (B,)
            mask = (lab != IGNORE_INDEX)
            if int(mask.sum().item()) == 0:
                continue
            predk = topk_preds[:, h, :]  # (B, k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            masked_pred = predk[mask]   # (M, k)
            masked_lab = lab_exp[mask]  # (M, k)
            hit_vec = (masked_pred == masked_lab).any(dim=1).float()
            hits += int(hit_vec.sum().item())
            total += int(mask.sum().item())
        return hits, total


# Load fused and labels
fused_df = pd.read_csv(FUSED_CSV_PATH)
labels_df = pd.read_csv(LABEL_CSV_PATH)

# Dataset
dataset = SingleVideoAnticipationDataset(
    fused_df,
    labels_df,
    t_obs=T_OBS,
    k_fut=K_FUT,        # must equal len(HORIZONS_S)
    feat_dim=FEAT_DIM,
    fps=FPS,
    horizons_s=HORIZONS_S
)

# split indices for train/val (60/40)
indices = list(range(len(dataset)))
random.seed(42)
random.shuffle(indices)
split_at = int(0.6 * len(indices))
train_idx = indices[:split_at]
val_idx = indices[split_at:]

train_ds = Subset(dataset, train_idx)
val_ds = Subset(dataset, val_idx)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=(DEVICE == "cuda")
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=(DEVICE == "cuda")
)

num_classes = detect_num_classes_from_labels_df(labels_df)
print("Detected classes:", num_classes)

model = AnticipationModel(
    feat_dim=FEAT_DIM,
    num_classes=num_classes,
    k_fut=K_FUT
).to(DEVICE)



# ================= MODEL PARAMETER COUNT =================
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

total_params, trainable_params = count_parameters(model)

print("\n========== MODEL PARAMETERS ==========")
print(f"Total parameters     : {total_params/1e6:.1f} M")
print(f"Trainable parameters : {trainable_params/1e6:.2f} M")
print("=====================================\n")


# ================= MODEL FLOPs =================
from thop import profile

model.eval()

dummy_input = torch.randn(
    1, T_OBS, FEAT_DIM
).to(DEVICE)

macs, params = profile(
    model,
    inputs=(dummy_input,),
    verbose=False
)

print("========== MODEL FLOPs ==========")
print(f"MACs  : {macs/1e9:.2f} G")
print(f"FLOPs : {(2*macs)/1e9:.2f} G")
print("================================\n")



# ================= INFERENCE LATENCY =================
def measure_latency(model, device, runs=100):
    model.eval()
    dummy = torch.randn(1, T_OBS, FEAT_DIM).to(device)

    # warm-up
    for _ in range(10):
        _ = model(dummy)

    if device == "cuda":
        torch.cuda.synchronize()

    start = time.time()
    for _ in range(runs):
        _ = model(dummy)

    if device == "cuda":
        torch.cuda.synchronize()

    avg_latency = (time.time() - start) / runs
    return avg_latency * 1000  # ms


latency_ms = measure_latency(model, DEVICE)

print("========== INFERENCE LATENCY ==========")
print(f"Average latency per sample: {latency_ms:.2f} ms")
print("======================================\n")

Detected classes: {'verb': 81, 'noun': 332, 'action': 123}

Total parameters     : 23.9 M
Trainable parameters : 23.92 M

MACs  : 0.08 G
FLOPs : 0.15 G

Average latency per sample: 21.58 ms



In [24]:
opt = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
sched = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=3
)

best_val_loss = float("inf")

In [25]:
opt

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0.0001
)

### Training

In [26]:
import torch.optim as optim
for epoch in range(1, NUM_EPOCHS + 1):
    t0 = time.time()

    # ------------- TRAIN -------------
    model.train()
    train_loss_sum = 0.0
    train_samples = 0
    train_counts = {
        "verb_top1": [0, 0], "verb_top5": [0, 0],
        "noun_top1": [0, 0], "noun_top5": [0, 0],
        "action_top1": [0, 0], "action_top5": [0, 0]
    }

    pbar = tqdm(train_loader, desc=f"Epoch {epoch} Train", leave=False)
    for F_batch, y_multi, meta in pbar:
        F_batch = F_batch.to(DEVICE)               # (B, T, D)
        y_v = y_multi["verb"].to(DEVICE)           # (B, K_fut)
        y_n = y_multi["noun"].to(DEVICE)
        y_a = y_multi["action"].to(DEVICE)

        opt.zero_grad()
        logits = model(F_batch)   # dict: "verb"/"noun"/"action" -> (B, K_fut, C)

        loss_v = masked_cross_entropy(logits["verb"], y_v)
        loss_n = masked_cross_entropy(logits["noun"], y_n)
        loss_a = masked_cross_entropy(logits["action"], y_a)
        loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

        loss.backward()
        opt.step()

        b = F_batch.size(0)
        train_loss_sum += float(loss.item()) * b
        train_samples += b

        for (task, lab, lg) in [
            ("verb", y_v, logits["verb"]),
            ("noun", y_n, logits["noun"]),
            ("action", y_a, logits["action"])
        ]:
            h1, t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
            h5, t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
            train_counts[f"{task}_top1"][0] += h1
            train_counts[f"{task}_top1"][1] += t1
            train_counts[f"{task}_top5"][0] += h5
            train_counts[f"{task}_top5"][1] += t5

    train_loss = train_loss_sum / max(1, train_samples)
    train_metrics = {}
    for task in ["verb", "noun", "action"]:
        h1, t1 = train_counts[f"{task}_top1"]
        h5, t5 = train_counts[f"{task}_top5"]
        train_metrics[f"{task}_top1"] = (h1 / t1) if t1 > 0 else None
        train_metrics[f"{task}_top5"] = (h5 / t5) if t5 > 0 else None

    # ------------- VALIDATION -------------
    model.eval()
    val_loss_sum = 0.0
    val_samples = 0
    val_counts = {
        "verb_top1": [0, 0], "verb_top5": [0, 0],
        "noun_top1": [0, 0], "noun_top5": [0, 0],
        "action_top1": [0, 0], "action_top5": [0, 0]
    }

    # store logits/labels for per-horizon + P/R/F1 metrics
    val_logits_store = {"verb": [], "noun": [], "action": []}
    val_labels_store = {"verb": [], "noun": [], "action": []}

    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f"Epoch {epoch} Val", leave=False)
        for F_batch, y_multi, meta in pbar:
            F_batch = F_batch.to(DEVICE)
            y_v = y_multi["verb"].to(DEVICE)
            y_n = y_multi["noun"].to(DEVICE)
            y_a = y_multi["action"].to(DEVICE)

            logits = model(F_batch)
            loss_v = masked_cross_entropy(logits["verb"], y_v)
            loss_n = masked_cross_entropy(logits["noun"], y_n)
            loss_a = masked_cross_entropy(logits["action"], y_a)
            loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

            b = F_batch.size(0)
            val_loss_sum += float(loss.item()) * b
            val_samples += b

            for (task, lab, lg) in [
                ("verb", y_v, logits["verb"]),
                ("noun", y_n, logits["noun"]),
                ("action", y_a, logits["action"])
            ]:
                h1, t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
                h5, t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
                val_counts[f"{task}_top1"][0] += h1
                val_counts[f"{task}_top1"][1] += t1
                val_counts[f"{task}_top5"][0] += h5
                val_counts[f"{task}_top5"][1] += t5

            # store for per-horizon + P/R/F1 metrics
            val_logits_store["verb"].append(logits["verb"].detach().cpu())
            val_logits_store["noun"].append(logits["noun"].detach().cpu())
            val_logits_store["action"].append(logits["action"].detach().cpu())
            val_labels_store["verb"].append(y_v.detach().cpu())
            val_labels_store["noun"].append(y_n.detach().cpu())
            val_labels_store["action"].append(y_a.detach().cpu())

    val_loss = val_loss_sum / max(1, val_samples)

    # overall val metrics (top-1/top-5 over all horizons)
    val_metrics = {}
    for task in ["verb", "noun", "action"]:
        h1, t1 = val_counts[f"{task}_top1"]
        h5, t5 = val_counts[f"{task}_top5"]
        val_metrics[f"{task}_top1"] = (h1 / t1) if t1 > 0 else None
        val_metrics[f"{task}_top5"] = (h5 / t5) if t5 > 0 else None

    # per-horizon metrics (time-based)
    per_horizon_metrics = {"verb": {}, "noun": {}, "action": {}}
    for task in ["verb", "noun", "action"]:
        if len(val_logits_store[task]) == 0:
            continue
        logits_all = torch.cat(val_logits_store[task], dim=0)  # (N, K_fut, C)
        labels_all = torch.cat(val_labels_store[task], dim=0)  # (N, K_fut)
        m = topk_accuracy_per_task(
            logits_all,
            labels_all,
            topk=(1, 5),
            ignore_index=IGNORE_INDEX
        )
        per_horizon_metrics[task] = m

    # macro precision / recall / F1 over all horizons (validation)
    prf_metrics = {"verb": {}, "noun": {}, "action": {}}
    for task in ["verb", "noun", "action"]:
        if len(val_logits_store[task]) == 0:
            continue

        logits_all = torch.cat(val_logits_store[task], dim=0)
        labels_all = torch.cat(val_labels_store[task], dim=0)

        preds_all = logits_all.argmax(dim=-1)  # (N, K_fut)
        mask = (labels_all != IGNORE_INDEX)
        if mask.sum().item() == 0:
            continue

        y_true = labels_all[mask].numpy()
        y_pred = preds_all[mask].numpy()

        p, r, f1, _ = precision_recall_fscore_support(
            y_true,
            y_pred,
            average="macro",
            zero_division=0
        )
        prf_metrics[task]["precision"] = p
        prf_metrics[task]["recall"] = r
        prf_metrics[task]["f1"] = f1

    # mean Top-5 recall across horizons for each task
    mean_top5_recall = {}
    for task in ["verb", "noun", "action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            mean_top5_recall[task] = None
            continue

        vals = []
        for h_idx in range(K_FUT):
            key = f"per_h{h_idx+1}_top5"
            if key in mh and mh[key] is not None:
                vals.append(mh[key])
        mean_top5_recall[task] = float(np.mean(vals)) if len(vals) > 0 else None

    # scheduler + logging
    sched.step(val_loss)
    elapsed = time.time() - t0

    print(f"Epoch {epoch}/{NUM_EPOCHS} | Time {elapsed:.1f}s")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    for task in ["verb", "noun", "action"]:
        print(
            f"  {task.upper():6s} Train Top1: {train_metrics[f'{task}_top1']}, "
            f"Top5: {train_metrics[f'{task}_top5']}; "
            f"Val Top1: {val_metrics[f'{task}_top1']}, "
            f"Top5: {val_metrics[f'{task}_top5']}"
        )

    # print macro precision / recall / F1 (validation)
    for task in ["verb", "noun", "action"]:
        if prf_metrics[task]:
            p = prf_metrics[task]["precision"]
            r = prf_metrics[task]["recall"]
            f1 = prf_metrics[task]["f1"]
            print(
                f"  {task.upper():6s} Val Precision: {p:.4f}, "
                f"Recall: {r:.4f}, F1: {f1:.4f}"
            )

    # print mean Top-5 recall
    print("  ---- Mean Top-5 Recall (validation) ----")
    for task in ["verb", "noun", "action"]:
        print(
            f"     {task.upper():6s}  Mean Top-5 Recall: {mean_top5_recall[task]}"
        )

    # print per-horizon by seconds
    for task in ["verb", "noun", "action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            continue
        print(f"  {task.upper():6s} per-horizon (time-based):")
        for h_idx, t_sec in enumerate(HORIZONS_S):
            key1 = f"per_h{h_idx+1}_top1"
            key5 = f"per_h{h_idx+1}_top5"
            v1 = mh.get(key1, None)
            v5 = mh.get(key5, None)
            print(f"    @ {t_sec:4.2f}s  Top1: {v1}  Top5: {v5}")
        print(
            f"    overall_top1: {mh.get('overall_top1', None)}, "
            f"overall_top5: {mh.get('overall_top5', None)}"
        )

    # optional: save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(
            {
                'epoch': epoch,
                'model_state': model.state_dict(),
                'opt_state': opt.state_dict(),
                'val_loss': val_loss
            },
            BEST_MODEL_PATH
        )
        print(f"[SAVED BEST] -> {BEST_MODEL_PATH}")

print("Training finished.")

Epoch 1 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 1/20 | Time 5.7s
  Train Loss: 9.1144 | Val Loss: 9.2078
  VERB   Train Top1: 0.14446227929373998, Top5: 0.5601926163723917; Val Top1: 0.05511811023622047, Top5: 0.6404199475065617
  NOUN   Train Top1: 0.08186195826645265, Top5: 0.2504012841091493; Val Top1: 0.015748031496062992, Top5: 0.2992125984251969
  ACTION Train Top1: 0.033707865168539325, Top5: 0.16051364365971107; Val Top1: 0.0, Top5: 0.015748031496062992
  VERB   Val Precision: 0.0037, Recall: 0.0667, F1: 0.0070
  NOUN   Val Precision: 0.0005, Recall: 0.0323, F1: 0.0010
  ACTION Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6323922134275302
     NOUN    Mean Top-5 Recall: 0.2991466414430677
     ACTION  Mean Top-5 Recall: 0.015830903523067503
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.058823529411764705  Top5: 0.5882352941176471
    @ 0.50s  Top1: 0.05405405405405406  Top5: 0.5675675675675675
    @ 0.75s  Top1: 0.04545454545454

Epoch 2 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 2 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 2/20 | Time 5.6s
  Train Loss: 7.6256 | Val Loss: 8.7153
  VERB   Train Top1: 0.14125200642054575, Top5: 0.7078651685393258; Val Top1: 0.2125984251968504, Top5: 0.6272965879265092
  NOUN   Train Top1: 0.1492776886035313, Top5: 0.42215088282504015; Val Top1: 0.015748031496062992, Top5: 0.19160104986876642
  ACTION Train Top1: 0.056179775280898875, Top5: 0.24398073836276082; Val Top1: 0.0, Top5: 0.07611548556430446
  VERB   Val Precision: 0.0299, Recall: 0.0553, F1: 0.0370
  NOUN   Val Precision: 0.0005, Recall: 0.0323, F1: 0.0010
  ACTION Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6194906940299676
     NOUN    Mean Top-5 Recall: 0.19366243469611677
     ACTION  Mean Top-5 Recall: 0.07775373082218257
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.5588235294117647
    @ 0.50s  Top1: 0.16216216216216217  Top5: 0.5675675675675675
    @ 0.75s  Top1: 0.1590909090909091

Epoch 3 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 3 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 3/20 | Time 5.6s
  Train Loss: 7.0397 | Val Loss: 8.8476
  VERB   Train Top1: 0.18138041733547353, Top5: 0.7158908507223114; Val Top1: 0.14435695538057744, Top5: 0.6272965879265092
  NOUN   Train Top1: 0.13001605136436598, Top5: 0.4446227929373997; Val Top1: 0.05774278215223097, Top5: 0.18110236220472442
  ACTION Train Top1: 0.07062600321027288, Top5: 0.2953451043338684; Val Top1: 0.015748031496062992, Top5: 0.023622047244094488
  VERB   Val Precision: 0.0235, Recall: 0.0693, F1: 0.0303
  NOUN   Val Precision: 0.0269, Recall: 0.0566, F1: 0.0199
  ACTION Val Precision: 0.0003, Recall: 0.0196, F1: 0.0006
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6194906940299676
     NOUN    Mean Top-5 Recall: 0.18386920406860757
     ACTION  Mean Top-5 Recall: 0.022451785665728228
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.5588235294117647
    @ 0.50s  Top1: 0.08108108108108109  Top5: 0.5675675675675675
    @ 0.75s  Top1: 0

Epoch 4 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 4 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 4/20 | Time 5.5s
  Train Loss: 6.6104 | Val Loss: 8.7230
  VERB   Train Top1: 0.2102728731942215, Top5: 0.7255216693418941; Val Top1: 0.18110236220472442, Top5: 0.6981627296587927
  NOUN   Train Top1: 0.20545746388443017, Top5: 0.4478330658105939; Val Top1: 0.05511811023622047, Top5: 0.28346456692913385
  ACTION Train Top1: 0.13162118780096307, Top5: 0.3563402889245586; Val Top1: 0.0, Top5: 0.06561679790026247
  VERB   Val Precision: 0.0483, Recall: 0.0760, F1: 0.0399
  NOUN   Val Precision: 0.0105, Recall: 0.0566, F1: 0.0133
  ACTION Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6912864967731787
     NOUN    Mean Top-5 Recall: 0.2869464685486024
     ACTION  Mean Top-5 Recall: 0.06742303090720407
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.6470588235294118
    @ 0.50s  Top1: 0.08108108108108109  Top5: 0.6486486486486487
    @ 0.75s  Top1: 0.18181818181818182  T

Epoch 5 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 5 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 5/20 | Time 5.3s
  Train Loss: 5.8192 | Val Loss: 8.5952
  VERB   Train Top1: 0.2102728731942215, Top5: 0.7624398073836276; Val Top1: 0.25196850393700787, Top5: 0.6797900262467191
  NOUN   Train Top1: 0.29213483146067415, Top5: 0.5585874799357945; Val Top1: 0.16272965879265092, Top5: 0.29133858267716534
  ACTION Train Top1: 0.18459069020866772, Top5: 0.43659711075441415; Val Top1: 0.07874015748031496, Top5: 0.13910761154855644
  VERB   Val Precision: 0.1160, Recall: 0.1458, F1: 0.1101
  NOUN   Val Precision: 0.0261, Recall: 0.0735, F1: 0.0335
  ACTION Val Precision: 0.0173, Recall: 0.0406, F1: 0.0236
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6721656899063286
     NOUN    Mean Top-5 Recall: 0.29231187476540177
     ACTION  Mean Top-5 Recall: 0.1400080727910476
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.2647058823529412  Top5: 0.6176470588235294
    @ 0.50s  Top1: 0.21621621621621623  Top5: 0.6216216216216216
    @ 0.75s  Top1: 0.2727

Epoch 6 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 6 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 6/20 | Time 5.6s
  Train Loss: 5.6809 | Val Loss: 8.4189
  VERB   Train Top1: 0.26163723916532905, Top5: 0.8073836276083467; Val Top1: 0.2992125984251969, Top5: 0.7847769028871391
  NOUN   Train Top1: 0.32263242375601925, Top5: 0.6356340288924559; Val Top1: 0.15748031496062992, Top5: 0.3123359580052493
  ACTION Train Top1: 0.20064205457463885, Top5: 0.5553772070626003; Val Top1: 0.05774278215223097, Top5: 0.2755905511811024
  VERB   Val Precision: 0.1154, Recall: 0.2025, F1: 0.1441
  NOUN   Val Precision: 0.0588, Recall: 0.1026, F1: 0.0610
  ACTION Val Precision: 0.0170, Recall: 0.0290, F1: 0.0199
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7813192966433732
     NOUN    Mean Top-5 Recall: 0.31366759886263196
     ACTION  Mean Top-5 Recall: 0.2804676318604871
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.38235294117647056  Top5: 0.7647058823529411
    @ 0.50s  Top1: 0.2972972972972973  Top5: 0.7567567567567568
    @ 0.75s  Top1: 0.3181818

Epoch 7 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 7 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 7/20 | Time 6.0s
  Train Loss: 4.7991 | Val Loss: 8.1295
  VERB   Train Top1: 0.3258426966292135, Top5: 0.8715890850722311; Val Top1: 0.2992125984251969, Top5: 0.8162729658792651
  NOUN   Train Top1: 0.3884430176565008, Top5: 0.6966292134831461; Val Top1: 0.17060367454068243, Top5: 0.3648293963254593
  ACTION Train Top1: 0.3402889245585875, Top5: 0.7207062600321027; Val Top1: 0.08923884514435695, Top5: 0.2650918635170604
  VERB   Val Precision: 0.1600, Recall: 0.2124, F1: 0.1502
  NOUN   Val Precision: 0.0827, Recall: 0.1161, F1: 0.0671
  ACTION Val Precision: 0.0463, Recall: 0.0599, F1: 0.0388
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.8113540126723042
     NOUN    Mean Top-5 Recall: 0.36987789544921634
     ACTION  Mean Top-5 Recall: 0.27076909820267486
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.7647058823529411
    @ 0.50s  Top1: 0.16216216216216217  Top5: 0.7837837837837838
    @ 0.75s  Top1: 0.25  Top5

Epoch 8 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 8 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 8/20 | Time 5.8s
  Train Loss: 4.0443 | Val Loss: 8.1580
  VERB   Train Top1: 0.4333868378812199, Top5: 0.9020866773675762; Val Top1: 0.24146981627296588, Top5: 0.8031496062992126
  NOUN   Train Top1: 0.4959871589085072, Top5: 0.8523274478330658; Val Top1: 0.17585301837270342, Top5: 0.4094488188976378
  ACTION Train Top1: 0.43659711075441415, Top5: 0.8523274478330658; Val Top1: 0.13123359580052493, Top5: 0.25984251968503935
  VERB   Val Precision: 0.0779, Recall: 0.1357, F1: 0.0930
  NOUN   Val Precision: 0.1106, Recall: 0.1118, F1: 0.0879
  ACTION Val Precision: 0.0763, Recall: 0.0978, F1: 0.0773
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7980270387651353
     NOUN    Mean Top-5 Recall: 0.4107772704506918
     ACTION  Mean Top-5 Recall: 0.26326525498318765
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.7647058823529411
    @ 0.50s  Top1: 0.24324324324324326  Top5: 0.7567567567567568
    @ 0.75s  Top1: 0.227272

Epoch 9 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 9 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 9/20 | Time 5.7s
  Train Loss: 3.2306 | Val Loss: 8.1727
  VERB   Train Top1: 0.5521669341894061, Top5: 0.9486356340288925; Val Top1: 0.2020997375328084, Top5: 0.7926509186351706
  NOUN   Train Top1: 0.6067415730337079, Top5: 0.8860353130016051; Val Top1: 0.1968503937007874, Top5: 0.4409448818897638
  ACTION Train Top1: 0.6500802568218299, Top5: 0.9438202247191011; Val Top1: 0.11023622047244094, Top5: 0.2992125984251969
  VERB   Val Precision: 0.0845, Recall: 0.1331, F1: 0.0906
  NOUN   Val Precision: 0.1290, Recall: 0.1525, F1: 0.1124
  ACTION Val Precision: 0.0613, Recall: 0.0758, F1: 0.0657
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7841672763324239
     NOUN    Mean Top-5 Recall: 0.4443558747800034
     ACTION  Mean Top-5 Recall: 0.30210339741072634
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.14705882352941177  Top5: 0.7352941176470589
    @ 0.50s  Top1: 0.13513513513513514  Top5: 0.7027027027027027
    @ 0.75s  Top1: 0.1363636363

Epoch 10 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 10 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 10/20 | Time 5.6s
  Train Loss: 2.7615 | Val Loss: 7.9064
  VERB   Train Top1: 0.6051364365971108, Top5: 0.9823434991974318; Val Top1: 0.28608923884514437, Top5: 0.7480314960629921
  NOUN   Train Top1: 0.6597110754414125, Top5: 0.9598715890850722; Val Top1: 0.26246719160104987, Top5: 0.5118110236220472
  ACTION Train Top1: 0.6789727126805778, Top5: 0.9759229534510433; Val Top1: 0.15485564304461943, Top5: 0.29396325459317585
  VERB   Val Precision: 0.0908, Recall: 0.1868, F1: 0.1152
  NOUN   Val Precision: 0.1804, Recall: 0.1561, F1: 0.1398
  ACTION Val Precision: 0.0883, Recall: 0.1161, F1: 0.0921
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7367416892444917
     NOUN    Mean Top-5 Recall: 0.5194117276554289
     ACTION  Mean Top-5 Recall: 0.2973876258108803
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.6764705882352942
    @ 0.50s  Top1: 0.21621621621621623  Top5: 0.6486486486486487
    @ 0.75s  Top1: 0.2727272

Epoch 11 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 11 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 11/20 | Time 5.8s
  Train Loss: 2.3846 | Val Loss: 8.2674
  VERB   Train Top1: 0.6163723916532905, Top5: 0.985553772070626; Val Top1: 0.16010498687664043, Top5: 0.7139107611548556
  NOUN   Train Top1: 0.7351524879614767, Top5: 0.9727126805778491; Val Top1: 0.2388451443569554, Top5: 0.4776902887139108
  ACTION Train Top1: 0.812199036918138, Top5: 0.9743178170144462; Val Top1: 0.14698162729658792, Top5: 0.31758530183727035
  VERB   Val Precision: 0.1255, Recall: 0.1842, F1: 0.0976
  NOUN   Val Precision: 0.1444, Recall: 0.1577, F1: 0.1233
  ACTION Val Precision: 0.0778, Recall: 0.1121, F1: 0.0854
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7074110296478384
     NOUN    Mean Top-5 Recall: 0.47975678106906755
     ACTION  Mean Top-5 Recall: 0.3180307058856755
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.6764705882352942
    @ 0.50s  Top1: 0.16216216216216217  Top5: 0.6756756756756757
    @ 0.75s  Top1: 0.159090909

Epoch 12 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 12 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 12/20 | Time 5.6s
  Train Loss: 2.0940 | Val Loss: 7.9282
  VERB   Train Top1: 0.7399678972712681, Top5: 0.985553772070626; Val Top1: 0.32020997375328086, Top5: 0.7007874015748031
  NOUN   Train Top1: 0.8218298555377207, Top5: 0.9919743178170144; Val Top1: 0.26246719160104987, Top5: 0.5118110236220472
  ACTION Train Top1: 0.8475120385232745, Top5: 0.985553772070626; Val Top1: 0.14698162729658792, Top5: 0.28346456692913385
  VERB   Val Precision: 0.1441, Recall: 0.2039, F1: 0.1542
  NOUN   Val Precision: 0.1849, Recall: 0.1594, F1: 0.1474
  ACTION Val Precision: 0.0825, Recall: 0.1121, F1: 0.0873
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6881080258101984
     NOUN    Mean Top-5 Recall: 0.5118273555760211
     ACTION  Mean Top-5 Recall: 0.2860775598671937
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.29411764705882354  Top5: 0.5882352941176471
    @ 0.50s  Top1: 0.24324324324324326  Top5: 0.5945945945945946
    @ 0.75s  Top1: 0.318181818

Epoch 13 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 13 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 13/20 | Time 5.4s
  Train Loss: 1.7619 | Val Loss: 8.0634
  VERB   Train Top1: 0.7768860353130016, Top5: 0.9919743178170144; Val Top1: 0.2073490813648294, Top5: 0.7165354330708661
  NOUN   Train Top1: 0.8202247191011236, Top5: 0.985553772070626; Val Top1: 0.28083989501312334, Top5: 0.4881889763779528
  ACTION Train Top1: 0.8764044943820225, Top5: 0.9903691813804173; Val Top1: 0.15223097112860892, Top5: 0.31758530183727035
  VERB   Val Precision: 0.1809, Recall: 0.2080, F1: 0.1600
  NOUN   Val Precision: 0.2149, Recall: 0.2021, F1: 0.1858
  ACTION Val Precision: 0.1138, Recall: 0.1198, F1: 0.0973
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7078372412872853
     NOUN    Mean Top-5 Recall: 0.48668631999224105
     ACTION  Mean Top-5 Recall: 0.3187715704133047
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.6470588235294118
    @ 0.50s  Top1: 0.21621621621621623  Top5: 0.6486486486486487
    @ 0.75s  Top1: 0.25  Top5

Epoch 14 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 14 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 14/20 | Time 5.7s
  Train Loss: 1.2395 | Val Loss: 8.3215
  VERB   Train Top1: 0.8587479935794543, Top5: 0.9919743178170144; Val Top1: 0.2440944881889764, Top5: 0.6299212598425197
  NOUN   Train Top1: 0.9197431781701445, Top5: 0.9887640449438202; Val Top1: 0.24671916010498687, Top5: 0.4881889763779528
  ACTION Train Top1: 0.9373996789727127, Top5: 0.9903691813804173; Val Top1: 0.16010498687664043, Top5: 0.31758530183727035
  VERB   Val Precision: 0.1871, Recall: 0.2105, F1: 0.1599
  NOUN   Val Precision: 0.1922, Recall: 0.1573, F1: 0.1598
  ACTION Val Precision: 0.0988, Recall: 0.1181, F1: 0.1010
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6227805973737501
     NOUN    Mean Top-5 Recall: 0.4928231285346023
     ACTION  Mean Top-5 Recall: 0.3199018384850295
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.20588235294117646  Top5: 0.5588235294117647
    @ 0.50s  Top1: 0.1891891891891892  Top5: 0.5675675675675675
    @ 0.75s  Top1: 0.204545454

Epoch 15 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 15 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 15/20 | Time 5.6s
  Train Loss: 1.0222 | Val Loss: 8.2169
  VERB   Train Top1: 0.9004815409309791, Top5: 0.9951845906902087; Val Top1: 0.2204724409448819, Top5: 0.6824146981627297
  NOUN   Train Top1: 0.9486356340288925, Top5: 0.9935794542536116; Val Top1: 0.3123359580052493, Top5: 0.5091863517060368
  ACTION Train Top1: 0.9582664526484751, Top5: 0.9951845906902087; Val Top1: 0.2099737532808399, Top5: 0.30446194225721784
  VERB   Val Precision: 0.2191, Recall: 0.1957, F1: 0.1655
  NOUN   Val Precision: 0.2561, Recall: 0.2031, F1: 0.2172
  ACTION Val Precision: 0.1286, Recall: 0.1388, F1: 0.1232
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6745909519284465
     NOUN    Mean Top-5 Recall: 0.5140930122506067
     ACTION  Mean Top-5 Recall: 0.30768497854177634
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.5882352941176471
    @ 0.50s  Top1: 0.1891891891891892  Top5: 0.6486486486486487
    @ 0.75s  Top1: 0.2045454545

Epoch 16 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 16 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 16/20 | Time 5.4s
  Train Loss: 0.7975 | Val Loss: 8.1119
  VERB   Train Top1: 0.9390048154093098, Top5: 0.9951845906902087; Val Top1: 0.2440944881889764, Top5: 0.7086614173228346
  NOUN   Train Top1: 0.9695024077046549, Top5: 0.9983948635634029; Val Top1: 0.31496062992125984, Top5: 0.5144356955380578
  ACTION Train Top1: 0.9727126805778491, Top5: 0.9983948635634029; Val Top1: 0.22572178477690288, Top5: 0.28346456692913385
  VERB   Val Precision: 0.1789, Recall: 0.1767, F1: 0.1642
  NOUN   Val Precision: 0.2264, Recall: 0.1964, F1: 0.2038
  ACTION Val Precision: 0.1311, Recall: 0.1463, F1: 0.1329
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.69967598408518
     NOUN    Mean Top-5 Recall: 0.5161834672767335
     ACTION  Mean Top-5 Recall: 0.2871832109728448
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.6176470588235294
    @ 0.50s  Top1: 0.21621621621621623  Top5: 0.6486486486486487
    @ 0.75s  Top1: 0.2272727272

Epoch 17 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 17 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 17/20 | Time 5.6s
  Train Loss: 0.6419 | Val Loss: 8.1922
  VERB   Train Top1: 0.9390048154093098, Top5: 0.9951845906902087; Val Top1: 0.27296587926509186, Top5: 0.7716535433070866
  NOUN   Train Top1: 0.9727126805778491, Top5: 0.9967897271268058; Val Top1: 0.31758530183727035, Top5: 0.48556430446194226
  ACTION Train Top1: 0.9791332263242376, Top5: 0.9983948635634029; Val Top1: 0.1968503937007874, Top5: 0.2887139107611549
  VERB   Val Precision: 0.2133, Recall: 0.2061, F1: 0.1726
  NOUN   Val Precision: 0.2251, Recall: 0.1977, F1: 0.1972
  ACTION Val Precision: 0.1205, Recall: 0.1299, F1: 0.1180
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7648509046857304
     NOUN    Mean Top-5 Recall: 0.4869150061287352
     ACTION  Mean Top-5 Recall: 0.29190920286922206
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.2647058823529412  Top5: 0.7058823529411765
    @ 0.50s  Top1: 0.24324324324324326  Top5: 0.7297297297297297
    @ 0.75s  Top1: 0.25  Top5

Epoch 18 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 18 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 18/20 | Time 5.6s
  Train Loss: 0.5380 | Val Loss: 8.1342
  VERB   Train Top1: 0.9743178170144462, Top5: 0.9983948635634029; Val Top1: 0.24671916010498687, Top5: 0.7585301837270341
  NOUN   Train Top1: 0.9791332263242376, Top5: 0.9983948635634029; Val Top1: 0.31758530183727035, Top5: 0.5118110236220472
  ACTION Train Top1: 0.9807383627608347, Top5: 1.0; Val Top1: 0.2152230971128609, Top5: 0.31496062992125984
  VERB   Val Precision: 0.1847, Recall: 0.1623, F1: 0.1497
  NOUN   Val Precision: 0.2611, Recall: 0.2063, F1: 0.2136
  ACTION Val Precision: 0.1329, Recall: 0.1429, F1: 0.1310
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7549645345190877
     NOUN    Mean Top-5 Recall: 0.5160130749465315
     ACTION  Mean Top-5 Recall: 0.31988681105610883
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.7352941176470589
    @ 0.50s  Top1: 0.24324324324324326  Top5: 0.7567567567567568
    @ 0.75s  Top1: 0.20454545454545456  Top

Epoch 19 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 19 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 19/20 | Time 5.6s
  Train Loss: 0.4797 | Val Loss: 8.2350
  VERB   Train Top1: 0.9678972712680578, Top5: 1.0; Val Top1: 0.28083989501312334, Top5: 0.6902887139107612
  NOUN   Train Top1: 0.9807383627608347, Top5: 1.0; Val Top1: 0.30708661417322836, Top5: 0.5459317585301837
  ACTION Train Top1: 0.9839486356340289, Top5: 1.0; Val Top1: 0.2125984251968504, Top5: 0.32020997375328086
  VERB   Val Precision: 0.2212, Recall: 0.2273, F1: 0.1873
  NOUN   Val Precision: 0.2670, Recall: 0.2011, F1: 0.2163
  ACTION Val Precision: 0.1269, Recall: 0.1407, F1: 0.1275
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6762647579387865
     NOUN    Mean Top-5 Recall: 0.548948417145612
     ACTION  Mean Top-5 Recall: 0.3235783090732648
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.5588235294117647
    @ 0.50s  Top1: 0.21621621621621623  Top5: 0.5945945945945946
    @ 0.75s  Top1: 0.25  Top5: 0.6363636363636364
    @ 1.00s  Top1: 0.25  

Epoch 20 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 20 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 20/20 | Time 5.3s
  Train Loss: 0.4060 | Val Loss: 8.2631
  VERB   Train Top1: 0.9823434991974318, Top5: 1.0; Val Top1: 0.23622047244094488, Top5: 0.6745406824146981
  NOUN   Train Top1: 0.9823434991974318, Top5: 1.0; Val Top1: 0.2992125984251969, Top5: 0.5065616797900262
  ACTION Train Top1: 0.9839486356340289, Top5: 1.0; Val Top1: 0.2204724409448819, Top5: 0.33858267716535434
  VERB   Val Precision: 0.2607, Recall: 0.1736, F1: 0.1703
  NOUN   Val Precision: 0.2456, Recall: 0.1876, F1: 0.2052
  ACTION Val Precision: 0.1473, Recall: 0.1475, F1: 0.1360
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6594500255287461
     NOUN    Mean Top-5 Recall: 0.5140606670135202
     ACTION  Mean Top-5 Recall: 0.3415092571596121
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.20588235294117646  Top5: 0.5
    @ 0.50s  Top1: 0.16216216216216217  Top5: 0.5945945945945946
    @ 0.75s  Top1: 0.18181818181818182  Top5: 0.6136363636363636
    @ 1.00s  Top1: 0.2083