In [1]:
from tools.imports import *

### User Config

In [2]:
DO_EXTRACT = False 
SAMPLE_RATE = 1  
FEAT_DIM = 512
W_RGB = 0.6
W_FLOW = 0.4
K = 5             
DROP = 0.1

In [3]:
LABEL_CSV   = Path(r"EPIC-Kitchens\Labels\P01_05.csv")
OUTPUT_FUSED_CSV = Path(r"EPIC-Kitchens\Features\FusedFeatures\P01_05_fused_features.csv")

### Loss mode (choose one): 
_"ce","focal","smooth","contrast","graph","combined"_

In [4]:
# Loss mode (choose one): "ce","focal","smooth","contrast","graph","combined"
LOSS_MODE = "combined"

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Outputs
OUTPUT_FUSED_CSV.parent.mkdir(parents=True, exist_ok=True)
BEST_MODEL_PATH = Path(r"./Model/P01_05_best_model.pth")

# Repro
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

<torch._C.Generator at 0x1e958f231f0>

### LOAD FUSED CSV & LABELS

In [5]:

def load_fused_csv_by_path(fused_csv_path: str):
    fp = Path(fused_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Fused features CSV not found: {fp}")
    df = pd.read_csv(fp)
    if "frame_idx" not in df.columns:
        raise KeyError("Fused CSV must contain 'frame_idx' column")
    df["frame_idx"] = df["frame_idx"].astype(int)
    df = df.sort_values("frame_idx").reset_index(drop=True)
    return df

def load_label_csv_by_path(label_csv_path: str):
    fp = Path(label_csv_path)
    if not fp.exists():
        raise FileNotFoundError(f"Label CSV not found: {fp}")
    df = pd.read_csv(fp)
    return df

### DATASET Loader

In [6]:

IGNORE_INDEX = -1

class SingleVideoAnticipationDataset(Dataset):
    def __init__(self
                 , fused_df_or_path, labels_df_or_path,
                 t_obs: int, k_fut: int, feat_dim: int,
                 fps: float, horizons_s):
        # load
        if isinstance(fused_df_or_path, (str, Path)):
            fused_df = pd.read_csv(fused_df_or_path)
        else:
            fused_df = fused_df_or_path.copy()
        if isinstance(labels_df_or_path, (str, Path)):
            labels_df = pd.read_csv(labels_df_or_path)
        else:
            labels_df = labels_df_or_path.copy()

        if "frame_idx" not in fused_df.columns:
            raise KeyError("fused_df must contain 'frame_idx'")
        fused_df["frame_idx"] = fused_df["frame_idx"].astype(int)
        self.fused_df = fused_df.set_index("frame_idx", drop=False).sort_index()
        self.labels_df = labels_df.reset_index(drop=True)

        if not all(c in self.labels_df.columns for c in ["StartFrame", "EndFrame"]):
            raise KeyError("labels_df must contain StartFrame and EndFrame")

        self.t_obs = int(t_obs)
        self.k_fut = int(k_fut)
        self.feat_dim = int(feat_dim)
        self.feat_cols = [f"feat_{i}" for i in range(self.feat_dim)]
        self.fps = float(fps)
        assert len(horizons_s) == self.k_fut, "len(horizons_s) must equal k_fut"
        self.horizons_s = list(horizons_s)

        # build samples: one per label row (use EndFrame as obs_end)
        self.samples = []
        for ridx, row in self.labels_df.iterrows():
            try:
                obs_end = int(row["EndFrame"])
            except:
                continue
            self.samples.append({"label_row_idx": int(ridx), "obs_end": obs_end})

        if len(self.samples) == 0:
            raise RuntimeError("No valid label rows found")

    def __len__(self):
        return len(self.samples)

    def _time_based_future_labels(self, obs_end: int):
        labels_df = self.labels_df
        def pick(cols):
            for c in cols:
                if c in labels_df.columns:
                    return c
            return None
        vcol = pick(["Verb_class","verb","Verb","verb_class"])
        ncol = pick(["Noun_class","noun","Noun","noun_class"])
        acol = pick(["Action_class","action","Action","ActionLabel"])
        verb_targets   = []
        noun_targets   = []
        action_targets = []
        for h_sec in self.horizons_s:
            future_frame = obs_end + int(round(h_sec * self.fps))
            seg = labels_df[(labels_df["StartFrame"] <= future_frame) &
                            (labels_df["EndFrame"]   >= future_frame)]
            if seg.empty:
                verb_targets.append(IGNORE_INDEX)
                noun_targets.append(IGNORE_INDEX)
                action_targets.append(IGNORE_INDEX)
            else:
                row = seg.iloc[0]
                if vcol is not None and not pd.isna(row[vcol]):
                    verb_targets.append(int(row[vcol]))
                else:
                    verb_targets.append(IGNORE_INDEX)
                if ncol is not None and not pd.isna(row[ncol]):
                    noun_targets.append(int(row[ncol]))
                else:
                    noun_targets.append(IGNORE_INDEX)
                if acol is not None and not pd.isna(row[acol]):
                    action_targets.append(int(row[acol]))
                else:
                    action_targets.append(IGNORE_INDEX)
        return {
            "verb":   torch.LongTensor(verb_targets),
            "noun":   torch.LongTensor(noun_targets),
            "action": torch.LongTensor(action_targets)
        }

    def __getitem__(self, idx):
        rec = self.samples[idx]
        obs_end = rec["obs_end"]
        obs_start = obs_end - (self.t_obs - 1)
        if obs_start < 0:
            obs_start = 0
            obs_end = obs_start + (self.t_obs - 1)

        fused_idx_min = int(self.fused_df.index.min())
        fused_idx_max = int(self.fused_df.index.max())
        obs_end = min(obs_end, fused_idx_max)
        obs_start = max(obs_end - (self.t_obs - 1), fused_idx_min)

        desired = list(range(obs_start, obs_end + 1))
        sel = self.fused_df.reindex(desired).fillna(method="ffill").fillna(method="bfill").fillna(0.0)

        if sel.shape[0] < self.t_obs:
            if sel.shape[0] == 0:
                zero_row = {c:0.0 for c in self.feat_cols}
                sel = pd.DataFrame([zero_row] * self.t_obs)
            else:
                first = sel.iloc[[0]]
                pads = pd.concat([first] * (self.t_obs - sel.shape[0]), ignore_index=True)
                sel = pd.concat([pads, sel.reset_index(drop=True)], ignore_index=True)

        for c in self.feat_cols:
            if c not in sel.columns:
                sel[c] = 0.0

        F_window = torch.from_numpy(sel[self.feat_cols].values).float()   # (T_obs, FEAT_DIM)
        y_multi = self._time_based_future_labels(obs_end)
        meta = {"obs_start": int(obs_start),
                "obs_end":   int(obs_end),
                "label_row_idx": int(rec["label_row_idx"])}
        return F_window, y_multi, meta


### GRAPH

In [7]:
def build_topk_edge_index(features: torch.Tensor, k=K, device=None):
    """
    Compute top-k edges on the device of `features`.
    - features: (T, D) float tensor (can be on CPU or GPU)
    - returns: edge_index (2, E) long tensor on same device as features (or 'device' arg)
    """
    if device is None:
        device = features.device
    features = features.to(device)
    Tn = int(features.size(0))
    if Tn == 0:
        return torch.zeros((2, 0), dtype=torch.long, device=device)
    x = F.normalize(features, dim=1)          # (T, D)
    sim = torch.matmul(x, x.t())              # (T, T)
    sim.fill_diagonal_(-1.0)
    kk = min(max(0, int(k)), max(0, Tn - 1))
    if kk <= 0:
        return torch.zeros((2, 0), dtype=torch.long, device=device)
    vals, idxs = torch.topk(sim, kk, dim=1)   # (T, kk)
    src = torch.arange(Tn, device=device).unsqueeze(1).expand(-1, kk).reshape(-1)
    dst = idxs.reshape(-1)
    edge = torch.stack([src, dst], dim=0)
    edge_rev = torch.stack([dst, src], dim=0)
    return torch.cat([edge, edge_rev], dim=1).long()

class BatchedGAT(nn.Module):
    def __init__(self, in_dim, hid_dim=None, num_layers=3, heads=8, dropout=DROP):
        super().__init__()
        hid = hid_dim or in_dim
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_ch = in_dim if i==0 else hid
            self.convs.append(GATConv(in_ch, hid//heads, heads=heads, concat=True, dropout=dropout))
        self.proj = nn.Linear(hid, in_dim)
        self.norm = nn.LayerNorm(in_dim)
        self.act = nn.GELU()
    def forward(self, pyg_batch: PyGBatch, T_per_sample: int):
        x = pyg_batch.x; edge_index = pyg_batch.edge_index
        h = x
        for conv in self.convs:
            h = conv(h, edge_index); h = self.act(h)
        h = self.proj(h)
        node_feats, mask = to_dense_batch(h, batch=pyg_batch.batch)  # (B, max_nodes, D)
        B, max_nodes, D = node_feats.shape
        if max_nodes < T_per_sample:
            pad = torch.zeros(B, T_per_sample - max_nodes, D, device=node_feats.device)
            node_feats = torch.cat([node_feats, pad], dim=1)
        elif max_nodes > T_per_sample:
            node_feats = node_feats[:, :T_per_sample, :]
        return self.norm(node_feats)  # (B, T_per_sample, D)



### Encoder, Decoder, AnticipationModel

In [8]:

class SimpleTransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead=8, num_layers=3, dim_feedforward=2048, dropout=DROP, max_len=1000):
        super().__init__()
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=num_layers)
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model))
    def forward(self, x):
        B,T,D = x.shape
        pos = self.pos_emb[:, :T, :].to(x.device)
        return self.encoder(x + pos)

class AnticipationModel(nn.Module):
    def __init__(self, feat_dim, num_classes: dict, k_fut=5, gat_layers=3, gat_heads=8, dec_layers=3, dec_heads=8, dropout=DROP):
        super().__init__()
        self.feat_dim = feat_dim; self.k_fut = k_fut
        self.gat = BatchedGAT(in_dim=feat_dim, hid_dim=feat_dim, num_layers=gat_layers, heads=gat_heads, dropout=dropout)
        self.encoder = SimpleTransformerEncoder(d_model=feat_dim, nhead=dec_heads, num_layers=3)
        dec_layer = nn.TransformerDecoderLayer(d_model=feat_dim, nhead=dec_heads, dim_feedforward=feat_dim*4, dropout=dropout, activation='gelu', batch_first=True)
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=dec_layers)
        self.queries = nn.Parameter(torch.randn(1, k_fut, feat_dim))
        assert isinstance(num_classes, dict)
        self.verb_head = nn.Linear(feat_dim, num_classes["verb"])
        self.noun_head = nn.Linear(feat_dim, num_classes["noun"])
        self.action_head = nn.Linear(feat_dim, num_classes["action"])

    def forward(self, F_batch):
        # F_batch: (B, T, D)
        B,T,D = F_batch.shape; device = F_batch.device
        data_list=[]
        for b in range(B):
            x = F_batch[b]
            edge_index = build_topk_edge_index(x, k=K, device=device)
            data_list.append(PyGData(x=x, edge_index=edge_index))
        pyg_batch = PyGBatch.from_data_list(data_list).to(device)
        gat_out = self.gat(pyg_batch, T_per_sample=T)   # (B,T,D)
        enc_out = self.encoder(F_batch)                 # (B,T,D)
        U = enc_out + gat_out
        q = self.queries.expand(B, -1, -1).to(device)
        dec_out = self.decoder(tgt=q, memory=U)         # (B, K_fut, D)
        logits = {
            "verb": self.verb_head(dec_out),
            "noun": self.noun_head(dec_out),
            "action": self.action_head(dec_out)
        }
        return logits, dec_out, gat_out

### Loss Functions

In [9]:
class FocalLoss(nn.Module):
    def __init__(self, gamma: float = 2.0, reduction: str = 'mean', ignore_index: int = -1, eps: float = 1e-8):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.eps = eps
    def forward(self, logits: torch.Tensor, target: torch.Tensor):
        B, K, C = logits.shape
        logits_flat = logits.view(-1, C)
        target_flat = target.view(-1)
        mask = (target_flat != self.ignore_index)
        if int(mask.sum().item()) == 0:
            return logits.new_tensor(0.0)
        probs = F.softmax(logits_flat, dim=-1)
        idx = target_flat.clamp_min(0).long()
        pt = probs[torch.arange(probs.size(0), device=probs.device), idx]
        pt = torch.clamp(pt, min=self.eps)
        loss = -((1.0 - pt) ** self.gamma) * torch.log(pt)
        loss = loss[mask]
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

def temporal_smoothness_loss(dec_outs: torch.Tensor):
    if dec_outs is None:
        return torch.tensor(0.0)
    diff = dec_outs[:, 1:, :] - dec_outs[:, :-1, :]
    return diff.pow(2).mean()

def supervised_contrastive_loss(features: torch.Tensor, labels: torch.Tensor, temperature: float = 0.07):
    device = features.device
    features = F.normalize(features, dim=1)
    logits = torch.matmul(features, features.t()) / temperature  # (N, N)
    logits_max, _ = logits.max(dim=1, keepdim=True)
    logits = logits - logits_max.detach()
    labels = labels.contiguous().view(-1, 1)
    mask = torch.eq(labels, labels.t()).float().to(device)
    diag = torch.eye(mask.size(0), device=device)
    mask_non_self = mask * (1.0 - diag)
    exp_logits = torch.exp(logits) * (1.0 - diag)
    log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)
    mean_log_prob_pos = (mask_non_self * log_prob).sum(1) / (mask_non_self.sum(1) + 1e-12)
    loss = - mean_log_prob_pos
    loss = loss.mean()
    return loss

def build_batch_adjacency_from_features(features_batch: torch.Tensor, k: int, symmetric: bool = True):
    device = features_batch.device
    B, T, D = features_batch.shape
    N = B * T
    A_blocks = torch.zeros((N, N), dtype=torch.float32, device=device)
    for b in range(B):
        feats = features_batch[b]  # (T, D)
        x = F.normalize(feats, dim=1)
        sim = torch.matmul(x, x.t())
        sim.fill_diagonal_(-1.0)
        kk = min(k, max(0, T-1))
        if kk <= 0:
            continue
        vals, idxs = torch.topk(sim, kk, dim=1)
        for i in range(T):
            neighbors = idxs[i]
            for nbr in neighbors.tolist():
                A_blocks[b*T + i, b*T + nbr] = 1.0
                if symmetric:
                    A_blocks[b*T + nbr, b*T + i] = 1.0
    return A_blocks

def graph_reconstruction_loss(gat_out: torch.Tensor, features_batch: torch.Tensor, k: int):
    if gat_out is None:
        return torch.tensor(0.0, device=features_batch.device)
    B, T, D = gat_out.shape
    losses = []
    for b in range(B):
        emb = F.normalize(gat_out[b], dim=1)  # (T, D)
        A_hat = torch.sigmoid(torch.matmul(emb, emb.t()))  # (T, T)
        A_gt_block = build_batch_adjacency_from_features(features_batch[b:b+1].detach(), k=k)
        A_gt = A_gt_block.to(A_hat.device)
        losses.append(F.binary_cross_entropy(A_hat, A_gt))
    return torch.stack(losses).mean() if len(losses) > 0 else torch.tensor(0.0, device=gat_out.device)

# -------------------------
# masked CE & metrics helpers
# -------------------------
def masked_cross_entropy(logits, labels, ignore_index=IGNORE_INDEX):
    B, K, C = logits.shape
    logits_flat = logits.view(B * K, C)      # (B*K, C)
    labels_flat = labels.view(B * K)         # (B*K,)
    loss_flat = F.cross_entropy(logits_flat, labels_flat, reduction='none', ignore_index=ignore_index)
    mask = (labels_flat != ignore_index).float()
    valid = mask.sum()
    if valid == 0:
        return (logits_flat * 0.0).sum()
    return (loss_flat * mask).sum() / valid

def topk_accuracy_per_task(logits, labels, topk=(1,5), ignore_index=IGNORE_INDEX):
    B,K,C = logits.shape
    res = {}
    overall = {k:0 for k in topk}
    total_cnt = 0
    preds_topk = logits.topk(max(topk), dim=-1)[1]  # (B,K,maxk)
    for h in range(K):
        lab = labels[:,h]; mask = (lab != ignore_index); cnt = int(mask.sum().item())
        for k in topk:
            if cnt == 0:
                res.setdefault(f"per_h{h+1}_top{k}", None)
                continue
            predk = preds_topk[:,h,:k]  # (B,k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            hits = (predk == lab_exp)
            hit = int(hits[mask].any(dim=1).float().sum().item())
            res[f"per_h{h+1}_top{k}"] = hit / cnt
            overall[k] += hit
        total_cnt += cnt
    for k in topk:
        res[f"overall_top{k}"] = overall[k] / total_cnt if total_cnt>0 else None
    return res

def topk_counts(logits, labels, k):
    with torch.no_grad():
        B, K, C = logits.shape
        topk_preds = logits.topk(k, dim=-1)[1]  # (B, K, k)
        hits = 0
        total = 0
        for h in range(K):
            lab = labels[:, h]  # (B,)
            mask = (lab != IGNORE_INDEX)
            if int(mask.sum().item()) == 0:
                continue
            predk = topk_preds[:, h, :]  # (B, k)
            lab_exp = lab.unsqueeze(1).expand(-1, k)
            masked_pred = predk[mask]
            masked_lab = lab_exp[mask]
            hit_vec = (masked_pred == masked_lab).any(dim=1).float()
            hits += int(hit_vec.sum().item())
            total += int(mask.sum().item())
        return hits, total

### MAIN: Prepare dataset, model, and training

In [10]:
FUSED_CSV_PATH = str(OUTPUT_FUSED_CSV) 
LABEL_CSV_PATH = str(LABEL_CSV)

# load fused and labels
fused_df = load_fused_csv_by_path(FUSED_CSV_PATH)
labels_df = load_label_csv_by_path(LABEL_CSV_PATH)

# Training & dataset
T_OBS = 90
FPS = 30.0
HORIZONS_S = [0.25, 0.5, 0.75, 1.0, 1.25, 1.50, 1.75, 2.0]
K_FUT = len(HORIZONS_S)

BATCH_SIZE = 8
NUM_EPOCHS = 20
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 0

# dataset
dataset = SingleVideoAnticipationDataset(
    fused_df,
    labels_df,
    t_obs=T_OBS,
    k_fut=K_FUT,
    feat_dim=FEAT_DIM,
    fps=FPS,
    horizons_s=HORIZONS_S
)


# split (60/40)
indices = list(range(len(dataset)))
random.seed(SEED)
random.shuffle(indices)
split_at = int(0.6 * len(indices))
train_idx = indices[:split_at]; val_idx = indices[split_at:]
train_ds = Subset(dataset, train_idx); val_ds = Subset(dataset, val_idx)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=(DEVICE=="cuda"))
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=(DEVICE=="cuda"))

# detect num classes
def detect_num_classes_from_labels_df(labels_df):
    verbs = set(); nouns = set(); actions = set()
    for cand in ["Verb_class","verb","Verb","verb_class"]:
        if cand in labels_df.columns:
            verbs.update(labels_df[cand].dropna().astype(int).tolist()); break
    for cand in ["Noun_class","noun","Noun","noun_class"]:
        if cand in labels_df.columns:
            nouns.update(labels_df[cand].dropna().astype(int).tolist()); break
    for cand in ["Action_class","action","Action","ActionLabel"]:
        if cand in labels_df.columns:
            actions.update(labels_df[cand].dropna().astype(int).tolist()); break
    nv = (max(verbs) + 1) if len(verbs) > 0 else 1
    nn_ = (max(nouns) + 1) if len(nouns) > 0 else 1
    na = (max(actions) + 1) if len(actions) > 0 else 1
    return {"verb": int(nv), "noun": int(nn_), "action": int(na)}

num_classes = detect_num_classes_from_labels_df(labels_df)
print("Detected num_classes:", num_classes)

# instantiate model, optimizer, scheduler
model = AnticipationModel(feat_dim=FEAT_DIM, num_classes=num_classes, k_fut=K_FUT).to(DEVICE)
opt = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=3)

# instantiate auxiliary helpers and weights (used by LOSS_MODE)
focal_fn = FocalLoss(gamma=2.0, ignore_index=IGNORE_INDEX)
focal_alpha = 0.3
smooth_weight = 0.1
contrast_weight = 0.1
contrast_temperature = 0.07
graph_rec_weight = 0.05

# training loop
best_val_loss = float("inf")
for epoch in range(1, NUM_EPOCHS + 1):
    t0 = time.time()
    model.train()
    train_loss_sum = 0.0; train_samples = 0
    train_counts = {"verb_top1":[0,0],"verb_top5":[0,0],"noun_top1":[0,0],"noun_top5":[0,0],"action_top1":[0,0],"action_top5":[0,0]}

    for F_batch, y_multi, meta in tqdm(train_loader, desc=f"Epoch {epoch} Train"):
        F_batch = F_batch.to(DEVICE)
        y_v = y_multi["verb"].to(DEVICE)
        y_n = y_multi["noun"].to(DEVICE)
        y_a = y_multi["action"].to(DEVICE)

        opt.zero_grad()
        logits, dec_outs, gat_out = model(F_batch)

        # base CE
        loss_v_ce = masked_cross_entropy(logits["verb"], y_v)
        loss_n_ce = masked_cross_entropy(logits["noun"], y_n)
        loss_a_ce = masked_cross_entropy(logits["action"], y_a)
        base_loss = loss_a_ce + 0.5 * loss_v_ce + 0.5 * loss_n_ce

        # init aux terms
        focal_term = torch.tensor(0.0, device=F_batch.device)
        smooth_term = torch.tensor(0.0, device=F_batch.device)
        contrast_term = torch.tensor(0.0, device=F_batch.device)
        graph_term = torch.tensor(0.0, device=F_batch.device)

        # LOSS_MODE branches
        if LOSS_MODE == "ce":
            loss = base_loss

        elif LOSS_MODE == "focal":
            f_v = focal_fn(logits["verb"], y_v)
            f_n = focal_fn(logits["noun"], y_n)
            f_a = focal_fn(logits["action"], y_a)
            loss_v = (1.0 - focal_alpha) * loss_v_ce + focal_alpha * f_v
            loss_n = (1.0 - focal_alpha) * loss_n_ce + focal_alpha * f_n
            loss_a = (1.0 - focal_alpha) * loss_a_ce + focal_alpha * f_a
            focal_term = (f_v + f_n + f_a) / 3.0
            loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

        elif LOSS_MODE == "smooth":
            loss = base_loss
            if dec_outs is not None:
                smooth_term = temporal_smoothness_loss(dec_outs)
                loss = loss + smooth_weight * smooth_term

        elif LOSS_MODE == "contrast":
            loss = base_loss
            feats_for_contrast = F.normalize(F_batch[:, -1, :], dim=1)
            labels_for_contrast = y_a[:, 0].clone().detach()
            valid_mask = (labels_for_contrast != IGNORE_INDEX)
            if int(valid_mask.sum().item()) > 1:
                contrast_term = supervised_contrastive_loss(feats_for_contrast[valid_mask], labels_for_contrast[valid_mask], temperature=contrast_temperature)
                loss = loss + contrast_weight * contrast_term

        elif LOSS_MODE == "graph":
            loss = base_loss
            if gat_out is not None:
                graph_term = graph_reconstruction_loss(gat_out, F_batch.detach(), k=K)
                loss = loss + graph_rec_weight * graph_term

        elif LOSS_MODE == "combined":
            f_v = focal_fn(logits["verb"], y_v)
            f_n = focal_fn(logits["noun"], y_n)
            f_a = focal_fn(logits["action"], y_a)
            loss_v = (1.0 - focal_alpha) * loss_v_ce + focal_alpha * f_v
            loss_n = (1.0 - focal_alpha) * loss_n_ce + focal_alpha * f_n
            loss_a = (1.0 - focal_alpha) * loss_a_ce + focal_alpha * f_a
            focal_term = (f_v + f_n + f_a) / 3.0
            if dec_outs is not None:
                smooth_term = temporal_smoothness_loss(dec_outs)
            feats_for_contrast = F.normalize(F_batch[:, -1, :], dim=1)
            labels_for_contrast = y_a[:, 0].clone().detach()
            if int((labels_for_contrast != IGNORE_INDEX).sum().item()) > 1:
                contrast_term = supervised_contrastive_loss(feats_for_contrast[labels_for_contrast != IGNORE_INDEX], labels_for_contrast[labels_for_contrast != IGNORE_INDEX], temperature=contrast_temperature)
            if gat_out is not None:
                graph_term = graph_reconstruction_loss(gat_out, F_batch.detach(), k=K)
            loss = loss_a + 0.5 * loss_v + 0.5 * loss_n
            loss = loss + smooth_weight * smooth_term + contrast_weight * contrast_term + graph_rec_weight * graph_term

        else:
            raise ValueError(f"Unknown LOSS_MODE: {LOSS_MODE}")

        # backward + step
        loss.backward()
        opt.step()

        # bookkeeping
        b = F_batch.size(0)
        train_loss_sum += float(loss.item()) * b
        train_samples += b

        for (task, lab, lg) in [("verb", y_v, logits["verb"]), ("noun", y_n, logits["noun"]), ("action", y_a, logits["action"])]:
            h1, t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
            h5, t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
            train_counts[f"{task}_top1"][0] += h1; train_counts[f"{task}_top1"][1] += t1
            train_counts[f"{task}_top5"][0] += h5; train_counts[f"{task}_top5"][1] += t5

    # train metrics
    train_loss = train_loss_sum / max(1, train_samples)
    train_metrics = {}
    for task in ["verb","noun","action"]:
        h1,t1 = train_counts[f"{task}_top1"]; h5,t5 = train_counts[f"{task}_top5"]
        train_metrics[f"{task}_top1"] = (h1 / t1) if t1>0 else None
        train_metrics[f"{task}_top5"] = (h5 / t5) if t5>0 else None

    # ------------- VALIDATION -------------
    model.eval()
    val_loss_sum = 0.0; val_samples = 0
    val_counts = {"verb_top1":[0,0],"verb_top5":[0,0],"noun_top1":[0,0],"noun_top5":[0,0],"action_top1":[0,0],"action_top5":[0,0]}
    val_logits_store = {"verb": [], "noun": [], "action": []}
    val_labels_store = {"verb": [], "noun": [], "action": []}

    with torch.no_grad():
        for F_batch, y_multi, meta in tqdm(val_loader, desc=f"Epoch {epoch} Val"):
            F_batch = F_batch.to(DEVICE)
            y_v = y_multi["verb"].to(DEVICE)
            y_n = y_multi["noun"].to(DEVICE)
            y_a = y_multi["action"].to(DEVICE)

            logits, _, _ = model(F_batch)
            loss_v = masked_cross_entropy(logits["verb"], y_v)
            loss_n = masked_cross_entropy(logits["noun"], y_n)
            loss_a = masked_cross_entropy(logits["action"], y_a)
            loss = loss_a + 0.5 * loss_v + 0.5 * loss_n

            b = F_batch.size(0)
            val_loss_sum += float(loss.item()) * b
            val_samples += b

            for (task, lab, lg) in [("verb", y_v, logits["verb"]), ("noun", y_n, logits["noun"]), ("action", y_a, logits["action"])]:
                h1,t1 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=1)
                h5,t5 = topk_counts(lg.detach().cpu(), lab.detach().cpu(), k=5)
                val_counts[f"{task}_top1"][0] += h1; val_counts[f"{task}_top1"][1] += t1
                val_counts[f"{task}_top5"][0] += h5; val_counts[f"{task}_top5"][1] += t5

            val_logits_store["verb"].append(logits["verb"].detach().cpu())
            val_logits_store["noun"].append(logits["noun"].detach().cpu())
            val_logits_store["action"].append(logits["action"].detach().cpu())
            val_labels_store["verb"].append(y_v.detach().cpu())
            val_labels_store["noun"].append(y_n.detach().cpu())
            val_labels_store["action"].append(y_a.detach().cpu())

    val_loss = val_loss_sum / max(1, val_samples)

    # overall val metrics
    val_metrics = {}
    for task in ["verb","noun","action"]:
        h1,t1 = val_counts[f"{task}_top1"]; h5,t5 = val_counts[f"{task}_top5"]
        val_metrics[f"{task}_top1"] = (h1 / t1) if t1>0 else None
        val_metrics[f"{task}_top5"] = (h5 / t5) if t5>0 else None

    # per-horizon metrics
    per_horizon_metrics = {"verb":{},"noun":{},"action":{}}
    for task in ["verb","noun","action"]:
        if len(val_logits_store[task]) == 0:
            continue
        logits_all = torch.cat(val_logits_store[task], dim=0)
        labels_all = torch.cat(val_labels_store[task], dim=0)
        m = topk_accuracy_per_task(logits_all, labels_all, topk=(1,5), ignore_index=IGNORE_INDEX)
        per_horizon_metrics[task] = m

    # PRF macro
    prf_metrics = {"verb":{}, "noun":{}, "action":{}}
    for task in ["verb","noun","action"]:
        if len(val_logits_store[task]) == 0:
            continue
        logits_all = torch.cat(val_logits_store[task], dim=0)
        labels_all = torch.cat(val_labels_store[task], dim=0)
        preds_all = logits_all.argmax(dim=-1)
        mask = (labels_all != IGNORE_INDEX)
        if mask.sum().item() == 0:
            continue
        y_true = labels_all[mask].numpy()
        y_pred = preds_all[mask].numpy()
        p,r,f1,_ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
        prf_metrics[task]["precision"] = p; prf_metrics[task]["recall"] = r; prf_metrics[task]["f1"] = f1

    # mean top-5 recall
    mean_top5_recall = {}
    for task in ["verb","noun","action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            mean_top5_recall[task] = None; continue
        vals = []
        for h_idx in range(K_FUT):
            key = f"per_h{h_idx+1}_top5"
            if key in mh and mh[key] is not None:
                vals.append(mh[key])
        mean_top5_recall[task] = float(np.mean(vals)) if len(vals) > 0 else None

    sched.step(val_loss)
    elapsed = time.time() - t0

    # print summary
    print(f"Epoch {epoch}/{NUM_EPOCHS} | Time {elapsed:.1f}s")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    for task in ["verb","noun","action"]:
        print(f"  {task.upper():6s} Train Top1: {train_metrics[f'{task}_top1']}, Top5: {train_metrics[f'{task}_top5']}; Val Top1: {val_metrics[f'{task}_top1']}, Top5: {val_metrics[f'{task}_top5']}")
    for task in ["verb","noun","action"]:
        if prf_metrics[task]:
            print(f"  {task.upper():6s} Val Precision: {prf_metrics[task]['precision']:.4f}, Recall: {prf_metrics[task]['recall']:.4f}, F1: {prf_metrics[task]['f1']:.4f}")
    print("  ---- Mean Top-5 Recall (validation) ----")
    for task in ["verb","noun","action"]:
        print(f"     {task.upper():6s}  Mean Top-5 Recall: {mean_top5_recall[task]}")
    # per-horizon
    for task in ["verb","noun","action"]:
        mh = per_horizon_metrics[task]
        if not mh:
            continue
        print(f"  {task.upper():6s} per-horizon (time-based):")
        for h_idx, t_sec in enumerate(HORIZONS_S):
            key1 = f"per_h{h_idx+1}_top1"; key5 = f"per_h{h_idx+1}_top5"
            v1 = mh.get(key1, None); v5 = mh.get(key5, None)
            print(f"    @ {t_sec:4.2f}s  Top1: {v1}  Top5: {v5}")
        print(f"    overall_top1: {mh.get('overall_top1', None)}, overall_top5: {mh.get('overall_top5', None)}")

    # save best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'opt_state': opt.state_dict(), 'val_loss': val_loss}, BEST_MODEL_PATH)
        print(f"[SAVED BEST] -> {BEST_MODEL_PATH}")

print("Training finished.")

Detected num_classes: {'verb': 81, 'noun': 332, 'action': 123}


Epoch 1 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 1 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 1/20 | Time 12.0s
  Train Loss: 9.0053 | Val Loss: 8.8257
  VERB   Train Top1: 0.1508828250401284, Top5: 0.5971107544141252; Val Top1: 0.2440944881889764, Top5: 0.6719160104986877
  NOUN   Train Top1: 0.10914927768860354, Top5: 0.32102728731942215; Val Top1: 0.015748031496062992, Top5: 0.17060367454068243
  ACTION Train Top1: 0.0449438202247191, Top5: 0.1781701444622793; Val Top1: 0.0, Top5: 0.03937007874015748
  VERB   Val Precision: 0.0163, Recall: 0.0667, F1: 0.0262
  NOUN   Val Precision: 0.0005, Recall: 0.0323, F1: 0.0010
  ACTION Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6650958135107943
     NOUN    Mean Top-5 Recall: 0.17193746116259995
     ACTION  Mean Top-5 Recall: 0.04016004372325099
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.6176470588235294
    @ 0.50s  Top1: 0.1891891891891892  Top5: 0.6216216216216216
    @ 0.75s  Top1: 0.2727272727272727  T

Epoch 2 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 2 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 2/20 | Time 11.4s
  Train Loss: 7.7290 | Val Loss: 8.7477
  VERB   Train Top1: 0.19743178170144463, Top5: 0.7062600321027287; Val Top1: 0.05511811023622047, Top5: 0.6272965879265092
  NOUN   Train Top1: 0.12359550561797752, Top5: 0.36597110754414125; Val Top1: 0.015748031496062992, Top5: 0.30971128608923887
  ACTION Train Top1: 0.02247191011235955, Top5: 0.18619582664526485; Val Top1: 0.005249343832020997, Top5: 0.023622047244094488
  VERB   Val Precision: 0.0037, Recall: 0.0667, F1: 0.0070
  NOUN   Val Precision: 0.0005, Recall: 0.0323, F1: 0.0010
  ACTION Val Precision: 0.0003, Recall: 0.0062, F1: 0.0005
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6194906940299676
     NOUN    Mean Top-5 Recall: 0.3082541353934609
     ACTION  Mean Top-5 Recall: 0.022451785665728228
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.058823529411764705  Top5: 0.5588235294117647
    @ 0.50s  Top1: 0.05405405405405406  Top5: 0.5675675675675675
    @ 0.75s  Top

Epoch 3 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 3 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 3/20 | Time 11.6s
  Train Loss: 6.9743 | Val Loss: 8.8058
  VERB   Train Top1: 0.17335473515248795, Top5: 0.6918138041733547; Val Top1: 0.26246719160104987, Top5: 0.6272965879265092
  NOUN   Train Top1: 0.1187800963081862, Top5: 0.42375601926163725; Val Top1: 0.015748031496062992, Top5: 0.28608923884514437
  ACTION Train Top1: 0.06741573033707865, Top5: 0.2712680577849117; Val Top1: 0.015748031496062992, Top5: 0.015748031496062992
  VERB   Val Precision: 0.0175, Recall: 0.0667, F1: 0.0277
  NOUN   Val Precision: 0.0005, Recall: 0.0323, F1: 0.0011
  ACTION Val Precision: 0.0003, Recall: 0.0192, F1: 0.0006
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6194906940299676
     NOUN    Mean Top-5 Recall: 0.2886414244659948
     ACTION  Mean Top-5 Recall: 0.015830903523067503
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.2647058823529412  Top5: 0.5588235294117647
    @ 0.50s  Top1: 0.2702702702702703  Top5: 0.5675675675675675
    @ 0.75s  Top1: 0.

Epoch 4 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 4 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 4/20 | Time 10.8s
  Train Loss: 6.6533 | Val Loss: 8.9167
  VERB   Train Top1: 0.18940609951845908, Top5: 0.7191011235955056; Val Top1: 0.06561679790026247, Top5: 0.6272965879265092
  NOUN   Train Top1: 0.15730337078651685, Top5: 0.42696629213483145; Val Top1: 0.05774278215223097, Top5: 0.2178477690288714
  ACTION Train Top1: 0.0754414125200642, Top5: 0.2825040128410915; Val Top1: 0.0, Top5: 0.015748031496062992
  VERB   Val Precision: 0.0097, Recall: 0.0548, F1: 0.0134
  NOUN   Val Precision: 0.0231, Recall: 0.0458, F1: 0.0183
  ACTION Val Precision: 0.0000, Recall: 0.0000, F1: 0.0000
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6194906940299676
     NOUN    Mean Top-5 Recall: 0.21779435049658974
     ACTION  Mean Top-5 Recall: 0.015830903523067503
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.08823529411764706  Top5: 0.5588235294117647
    @ 0.50s  Top1: 0.08108108108108109  Top5: 0.5675675675675675
    @ 0.75s  Top1: 0.0454545454545454

Epoch 5 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 5 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 5/20 | Time 10.5s
  Train Loss: 6.5472 | Val Loss: 8.6198
  VERB   Train Top1: 0.21829855537720708, Top5: 0.723916532905297; Val Top1: 0.24671916010498687, Top5: 0.6614173228346457
  NOUN   Train Top1: 0.2247191011235955, Top5: 0.48796147672552165; Val Top1: 0.06561679790026247, Top5: 0.2047244094488189
  ACTION Train Top1: 0.12680577849117175, Top5: 0.33707865168539325; Val Top1: 0.04199475065616798, Top5: 0.12860892388451445
  VERB   Val Precision: 0.0307, Recall: 0.0637, F1: 0.0397
  NOUN   Val Precision: 0.0396, Recall: 0.0566, F1: 0.0313
  ACTION Val Precision: 0.0117, Recall: 0.0175, F1: 0.0140
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6541355819920195
     NOUN    Mean Top-5 Recall: 0.20537409531317669
     ACTION  Mean Top-5 Recall: 0.12946688012427443
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.5882352941176471
    @ 0.50s  Top1: 0.21621621621621623  Top5: 0.6216216216216216
    @ 0.75s  Top1: 0.22

Epoch 6 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 6 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 6/20 | Time 10.8s
  Train Loss: 5.7454 | Val Loss: 8.6301
  VERB   Train Top1: 0.22792937399678972, Top5: 0.7929373996789727; Val Top1: 0.2440944881889764, Top5: 0.6745406824146981
  NOUN   Train Top1: 0.30658105939004815, Top5: 0.5698234349919743; Val Top1: 0.14960629921259844, Top5: 0.2755905511811024
  ACTION Train Top1: 0.23274478330658105, Top5: 0.5040128410914928; Val Top1: 0.03674540682414698, Top5: 0.11286089238845144
  VERB   Val Precision: 0.0681, Recall: 0.1182, F1: 0.0849
  NOUN   Val Precision: 0.0700, Recall: 0.0792, F1: 0.0655
  ACTION Val Precision: 0.0182, Recall: 0.0159, F1: 0.0170
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6688100941413768
     NOUN    Mean Top-5 Recall: 0.27522009786325335
     ACTION  Mean Top-5 Recall: 0.11381738952866315
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.20588235294117646  Top5: 0.6470588235294118
    @ 0.50s  Top1: 0.24324324324324326  Top5: 0.6216216216216216
    @ 0.75s  Top1: 0.227

Epoch 7 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 7 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 7/20 | Time 11.0s
  Train Loss: 5.2387 | Val Loss: 8.3512
  VERB   Train Top1: 0.27447833065810595, Top5: 0.8154093097913323; Val Top1: 0.13385826771653545, Top5: 0.7244094488188977
  NOUN   Train Top1: 0.3563402889245586, Top5: 0.6324237560192616; Val Top1: 0.15748031496062992, Top5: 0.30971128608923887
  ACTION Train Top1: 0.29373996789727125, Top5: 0.6035313001605136; Val Top1: 0.06036745406824147, Top5: 0.24671916010498687
  VERB   Val Precision: 0.0783, Recall: 0.1039, F1: 0.0699
  NOUN   Val Precision: 0.1130, Recall: 0.1044, F1: 0.0756
  ACTION Val Precision: 0.0268, Recall: 0.0294, F1: 0.0251
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7174727171772984
     NOUN    Mean Top-5 Recall: 0.31337565331206574
     ACTION  Mean Top-5 Recall: 0.2480436087333973
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.14705882352941177  Top5: 0.6470588235294118
    @ 0.50s  Top1: 0.16216216216216217  Top5: 0.7027027027027027
    @ 0.75s  Top1: 0.090

Epoch 8 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 8 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 8/20 | Time 11.1s
  Train Loss: 4.4822 | Val Loss: 8.1327
  VERB   Train Top1: 0.3402889245585875, Top5: 0.884430176565008; Val Top1: 0.25984251968503935, Top5: 0.7611548556430446
  NOUN   Train Top1: 0.39646869983948635, Top5: 0.7736757624398074; Val Top1: 0.1679790026246719, Top5: 0.4094488188976378
  ACTION Train Top1: 0.40770465489566615, Top5: 0.78330658105939; Val Top1: 0.08136482939632546, Top5: 0.2677165354330709
  VERB   Val Precision: 0.0746, Recall: 0.1205, F1: 0.0901
  NOUN   Val Precision: 0.0491, Recall: 0.0844, F1: 0.0568
  ACTION Val Precision: 0.0370, Recall: 0.0516, F1: 0.0384
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7564614173737292
     NOUN    Mean Top-5 Recall: 0.40794679955997615
     ACTION  Mean Top-5 Recall: 0.2713144974074301
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.2647058823529412  Top5: 0.7352941176470589
    @ 0.50s  Top1: 0.2972972972972973  Top5: 0.7297297297297297
    @ 0.75s  Top1: 0.25  Top5: 0

Epoch 9 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 9 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 9/20 | Time 10.9s
  Train Loss: 4.0011 | Val Loss: 8.0384
  VERB   Train Top1: 0.3595505617977528, Top5: 0.9020866773675762; Val Top1: 0.27296587926509186, Top5: 0.7506561679790026
  NOUN   Train Top1: 0.4622792937399679, Top5: 0.7688603531300161; Val Top1: 0.2388451443569554, Top5: 0.43832020997375326
  ACTION Train Top1: 0.48796147672552165, Top5: 0.8330658105939005; Val Top1: 0.14173228346456693, Top5: 0.2755905511811024
  VERB   Val Precision: 0.1393, Recall: 0.1312, F1: 0.1165
  NOUN   Val Precision: 0.1333, Recall: 0.1560, F1: 0.1271
  ACTION Val Precision: 0.0707, Recall: 0.1030, F1: 0.0787
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7437493671913711
     NOUN    Mean Top-5 Recall: 0.44260403193526127
     ACTION  Mean Top-5 Recall: 0.2779353795500908
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.3235294117647059  Top5: 0.6764705882352942
    @ 0.50s  Top1: 0.2972972972972973  Top5: 0.7027027027027027
    @ 0.75s  Top1: 0.25  Top5

Epoch 10 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 10 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 10/20 | Time 10.9s
  Train Loss: 3.2969 | Val Loss: 8.0135
  VERB   Train Top1: 0.4654895666131621, Top5: 0.9373996789727127; Val Top1: 0.2230971128608924, Top5: 0.7716535433070866
  NOUN   Train Top1: 0.5666131621187801, Top5: 0.8956661316211878; Val Top1: 0.2047244094488189, Top5: 0.4435695538057743
  ACTION Train Top1: 0.593900481540931, Top5: 0.9117174959871589; Val Top1: 0.14173228346456693, Top5: 0.2782152230971129
  VERB   Val Precision: 0.1388, Recall: 0.1527, F1: 0.1165
  NOUN   Val Precision: 0.1039, Recall: 0.1281, F1: 0.1035
  ACTION Val Precision: 0.0891, Recall: 0.1218, F1: 0.0902
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7641897849746881
     NOUN    Mean Top-5 Recall: 0.4452002417711489
     ACTION  Mean Top-5 Recall: 0.28012836200623115
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.6764705882352942
    @ 0.50s  Top1: 0.16216216216216217  Top5: 0.7297297297297297
    @ 0.75s  Top1: 0.204545454

Epoch 11 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 11 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 11/20 | Time 10.6s
  Train Loss: 2.5139 | Val Loss: 7.8821
  VERB   Train Top1: 0.5842696629213483, Top5: 0.9486356340288925; Val Top1: 0.2230971128608924, Top5: 0.7585301837270341
  NOUN   Train Top1: 0.6918138041733547, Top5: 0.9454253611556982; Val Top1: 0.27034120734908135, Top5: 0.4304461942257218
  ACTION Train Top1: 0.7303370786516854, Top5: 0.9534510433386838; Val Top1: 0.14698162729658792, Top5: 0.28608923884514437
  VERB   Val Precision: 0.1151, Recall: 0.1346, F1: 0.0998
  NOUN   Val Precision: 0.1936, Recall: 0.1798, F1: 0.1636
  ACTION Val Precision: 0.0823, Recall: 0.1059, F1: 0.0875
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7522852875403718
     NOUN    Mean Top-5 Recall: 0.4333624224476085
     ACTION  Mean Top-5 Recall: 0.2881529874643999
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.7058823529411765
    @ 0.50s  Top1: 0.16216216216216217  Top5: 0.7027027027027027
    @ 0.75s  Top1: 0.1818181

Epoch 12 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 12 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 12/20 | Time 11.0s
  Train Loss: 2.1195 | Val Loss: 7.5287
  VERB   Train Top1: 0.6548956661316212, Top5: 0.9887640449438202; Val Top1: 0.4120734908136483, Top5: 0.8031496062992126
  NOUN   Train Top1: 0.7672552166934189, Top5: 0.9791332263242376; Val Top1: 0.28083989501312334, Top5: 0.45144356955380577
  ACTION Train Top1: 0.7784911717495987, Top5: 0.9807383627608347; Val Top1: 0.2073490813648294, Top5: 0.28608923884514437
  VERB   Val Precision: 0.2110, Recall: 0.2305, F1: 0.2084
  NOUN   Val Precision: 0.1530, Recall: 0.1860, F1: 0.1602
  ACTION Val Precision: 0.0976, Recall: 0.1463, F1: 0.1090
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7958009270285746
     NOUN    Mean Top-5 Recall: 0.45461486198498446
     ACTION  Mean Top-5 Recall: 0.28716042835941813
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.35294117647058826  Top5: 0.7352941176470589
    @ 0.50s  Top1: 0.2972972972972973  Top5: 0.7567567567567568
    @ 0.75s  Top1: 0.431818

Epoch 13 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 13 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 13/20 | Time 10.7s
  Train Loss: 1.7302 | Val Loss: 7.8411
  VERB   Train Top1: 0.7126805778491172, Top5: 0.9919743178170144; Val Top1: 0.29658792650918636, Top5: 0.7112860892388452
  NOUN   Train Top1: 0.8041733547351525, Top5: 0.9919743178170144; Val Top1: 0.2755905511811024, Top5: 0.4540682414698163
  ACTION Train Top1: 0.8876404494382022, Top5: 0.9871589085072231; Val Top1: 0.18110236220472442, Top5: 0.28346456692913385
  VERB   Val Precision: 0.2109, Recall: 0.1542, F1: 0.1370
  NOUN   Val Precision: 0.1939, Recall: 0.1823, F1: 0.1780
  ACTION Val Precision: 0.1163, Recall: 0.1447, F1: 0.1101
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7004090181524926
     NOUN    Mean Top-5 Recall: 0.4589392592079885
     ACTION  Mean Top-5 Recall: 0.28459407173509876
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.6176470588235294
    @ 0.50s  Top1: 0.2702702702702703  Top5: 0.6216216216216216
    @ 0.75s  Top1: 0.25  Top

Epoch 14 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 14 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 14/20 | Time 10.7s
  Train Loss: 1.4029 | Val Loss: 7.7438
  VERB   Train Top1: 0.8073836276083467, Top5: 0.9967897271268058; Val Top1: 0.25196850393700787, Top5: 0.7112860892388452
  NOUN   Train Top1: 0.8747993579454254, Top5: 0.9919743178170144; Val Top1: 0.2388451443569554, Top5: 0.48293963254593175
  ACTION Train Top1: 0.8940609951845907, Top5: 0.9935794542536116; Val Top1: 0.23622047244094488, Top5: 0.29396325459317585
  VERB   Val Precision: 0.1771, Recall: 0.1678, F1: 0.1497
  NOUN   Val Precision: 0.2394, Recall: 0.1720, F1: 0.1712
  ACTION Val Precision: 0.1498, Recall: 0.1676, F1: 0.1431
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6988499099357032
     NOUN    Mean Top-5 Recall: 0.482908855223397
     ACTION  Mean Top-5 Recall: 0.2947360595647133
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.20588235294117646  Top5: 0.5882352941176471
    @ 0.50s  Top1: 0.1891891891891892  Top5: 0.6216216216216216
    @ 0.75s  Top1: 0.22727272

Epoch 15 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 15 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 15/20 | Time 10.8s
  Train Loss: 1.1285 | Val Loss: 7.6277
  VERB   Train Top1: 0.884430176565008, Top5: 0.9935794542536116; Val Top1: 0.30183727034120733, Top5: 0.7401574803149606
  NOUN   Train Top1: 0.9197431781701445, Top5: 0.9951845906902087; Val Top1: 0.28083989501312334, Top5: 0.4671916010498688
  ACTION Train Top1: 0.942215088282504, Top5: 0.9967897271268058; Val Top1: 0.1889763779527559, Top5: 0.30446194225721784
  VERB   Val Precision: 0.1861, Recall: 0.1821, F1: 0.1625
  NOUN   Val Precision: 0.2249, Recall: 0.1937, F1: 0.1883
  ACTION Val Precision: 0.1064, Recall: 0.1310, F1: 0.1080
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7283683065842175
     NOUN    Mean Top-5 Recall: 0.4720468084990811
     ACTION  Mean Top-5 Recall: 0.309144926930604
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.2647058823529412  Top5: 0.6176470588235294
    @ 0.50s  Top1: 0.24324324324324326  Top5: 0.6486486486486487
    @ 0.75s  Top1: 0.29545454545

Epoch 16 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 16 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 16/20 | Time 10.6s
  Train Loss: 0.8482 | Val Loss: 7.6182
  VERB   Train Top1: 0.9101123595505618, Top5: 0.9967897271268058; Val Top1: 0.2755905511811024, Top5: 0.7244094488188977
  NOUN   Train Top1: 0.9486356340288925, Top5: 0.9983948635634029; Val Top1: 0.28608923884514437, Top5: 0.47244094488188976
  ACTION Train Top1: 0.9598715890850722, Top5: 0.9983948635634029; Val Top1: 0.2152230971128609, Top5: 0.3123359580052493
  VERB   Val Precision: 0.1774, Recall: 0.1714, F1: 0.1544
  NOUN   Val Precision: 0.2070, Recall: 0.1902, F1: 0.1754
  ACTION Val Precision: 0.1461, Recall: 0.1489, F1: 0.1285
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7138726238035982
     NOUN    Mean Top-5 Recall: 0.4787780727132307
     ACTION  Mean Top-5 Recall: 0.31458041315102675
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.6176470588235294
    @ 0.50s  Top1: 0.16216216216216217  Top5: 0.6486486486486487
    @ 0.75s  Top1: 0.2045454

Epoch 17 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 17 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 17/20 | Time 11.2s
  Train Loss: 0.7391 | Val Loss: 7.7998
  VERB   Train Top1: 0.9309791332263242, Top5: 0.9967897271268058; Val Top1: 0.2887139107611549, Top5: 0.7559055118110236
  NOUN   Train Top1: 0.9743178170144462, Top5: 1.0; Val Top1: 0.30183727034120733, Top5: 0.4540682414698163
  ACTION Train Top1: 0.971107544141252, Top5: 0.9983948635634029; Val Top1: 0.2073490813648294, Top5: 0.32020997375328086
  VERB   Val Precision: 0.2675, Recall: 0.2433, F1: 0.2075
  NOUN   Val Precision: 0.2405, Recall: 0.2099, F1: 0.2003
  ACTION Val Precision: 0.1232, Recall: 0.1437, F1: 0.1231
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.747474086022147
     NOUN    Mean Top-5 Recall: 0.4578801483626935
     ACTION  Mean Top-5 Recall: 0.3216922243208007
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.23529411764705882  Top5: 0.6764705882352942
    @ 0.50s  Top1: 0.24324324324324326  Top5: 0.7027027027027027
    @ 0.75s  Top1: 0.25  Top5: 0.7272727272727

Epoch 18 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 18 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 18/20 | Time 10.9s
  Train Loss: 0.5870 | Val Loss: 7.6714
  VERB   Train Top1: 0.9550561797752809, Top5: 0.9983948635634029; Val Top1: 0.27296587926509186, Top5: 0.7322834645669292
  NOUN   Train Top1: 0.9775280898876404, Top5: 0.9983948635634029; Val Top1: 0.30183727034120733, Top5: 0.5039370078740157
  ACTION Train Top1: 0.9887640449438202, Top5: 0.9983948635634029; Val Top1: 0.2099737532808399, Top5: 0.32020997375328086
  VERB   Val Precision: 0.2024, Recall: 0.1861, F1: 0.1708
  NOUN   Val Precision: 0.2366, Recall: 0.2159, F1: 0.2045
  ACTION Val Precision: 0.1280, Recall: 0.1349, F1: 0.1186
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7228197283631255
     NOUN    Mean Top-5 Recall: 0.5067513113510698
     ACTION  Mean Top-5 Recall: 0.3232090596369781
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.20588235294117646  Top5: 0.6470588235294118
    @ 0.50s  Top1: 0.1891891891891892  Top5: 0.6486486486486487
    @ 0.75s  Top1: 0.25  Top5

Epoch 19 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 19 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 19/20 | Time 10.8s
  Train Loss: 0.5449 | Val Loss: 7.7575
  VERB   Train Top1: 0.9646869983948636, Top5: 0.9983948635634029; Val Top1: 0.2677165354330709, Top5: 0.7139107611548556
  NOUN   Train Top1: 0.9791332263242376, Top5: 1.0; Val Top1: 0.30183727034120733, Top5: 0.49343832020997375
  ACTION Train Top1: 0.985553772070626, Top5: 1.0; Val Top1: 0.2047244094488189, Top5: 0.31496062992125984
  VERB   Val Precision: 0.2016, Recall: 0.1687, F1: 0.1667
  NOUN   Val Precision: 0.2402, Recall: 0.2243, F1: 0.2060
  ACTION Val Precision: 0.1327, Recall: 0.1391, F1: 0.1225
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.7038120134969447
     NOUN    Mean Top-5 Recall: 0.4973835352331669
     ACTION  Mean Top-5 Recall: 0.31714676977534617
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.17647058823529413  Top5: 0.5882352941176471
    @ 0.50s  Top1: 0.16216216216216217  Top5: 0.6486486486486487
    @ 0.75s  Top1: 0.22727272727272727  Top5: 0.6818181818

Epoch 20 Train:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 20 Val:   0%|          | 0/13 [00:00<?, ?it/s]

Epoch 20/20 | Time 11.4s
  Train Loss: 0.5014 | Val Loss: 7.8269
  VERB   Train Top1: 0.9695024077046549, Top5: 1.0; Val Top1: 0.26246719160104987, Top5: 0.6902887139107612
  NOUN   Train Top1: 0.9775280898876404, Top5: 1.0; Val Top1: 0.2887139107611549, Top5: 0.48556430446194226
  ACTION Train Top1: 0.9823434991974318, Top5: 1.0; Val Top1: 0.2073490813648294, Top5: 0.3123359580052493
  VERB   Val Precision: 0.2274, Recall: 0.1909, F1: 0.1893
  NOUN   Val Precision: 0.2563, Recall: 0.1915, F1: 0.2025
  ACTION Val Precision: 0.1225, Recall: 0.1418, F1: 0.1258
  ---- Mean Top-5 Recall (validation) ----
     VERB    Mean Top-5 Recall: 0.6761850131221996
     NOUN    Mean Top-5 Recall: 0.4898268545424118
     ACTION  Mean Top-5 Recall: 0.3145426031086795
  VERB   per-horizon (time-based):
    @ 0.25s  Top1: 0.20588235294117646  Top5: 0.5588235294117647
    @ 0.50s  Top1: 0.21621621621621623  Top5: 0.5945945945945946
    @ 0.75s  Top1: 0.22727272727272727  Top5: 0.6363636363636364
    @ 1.0