In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
# train_chainrule_head.py
# Train ChainRule head on frozen Hierarchical embeddings (300D), using combined feature+label NPY files
# assuming files are in the CURRENT WORKING DIRECTORY (e.g., where this script is run from).
import os, math, time, random
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
import json

# ============ Config ============
# --- Data Paths: These files are assumed to be in the current working directory. ---
FEATURE_TRAIN    = "/content/drive/MyDrive/data/yelp_polar_glove_final_train.npy"
FEATURE_VAL      = "/content/drive/MyDrive/data/yelp_polar_glove_final_val.npy"
FEATURE_TEST     = "/content/drive/MyDrive/data/yelp_polar_glove_final_test.npy"
# NOTE: The last column of these files is assumed to contain the rating score (0-4).

TARGET_MIN, TARGET_MAX = 0.0, 1.0

BATCH_TRAIN = 4096
BATCH_TEST  = 6144
EPOCHS      = 250
LR          = 4e-4
WDECAY      = 6e-4
DREG        = 1e-5
PATIENCE    = 25
USE_AMP     = True
hidden_features = 3000
num_layers  = 2
degrees     = 2
dropout     = 0.05




DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# ============ Metrics ============
def rmse(y_true, y_pred):
    return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))

def mae(y_true, y_pred):
    return float(np.mean(np.abs(y_true - y_pred)))

def qwk(y_true, y_pred, min_rating=0, max_rating=4):
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    M = max_rating - min_rating + 1
    O = np.zeros((M, M), dtype=np.float64)
    for t, p in zip(y_true, y_pred):
        if min_rating <= t <= max_rating and min_rating <= p <= max_rating:
            O[t - min_rating, p - min_rating] += 1.0
    act_hist = O.sum(axis=1)
    pred_hist = O.sum(axis=0)
    E = np.outer(act_hist, pred_hist) / max(1.0, O.sum())
    W = np.zeros((M, M))
    for i in range(M):
        for j in range(M):
            W[i, j] = ((i - j) ** 2) / ((M - 1) ** 2)
    num = (W * O).sum()
    den = (W * E).sum() if (W * E).sum() != 0 else 1.0
    return 1.0 - num / den

def round_clip(x):
    return (x >= 0.5).astype(int)

def apply_thresholds(p_cont, ts):
    """ts = [t1,t2,t3,t4] on 0..4; returns integer 0..4 via np.digitize."""
    ts = np.asarray(ts, dtype=np.float32)
    return np.digitize(p_cont, ts).astype(int)

def greedy_search_thresholds(y_true_int, p_cont, metric="qwk",
                             start_ts=(0.5,1.5,2.5,3.5),
                             sweeps=3, step=0.10, margin=0.40):
    """Coordinate-descent tuner for 4 thresholds on 0..4."""
    ts = np.array(start_ts, dtype=np.float32)

    def score_with(ts_local):
        pr = apply_thresholds(p_cont, ts_local)
        if metric == "acc":
            return (pr == y_true_int).mean()
        else:
            return qwk(y_true_int, pr)

    best_score = score_with(ts)
    for _ in range(sweeps):
        for i in range(len(ts)):
            low = (ts[i-1] + 1e-3) if i > 0 else 0.0
            high = (ts[i+1] - 1e-3) if i < 3 else 4.0
            cand_grid = np.arange(max(low, ts[i]-margin),
                                  min(high, ts[i]+margin)+1e-9, step, dtype=np.float32)
            local_best = ts[i]; local_best_score = best_score
            for c in cand_grid:
                ts_try = ts.copy(); ts_try[i] = c
                s = score_with(ts_try)
                if s > local_best_score:
                    local_best_score = s
                    local_best = c
            ts[i] = local_best
            best_score = local_best_score
    return ts.tolist(), float(best_score)

@torch.no_grad()
def _collect_preds_cont(model, loader):
    """Return (y_true_int, y_pred_cont_0to4) from a loader (uses EMA if present)."""
    eval_model = model.module if isinstance(model, AveragedModel) else model
    eval_model.eval()
    ys_true, ys_pred = [], []
    for xb, yb in loader:
        xb = xb.to(DEVICE, non_blocking=True)
        if xb.ndim != 2 or xb.shape[-1] != in_dim:
            xb = xb.view(-1, in_dim)
        out, _ = eval_model(xb)
        yhat_s = torch.sigmoid(out)
        yhat   = TARGET_MIN + (TARGET_MAX - TARGET_MIN) * yhat_s
        ys_true.append(yb.numpy())
        ys_pred.append(yhat.squeeze(-1).cpu().numpy())
    y = np.concatenate(ys_true).astype(int)
    p = np.concatenate(ys_pred).astype(np.float32)
    return y, p

# ============ ChainRule Model ============
class PolyLayerStable(nn.Module):
    def __init__(self, features: int, degree: int):
        super().__init__()
        self.features = features
        self.degree = degree
        self.raw_coeffs = nn.Parameter(torch.randn(features, degree + 1) * 0.05)
        self.gamma = nn.Parameter(torch.zeros(features))
        self.scale = 0.5
    def forward(self, h, dh):
        coeffs = self.scale * torch.tanh(self.raw_coeffs)
        # polynomial forward
        powers = [torch.ones_like(h)]
        for _ in range(1, self.degree + 1):
            powers.append(powers[-1] * h)
        H = torch.stack(powers, dim=-1)         # (B,F,deg+1)
        y = torch.sum(H * coeffs, dim=-1)       # (B,F)
        # derivative part
        if self.degree >= 1:
            ar = torch.arange(1, self.degree + 1, device=h.device, dtype=h.dtype)
            deriv_coeffs = coeffs[:, 1:] * ar    # (F,deg)
            dlocal = torch.sum(H[..., :-1] * deriv_coeffs, dim=-1) # (B,F)
        else:
            dlocal = torch.zeros_like(h)
        y = y + 0.1 * self.gamma * dh
        dh_out = dlocal * dh
        return y, dh_out

class ChainRulePolyNetStable(nn.Module):
    def __init__(self, in_features, hidden_features=256, num_layers=2, degree=3):
        super().__init__()
        self.in_map = nn.Linear(in_features, hidden_features)
        self.in_norm = nn.LayerNorm(hidden_features)         # NEW
        self.in_drop = nn.Dropout(p=dropout)                     # NEW
        self.layers = nn.ModuleList([PolyLayerStable(hidden_features, degree) for _ in range(num_layers)])
        self.out_map = nn.Linear(hidden_features, 1)
    def forward(self, x):
        h = self.in_map(x)
        h = self.in_drop(self.in_norm(torch.nn.functional.gelu(h)))  # NEW
        dh = torch.ones_like(h)
        reg = 0.0
        for layer in self.layers:
            h, dh = layer(h, dh)
            reg = reg + torch.mean(dh ** 2)
        y = self.out_map(h) # (B,1)
        return y, reg / max(1, len(self.layers))

# ============ Datasets (Combined Feature/Label Logic) ============

def _safe_load_npy(path: str):
    """
    Try memory-mapped load first; if the file was saved/truncated oddly,
    fall back to a normal load (no mmap). Raises a clear error if both fail.
    """
    try:
        return np.load(path, mmap_mode="r")
    except Exception as e_mmap:
        print(f"⚠️ mmap load failed for '{path}': {e_mmap}\n   → Retrying without mmap...")
        try:
            return np.load(path)  # no mmap
        except Exception as e_plain:
            raise RuntimeError(
                f"Failed to load '{path}' even without mmap. "
                f"The file may be truncated or corrupted. Original error: {e_plain}"
            )

class CombinedNumpyDataset(Dataset):
    """
    Dataset for pre-split NumPy files where the final column is the label.
    Accepts either a proper 2D numeric array (N, D+1) or a flat 1D numeric array
    that can be reshaped into (N, D+1) when divisible by a plausible column count.
    """
    def __init__(self, full_path: str, prefer_cols: int | None = None):
        print(f"Loading data from {full_path}...")
        if not os.path.exists(full_path):
            raise FileNotFoundError(f"File not found: {full_path}")

        data = _safe_load_npy(full_path)

        if data.dtype == np.object_:
            raise ValueError(
                f"{full_path} is an object array (likely pickled). "
                f"Re-save as numeric (e.g., float32) with shape (N, D+1)."
            )

        # Case 1: Already 2D
        if data.ndim == 2:
            self.features = np.array(data[:, :-1], dtype=np.float32, copy=True)  # was: np.asarray(...)
            self.labels   = np.array(data[:,  -1], dtype=np.float32, copy=True).reshape(-1)
            self.dim = self.features.shape[1]
            return

        # Case 2: Flat 1D → try to infer and reshape
        if data.ndim == 1:
            total = data.size
            candidates = []

            # Prefer a known columns count if provided (e.g., 300D + 1 label = 301)
            if prefer_cols is not None and total % prefer_cols == 0:
                candidates.append(prefer_cols)

            # Otherwise search a reasonable range of divisors
            if not candidates:
                for c in range(32, 4097):
                    if total % c == 0:
                        candidates.append(c)

            if not candidates:
                raise ValueError(
                    f"{full_path}: flat size {total:,} has no reasonable divisors in [32, 4096]; "
                    f"cannot infer (features+label) column count."
                )

            # Pick the candidate closest to common shapes like 301/309
            candidates.sort(key=lambda c: min(abs(c - 301), abs(c - 309)))
            cols = candidates[0]
            n = total // cols
            print(f"ℹ️  Inferred shape for {full_path}: ({n:,}, {cols}) → features={cols-1}, label=1.")

            data2d = np.array(data, dtype=np.float32, copy=True).reshape(n, cols)
            self.features = data2d[:, :-1]
            self.labels   = data2d[:,  -1].reshape(-1)
            self.dim = self.features.shape[1]
            return

        raise ValueError(f"{full_path} has ndim={data.ndim}; expected 1D (flat) or 2D numeric array.")

    def __len__(self):
        return self.features.shape[0]

    def __getitem__(self, idx):
        x = torch.from_numpy(self.features[idx])
        y = torch.tensor(self.labels[idx], dtype=torch.float32)
        return x, y

def build_loaders_from_files():
    """Loads pre-split data from combined NPY files and creates DataLoaders."""
    print(f"--- Data Loading ---")

    # If you know your exact columns (e.g., 300D + 1 label = 301), set prefer_cols=301.
    # Otherwise leave None to auto-infer for each file.
    prefer_cols = None  # e.g., set to 301 if you want to enforce 300D+label

    try:
        train_dataset = CombinedNumpyDataset(FEATURE_TRAIN, prefer_cols=prefer_cols)
        val_dataset   = CombinedNumpyDataset(FEATURE_VAL,   prefer_cols=prefer_cols)

        if os.path.exists(FEATURE_TEST):
            test_dataset = CombinedNumpyDataset(FEATURE_TEST, prefer_cols=prefer_cols)
        else:
            test_dataset = None
            print(f"Warning: Test file {FEATURE_TEST} not found. Test set will be skipped.")

    except Exception as e:
        raise RuntimeError(
            f"Error loading combined NPY files. Ensure files are numeric (features + label) and not truncated. "
            f"Error: {e}"
        )

    in_dim = train_dataset.dim
    if val_dataset.dim != in_dim or (test_dataset and test_dataset.dim != in_dim):
        raise ValueError(
            f"Embedding dimensions must match across splits: "
            f"train={in_dim}, val={val_dataset.dim}, "
            f"test={(test_dataset.dim if test_dataset else '—')}."
        )

    print(f"Total Train samples: {len(train_dataset):,} | Val: {len(val_dataset):,} | Dim: {in_dim}")

    pin = (DEVICE.type == "cuda")
    # cap at 2 when on this runtime
    n_workers = 2 if pin else 0
    persistent = (n_workers > 0)


    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_TRAIN, shuffle=True,
        pin_memory=pin, num_workers=n_workers, persistent_workers=persistent
    )
    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_TEST, shuffle=False,
        pin_memory=pin, num_workers=n_workers, persistent_workers=persistent
    )
    test_loader = None
    if test_dataset:
        test_loader = DataLoader(
            test_dataset, batch_size=BATCH_TEST, shuffle=False,
            pin_memory=pin, num_workers=n_workers, persistent_workers=persistent
        )

    return train_loader, val_loader, test_loader, in_dim

# ============ Training / Eval ============

def train_loop(
    model, train_loader, val_loader,
    epochs=EPOCHS, lr=LR, wdecay=WDECAY, dreg=DREG,
    use_amp=USE_AMP, grad_clip=0.5, train_probe_batches=10
):
    model = model.to(DEVICE)

    scaler = torch.amp.GradScaler('cuda', enabled=(use_amp and DEVICE.type == "cuda"))
    opt = AdamW(model.parameters(), lr=lr, weight_decay=wdecay, fused=torch.cuda.is_available())

    total_steps = max(1, epochs * len(train_loader))
    warmup_steps = max(10, len(train_loader))

    def lr_lambda(step):
        if step < warmup_steps:
            return float(step + 1) / warmup_steps
        t = (step - warmup_steps) / max(1, (total_steps - warmup_steps))
        return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * t))

    scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)

    ema = AveragedModel(model, avg_fn=get_ema_avg_fn(0.999))

    best_val = -math.inf
    best = None
    best_ema = None
    no_improve = 0
    global in_dim

    for ep in range(1, epochs + 1):
        model.train()
        tot, n = 0.0, 0

        for xb, yb in train_loader:
            xb = xb.to(DEVICE, non_blocking=True)
            if xb.ndim != 2 or xb.shape[-1] != in_dim:
                xb = xb.view(-1, in_dim)

            # IMPORTANT: BCEWithLogits expects float targets
            ys = yb.to(DEVICE, non_blocking=True).float().unsqueeze(-1)

            opt.zero_grad(set_to_none=True)

            if use_amp and DEVICE.type == "cuda":
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    logits, reg = model(xb)
                    reg = torch.clamp(reg, max=10.0)
                    loss = nn.functional.binary_cross_entropy_with_logits(logits, ys) + dreg * reg

                scaler.scale(loss).backward()
                scaler.unscale_(opt)
                if grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
                scaler.step(opt)
                scaler.update()
            else:
                logits, reg = model(xb)
                reg = torch.clamp(reg, max=10.0)
                loss = nn.functional.binary_cross_entropy_with_logits(logits, ys) + dreg * reg
                loss.backward()
                if grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
                opt.step()

            ema.update_parameters(model)
            scheduler.step()

            bs = xb.size(0)
            tot += loss.item() * bs
            n += bs

        # ---- fast train accuracy probe (first N batches only) ----
        train_metrics = eval_loop(ema, train_loader, max_batches=train_probe_batches)
        train_acc_fast = train_metrics["acc_round"]

        val_metrics = eval_loop(ema, val_loader)     # thresholds=None => 0.5
        val_acc_05 = val_metrics["acc_round"]

        print(f"Epoch {ep:02d}/{epochs} | train_loss={tot/max(1,n):.5f} | val_acc={val_acc_05:.4f}")

        metric = val_acc_05

        if metric > best_val + 1e-6:
            best_val = metric
            best = {k: v.cpu() for k, v in model.state_dict().items()}
            best_ema = {k: v.cpu() for k, v in ema.module.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= PATIENCE:
                print("Early stopping.")
                break

    if best is not None:
        model.load_state_dict(best)
        if best_ema is not None:
            ema.module.load_state_dict(best_ema)

    return ema


@torch.no_grad()
def eval_loop(model, loader, thresholds=None, max_batches=None, return_best_thresh=False):
    """
    Returns:
        dict with:
            - acc_round: accuracy using threshold(s) (0.5 if thresholds is None)
            - best_thresh (optional): best scalar threshold on this loader
            - best_acc (optional): best accuracy achieved by scalar threshold sweep
    """
    eval_model = model.module if isinstance(model, AveragedModel) else model
    eval_model.eval()

    ys_true, ys_pred = [], []

    for bidx, (xb, yb) in enumerate(loader):
        if (max_batches is not None) and (bidx >= max_batches):
            break

        xb = xb.to(DEVICE, non_blocking=True)
        if xb.ndim != 2 or xb.shape[-1] != in_dim:
            xb = xb.view(-1, in_dim)

        out, _ = eval_model(xb)                 # logits
        p = torch.sigmoid(out).squeeze(-1)      # probs in [0,1]

        ys_true.append(yb.detach().cpu().numpy())
        ys_pred.append(p.detach().cpu().numpy())

    y = np.concatenate(ys_true).astype(int)
    p = np.concatenate(ys_pred)

    if thresholds is None:
        pr = (p >= 0.5).astype(int)
    elif isinstance(thresholds, (float, int)):
        pr = (p >= thresholds).astype(int)
    else:
        pr = apply_thresholds(p, thresholds)

    acc = float(np.mean(pr == y))
    out_dict = {"acc_round": acc}

    # Optional: find best scalar threshold on this loader
    if return_best_thresh:
        best_t, best_acc = 0.5, acc
        # sweep 0.05..0.95
        for t in np.linspace(0.05, 0.95, 91):
            pr_t = (p >= t).astype(int)
            acc_t = float(np.mean(pr_t == y))
            if acc_t > best_acc:
                best_acc, best_t = acc_t, float(t)
        out_dict["best_thresh"] = best_t
        out_dict["best_acc"] = best_acc

    return out_dict


if __name__ == "__main__":
    print(f"Device: {DEVICE.type}")

    # --- Initial File Check ---
    if not os.path.exists(FEATURE_TRAIN):
        raise FileNotFoundError(f"Training file '{FEATURE_TRAIN}' not found. Ensure it is in the current working directory.")
    if not os.path.exists(FEATURE_VAL):
        raise FileNotFoundError(f"Validation file '{FEATURE_VAL}' not found. Ensure it is in the current working directory.")

    train_loader, val_loader, test_loader, in_dim = build_loaders_from_files()

    # Initialize model using the dynamically determined input dimension
    model = ChainRulePolyNetStable(in_features=in_dim, hidden_features=hidden_features, num_layers=num_layers, degree=degrees)
    print(f"Model initialized with in_features={in_dim}")

    t0 = time.time()
    model = train_loop(model, train_loader, val_loader)
    dur = time.time() - t0

    print("\n=== Training Complete ===")

val_metrics = eval_loop(model, val_loader, return_best_thresh=True)
t_star = val_metrics["best_thresh"]

if test_loader:
    test_05 = eval_loop(model, test_loader)["acc_round"]
    test_t  = eval_loop(model, test_loader, thresholds=t_star)["acc_round"]
    print(f"Test acc@0.5: {test_05:.4f}")
    print(f"Test acc@val-threshold (t={t_star:.2f}): {test_t:.4f}")
else:
    print("Test metrics skipped because the test data file was not found.")

Device: cuda
--- Data Loading ---
Loading data from /content/drive/MyDrive/data/yelp_polar_glove_final_train.npy...
Loading data from /content/drive/MyDrive/data/yelp_polar_glove_final_val.npy...
Loading data from /content/drive/MyDrive/data/yelp_polar_glove_final_test.npy...
Total Train samples: 418,600 | Val: 89,700 | Dim: 308
Model initialized with in_features=308
Epoch 01/250 | train_loss=0.63062 | val_acc=0.5235
Epoch 02/250 | train_loss=0.31416 | val_acc=0.5235
Epoch 03/250 | train_loss=0.27646 | val_acc=0.5235
Epoch 04/250 | train_loss=0.25974 | val_acc=0.5294
Epoch 05/250 | train_loss=0.24641 | val_acc=0.5660
Epoch 06/250 | train_loss=0.23412 | val_acc=0.6354
Epoch 07/250 | train_loss=0.22370 | val_acc=0.7152
Epoch 08/250 | train_loss=0.21317 | val_acc=0.7836
Epoch 09/250 | train_loss=0.20323 | val_acc=0.8327
Epoch 10/250 | train_loss=0.19355 | val_acc=0.8646
Epoch 11/250 | train_loss=0.18355 | val_acc=0.8808
Epoch 12/250 | train_loss=0.17347 | val_acc=0.8877
Epoch 13/250 | tra

In [6]:
# ---- Parameter accounting ----

def count_params(model):
    # Access the underlying model if it's an AveragedModel
    if isinstance(model, AveragedModel):
        model = model.module

    total = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Per-submodule tallies
    sub = {}
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        top = name.split('.')[0]  # e.g., 'in_map', 'in_norm', 'layers', 'out_map'
        sub[top] = sub.get(top, 0) + p.numel()

    # Poly-layer "embedded" params = all raw_coeffs + gamma across layers
    poly_embedded = 0
    # Check if the model has layers before iterating
    if hasattr(model, 'layers'):
        for l in model.layers:
            # Check if the layer has raw_coeffs and gamma before accessing
            if hasattr(l, 'raw_coeffs'):
                 poly_embedded += l.raw_coeffs.numel()
            if hasattr(l, 'gamma'):
                 poly_embedded += l.gamma.numel()


    return total, sub, poly_embedded

total, sub_breakdown, poly_embedded = count_params(model)

print("\n=== Parameter Counts ===")
print(f"Total trainable params: {total:,}")
for k in sorted(sub_breakdown.keys()):
    print(f"  {k:8s}: {sub_breakdown[k]:,}")
print(f"Embedded params (poly raw_coeffs + gamma): {poly_embedded:,}")
print("External frozen embeddings in .npy (not part of model): 0 trainable params")
print(f"Your KPI is {test_t*100/np.log10(total)}")



=== Parameter Counts ===
Total trainable params: 960,001
  in_map  : 927,000
  in_norm : 6,000
  layers  : 24,000
  out_map : 3,001
Embedded params (poly raw_coeffs + gamma): 24,000
External frozen embeddings in .npy (not part of model): 0 trainable params
Your KPI is 16.078536917166687


In [7]:
import numpy as np
import torch
import json
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report
# Assuming the necessary utility functions like apply_thresholds and
# _collect_preds_cont, along with global variables (model, test_loader,
# TARGET_MIN, TARGET_MAX) are defined from the main script execution.

# Re-defining necessary utility from user's script for robust execution:
def apply_thresholds(p_cont, ts):
    """ts = [t] on 0..1; returns integer 0..1 via np.digitize."""
    ts = np.asarray(ts, dtype=np.float32)
    return np.digitize(p_cont, ts).astype(int)

# ============ FINAL EVALUATION CELL ============
print("\n=== Calculating F1, Precision, and Recall on Test Set ===")

# 1. Load Tuned Thresholds (saved during the main script's final steps)
try:
    with open("rating_thresholds.json", "r") as f:
        threshold_data = json.load(f)
        BEST_THRESHOLDS = threshold_data["thresholds"]
    print(f"Loaded optimal thresholds: {[round(t, 3) for t in BEST_THRESHOLDS]}")
except FileNotFoundError:
    print("Error: 'rating_thresholds.json' not found. Ensure the main training script ran successfully and saved the thresholds.")
    BEST_THRESHOLDS = None
except Exception as e:
    print(f"Error loading thresholds: {e}")
    BEST_THRESHOLDS = None

if BEST_THRESHOLDS and 'test_loader' in locals():
    # 2. Collect True Labels and Continuous Predictions
    # NOTE: This line requires the _collect_preds_cont function and the EMA model object ('model')
    y_true_int, p_cont = _collect_preds_cont(model, test_loader)

    # 3. Apply Tuned Thresholds to get Discrete Predictions (0, 1, 2, 3, 4)
    y_pred_int = apply_thresholds(p_cont, BEST_THRESHOLDS)

    # 4. Calculate Classification Metrics (Multi-class 0-4)
    # Using 'weighted' average accounts for class imbalance
    f1 = f1_score(y_true_int, y_pred_int, average='weighted')
    precision = precision_score(y_true_int, y_pred_int, average='weighted')
    recall = recall_score(y_true_int, y_pred_int, average='weighted')

    print("\n--- Weighted Classification Metrics ---")
    print(f"Test Precision (Weighted): {precision:.4f}")
    print(f"Test Recall (Weighted):    {recall:.4f}")
    print(f"Test F1-Score (Weighted):  {f1:.4f}")

    # Optional: Detailed report for each class (0, 1, 2, 3, 4)
    print("\nDetailed Multi-Class Classification Report:")
    # Assuming TARGET_MIN and TARGET_MAX are available globally
    target_names = ["negative", "positive"]
    print(classification_report(y_true_int, y_pred_int, target_names=target_names))

else:
    print("Skipping F1/P/R calculation: Thresholds file or Test Loader not available.")


=== Calculating F1, Precision, and Recall on Test Set ===
Error: 'rating_thresholds.json' not found. Ensure the main training script ran successfully and saved the thresholds.
Skipping F1/P/R calculation: Thresholds file or Test Loader not available.
