In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# train_chainrule_head.py
# Train ChainRule head on frozen Hierarchical embeddings (300D), using combined feature+label NPY files
# assuming files are in the CURRENT WORKING DIRECTORY (e.g., where this script is run from).
import os, math, time, random
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
import json

# ============ Config ============
# --- Data Paths: These files are assumed to be in the current working directory. ---
FEATURE_TRAIN    = "/content/drive/MyDrive/data/ag_glove_train.npy"
FEATURE_VAL      = "/content/drive/MyDrive/data/ag_glove_val.npy"
FEATURE_TEST     = "/content/drive/MyDrive/data/ag_glove_test.npy"
# NOTE: The last column of these files is assumed to contain the rating score (0-4).

NUM_CLASSES = 4     # AG News
# NUM_CLASSES = 14  # DBpedia


BATCH_TRAIN = 4096
BATCH_TEST  = 6144
EPOCHS      = 250
LR          = 3e-4
WDECAY      = 6e-4 #8e-4
DREG        = 4e-5
PATIENCE    = 40
USE_AMP     = False
hidden_features = 5000
num_layers  = 2
degrees     = 2
dropout     = 0.05






DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# ============ Metrics ============
from sklearn.metrics import accuracy_score, f1_score

# ============ ChainRule Model ============
class PolyLayerStable(nn.Module):
    def __init__(self, features: int, degree: int):
        super().__init__()
        self.features = features
        self.degree = degree
        self.raw_coeffs = nn.Parameter(torch.randn(features, degree + 1) * 0.05)
        self.gamma = nn.Parameter(torch.zeros(features))
        self.scale = 0.5
    def forward(self, h, dh):
        coeffs = self.scale * torch.tanh(self.raw_coeffs)
        # polynomial forward
        powers = [torch.ones_like(h)]
        for _ in range(1, self.degree + 1):
            powers.append(powers[-1] * h)
        H = torch.stack(powers, dim=-1)         # (B,F,deg+1)
        y = torch.sum(H * coeffs, dim=-1)       # (B,F)
        # derivative part
        if self.degree >= 1:
            ar = torch.arange(1, self.degree + 1, device=h.device, dtype=h.dtype)
            deriv_coeffs = coeffs[:, 1:] * ar    # (F,deg)
            dlocal = torch.sum(H[..., :-1] * deriv_coeffs, dim=-1) # (B,F)
        else:
            dlocal = torch.zeros_like(h)
        y = y + 0.1 * self.gamma * dh
        dh_out = dlocal * dh
        return y, dh_out

class ChainRulePolyNetStable(nn.Module):
    def __init__(self, in_features, hidden_features=256, num_layers=2, degree=3):
        super().__init__()
        self.in_map = nn.Linear(in_features, hidden_features)
        self.in_norm = nn.LayerNorm(hidden_features)         # NEW
        self.in_drop = nn.Dropout(p=dropout)                     # NEW
        self.layers = nn.ModuleList([PolyLayerStable(hidden_features, degree) for _ in range(num_layers)])
        self.out_map = nn.Linear(hidden_features, NUM_CLASSES)
    def forward(self, x):
        h = self.in_map(x)
        h = self.in_drop(self.in_norm(torch.nn.functional.gelu(h)))  # NEW
        dh = torch.ones_like(h)
        reg = 0.0
        for layer in self.layers:
            h, dh = layer(h, dh)
            reg = reg + torch.mean(dh ** 2)
        y = self.out_map(h) # (B,1)
        return y, reg / max(1, len(self.layers))

# ============ Datasets (Combined Feature/Label Logic) ============

def _safe_load_npy(path: str):
    """
    Try memory-mapped load first; if the file was saved/truncated oddly,
    fall back to a normal load (no mmap). Raises a clear error if both fail.
    """
    try:
        return np.load(path, mmap_mode="r")
    except Exception as e_mmap:
        print(f"⚠️ mmap load failed for '{path}': {e_mmap}\n   → Retrying without mmap...")
        try:
            return np.load(path)  # no mmap
        except Exception as e_plain:
            raise RuntimeError(
                f"Failed to load '{path}' even without mmap. "
                f"The file may be truncated or corrupted. Original error: {e_plain}"
            )

class CombinedNumpyDataset(Dataset):
    """
    Dataset for pre-split NumPy files where the final column is the label.
    Accepts either a proper 2D numeric array (N, D+1) or a flat 1D numeric array
    that can be reshaped into (N, D+1) when divisible by a plausible column count.
    """
    def __init__(self, full_path: str, prefer_cols: int | None = None):
        print(f"Loading data from {full_path}...")
        if not os.path.exists(full_path):
            raise FileNotFoundError(f"File not found: {full_path}")

        data = _safe_load_npy(full_path)

        if data.dtype == np.object_:
            raise ValueError(
                f"{full_path} is an object array (likely pickled). "
                f"Re-save as numeric (e.g., float32) with shape (N, D+1)."
            )

        # Case 1: Already 2D
        if data.ndim == 2:
            self.features = np.array(data[:, :-1], dtype=np.float32, copy=True)  # was: np.asarray(...)
            self.labels   = np.array(data[:,  -1], dtype=np.float32, copy=True).reshape(-1)
            self.dim = self.features.shape[1]
            return

        # Case 2: Flat 1D → try to infer and reshape
        if data.ndim == 1:
            total = data.size
            candidates = []

            # Prefer a known columns count if provided (e.g., 300D + 1 label = 301)
            if prefer_cols is not None and total % prefer_cols == 0:
                candidates.append(prefer_cols)

            # Otherwise search a reasonable range of divisors
            if not candidates:
                for c in range(32, 4097):
                    if total % c == 0:
                        candidates.append(c)

            if not candidates:
                raise ValueError(
                    f"{full_path}: flat size {total:,} has no reasonable divisors in [32, 4096]; "
                    f"cannot infer (features+label) column count."
                )

            # Pick the candidate closest to common shapes like 301/309
            candidates.sort(key=lambda c: min(abs(c - 301), abs(c - 309)))
            cols = candidates[0]
            n = total // cols
            print(f"ℹ️  Inferred shape for {full_path}: ({n:,}, {cols}) → features={cols-1}, label=1.")

            data2d = np.array(data, dtype=np.float32, copy=True).reshape(n, cols)
            self.features = data2d[:, :-1]
            self.labels   = data2d[:,  -1].reshape(-1)
            self.dim = self.features.shape[1]
            return

        raise ValueError(f"{full_path} has ndim={data.ndim}; expected 1D (flat) or 2D numeric array.")

    def __len__(self):
        return self.features.shape[0]

    def __getitem__(self, idx):
        x = torch.from_numpy(self.features[idx])
        y = torch.tensor(int(self.labels[idx]), dtype=torch.long) # used to be y = torch.tensor(self.labels[idx], dtype=torch.long)
        return x, y

def build_loaders_from_files():
    """Loads pre-split data from combined NPY files and creates DataLoaders."""
    print(f"--- Data Loading ---")

    # If you know your exact columns (e.g., 300D + 1 label = 301), set prefer_cols=301.
    # Otherwise leave None to auto-infer for each file.
    prefer_cols = None  # e.g., set to 301 if you want to enforce 300D+label

    try:
        train_dataset = CombinedNumpyDataset(FEATURE_TRAIN, prefer_cols=prefer_cols)
        val_dataset   = CombinedNumpyDataset(FEATURE_VAL,   prefer_cols=prefer_cols)

        if os.path.exists(FEATURE_TEST):
            test_dataset = CombinedNumpyDataset(FEATURE_TEST, prefer_cols=prefer_cols)
        else:
            test_dataset = None
            print(f"Warning: Test file {FEATURE_TEST} not found. Test set will be skipped.")

    except Exception as e:
        raise RuntimeError(
            f"Error loading combined NPY files. Ensure files are numeric (features + label) and not truncated. "
            f"Error: {e}"
        )

    in_dim = train_dataset.dim
    if val_dataset.dim != in_dim or (test_dataset and test_dataset.dim != in_dim):
        raise ValueError(
            f"Embedding dimensions must match across splits: "
            f"train={in_dim}, val={val_dataset.dim}, "
            f"test={(test_dataset.dim if test_dataset else '—')}."
        )

    print(f"Total Train samples: {len(train_dataset):,} | Val: {len(val_dataset):,} | Dim: {in_dim}")

    pin = (DEVICE.type == "cuda")
    # cap at 2 when on this runtime
    n_workers = 2 if pin else 0
    persistent = (n_workers > 0)


    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_TRAIN, shuffle=True,
        pin_memory=pin, num_workers=n_workers, persistent_workers=persistent
    )
    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_TEST, shuffle=False,
        pin_memory=pin, num_workers=n_workers, persistent_workers=persistent
    )
    test_loader = None
    if test_dataset:
        test_loader = DataLoader(
            test_dataset, batch_size=BATCH_TEST, shuffle=False,
            pin_memory=pin, num_workers=n_workers, persistent_workers=persistent
        )

    return train_loader, val_loader, test_loader, in_dim

# ============ Training / Eval ============
def train_loop(model, train_loader, val_loader, epochs=EPOCHS, lr=LR, wdecay=WDECAY, dreg=DREG, use_amp=USE_AMP):
    model = model.to(DEVICE)
    # CHANGED: modern GradScaler API
    scaler = torch.amp.GradScaler('cuda', enabled=(use_amp and DEVICE.type == "cuda"))  # NEW

    # CHANGED: fused AdamW when available
    opt = AdamW(model.parameters(), lr=lr, weight_decay=wdecay, fused=torch.cuda.is_available())  # NEW

#     opt = AdamW(
#     model.parameters(),
#     lr=lr,
#     weight_decay=wdecay,
#     betas=(0.9, 0.98),   # ← KEY CHANGE
#     eps=1e-8,
#     fused=torch.cuda.is_available()
# )

    # NEW: warmup + cosine schedule
    total_steps = max(1, epochs * len(train_loader))
    warmup_steps = max(10, len(train_loader))
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step + 1) / warmup_steps
        t = (step - warmup_steps) / max(1, (total_steps - warmup_steps))
        return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * t))
    scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)  # NEW

    # NEW: EMA tracking
    ema = AveragedModel(model, avg_fn=get_ema_avg_fn(0.999))

    best_val = -math.inf
    best = None
    best_ema = None
    no_improve = 0
    global in_dim

    criterion = nn.CrossEntropyLoss()

    step_idx = 0  # NEW
    for ep in range(1, epochs + 1):
        model.train()
        tot, n = 0.0, 0
        for xb, yb in train_loader:
            xb = xb.to(DEVICE, non_blocking=True)

            # Ensure (B, in_dim)
            if xb.ndim != 2 or xb.shape[-1] != in_dim:
                xb = xb.view(-1, in_dim)

            # scale labels to [0,1]
            yb = yb.to(DEVICE, non_blocking=True).long()

            opt.zero_grad(set_to_none=True)
            if use_amp and DEVICE.type == "cuda":
                # CHANGED: modern autocast
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    logits, reg = model(xb)
                    reg = torch.clamp(reg, max=10.0)
                    loss = criterion(logits, yb) + dreg * reg
                scaler.scale(loss).backward()
                # NEW: clip grads in AMP path (after unscale)
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                scaler.step(opt); scaler.update()
            else:
                logits, reg = model(xb)
                reg = torch.clamp(reg, max=10.0)
                loss = criterion(logits, yb) + dreg * reg
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5) # ------------------------------------------------------------------------
                opt.step()

            # NEW: EMA + scheduler step
            ema.update_parameters(model)
            scheduler.step()
            step_idx += 1

            bs = xb.size(0)
            tot += loss.item() * bs
            n += bs

        val_acc = eval_loop(ema, val_loader)["acc"]
        print(f"Epoch {ep:02d}/{epochs} | train_loss={tot/max(1,n):.5f} | val_acc={val_acc:.4f}")

        if val_acc > best_val + 1e-6:
            best_val = val_acc
            best = {k: v.cpu() for k, v in model.state_dict().items()}
            best_ema = {k: v.cpu() for k, v in ema.module.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= PATIENCE:
                print("Early stopping.")
                break

    if best is not None:
        model.load_state_dict(best)
        if best_ema is not None:
            ema.module.load_state_dict(best_ema)  # NEW

    # Return the EMA-smoothed model for downstream eval
    return ema  # CHANGED

@torch.no_grad()
def eval_loop(model, loader):
    eval_model = model.module if isinstance(model, AveragedModel) else model
    eval_model.eval()

    ys_true, ys_pred = [], []

    for xb, yb in loader:
        xb = xb.to(DEVICE, non_blocking=True)

        # Ensure (B, in_dim)
        if xb.ndim != 2 or xb.shape[-1] != in_dim:
            xb = xb.view(-1, in_dim)

        logits, _ = eval_model(xb)          # shape: (B, C)
        pred = logits.argmax(dim=1)         # shape: (B,)
        ys_true.append(yb.numpy())
        ys_pred.append(pred.cpu().numpy())

    y = np.concatenate(ys_true)
    p = np.concatenate(ys_pred)

    acc = float(np.mean(p == y))
    return {"acc": acc}

# ============ Main ============
if __name__ == "__main__":
    print(f"Device: {DEVICE.type}")

    # --- Initial File Check ---
    if not os.path.exists(FEATURE_TRAIN):
        raise FileNotFoundError(f"Training file '{FEATURE_TRAIN}' not found. Ensure it is in the current working directory.")
    if not os.path.exists(FEATURE_VAL):
        raise FileNotFoundError(f"Validation file '{FEATURE_VAL}' not found. Ensure it is in the current working directory.")

    train_loader, val_loader, test_loader, in_dim = build_loaders_from_files()

    # Initialize model using the dynamically determined input dimension
    model = ChainRulePolyNetStable(
        in_features=in_dim,
        hidden_features=hidden_features,
        num_layers=num_layers,
        degree=degrees
    )
    print(f"Model initialized with in_features={in_dim}")

    t0 = time.time()
    model = train_loop(model, train_loader, val_loader)
    dur = time.time() - t0

    print("\n=== Training Complete ===")

    if test_loader:
        metrics = eval_loop(model, test_loader)
        print("\n=== Test Metrics ===")
        print(f"Accuracy: {metrics['acc']:.4f}")
    else:
        print("Test metrics skipped because the test data file was not found.")

    print(f"Total train time: {dur:.2f}s")



Device: cuda
--- Data Loading ---
Loading data from /content/drive/MyDrive/data/ag_glove_train.npy...
Loading data from /content/drive/MyDrive/data/ag_glove_val.npy...
Loading data from /content/drive/MyDrive/data/ag_glove_test.npy...
Total Train samples: 89,320 | Val: 19,140 | Dim: 308
Model initialized with in_features=308
Epoch 01/250 | train_loss=1.38272 | val_acc=0.2528
Epoch 02/250 | train_loss=1.31120 | val_acc=0.3358
Epoch 03/250 | train_loss=0.81106 | val_acc=0.2592
Epoch 04/250 | train_loss=0.39330 | val_acc=0.2460
Epoch 05/250 | train_loss=0.31228 | val_acc=0.2460
Epoch 06/250 | train_loss=0.27862 | val_acc=0.2460
Epoch 07/250 | train_loss=0.25733 | val_acc=0.2460
Epoch 08/250 | train_loss=0.24072 | val_acc=0.2460
Epoch 09/250 | train_loss=0.22768 | val_acc=0.2460
Epoch 10/250 | train_loss=0.21503 | val_acc=0.2460
Epoch 11/250 | train_loss=0.20341 | val_acc=0.2461
Epoch 12/250 | train_loss=0.19261 | val_acc=0.2472
Epoch 13/250 | train_loss=0.18238 | val_acc=0.2510
Epoch 14/2

In [None]:
# ---- Parameter accounting ----
def count_params(model):
    # Access the underlying model if it's an AveragedModel
    if isinstance(model, AveragedModel):
        model = model.module

    total = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Per-submodule tallies
    sub = {}
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        top = name.split('.')[0]  # e.g., 'in_map', 'in_norm', 'layers', 'out_map'
        sub[top] = sub.get(top, 0) + p.numel()

    # Poly-layer "embedded" params = all raw_coeffs + gamma across layers
    poly_embedded = 0
    # Check if the model has layers before iterating
    if hasattr(model, 'layers'):
        for l in model.layers:
            # Check if the layer has raw_coeffs and gamma before accessing
            if hasattr(l, 'raw_coeffs'):
                 poly_embedded += l.raw_coeffs.numel()
            if hasattr(l, 'gamma'):
                 poly_embedded += l.gamma.numel()


    return total, sub, poly_embedded

total, sub_breakdown, poly_embedded = count_params(model)

print("\n=== Parameter Counts ===")
print(f"Total trainable params: {total:,}")
for k in sorted(sub_breakdown.keys()):
    print(f"  {k:8s}: {sub_breakdown[k]:,}")
print(f"Embedded params (poly raw_coeffs + gamma): {poly_embedded:,}")
print("External frozen embeddings in .npy (not part of model): 0 trainable params")
print(f"Your KPI is {metrics['acc']*100/np.log10(total)}")


=== Parameter Counts ===
Total trainable params: 1,615,004
  in_map  : 1,545,000
  in_norm : 10,000
  layers  : 40,000
  out_map : 20,004
Embedded params (poly raw_coeffs + gamma): 40,000
External frozen embeddings in .npy (not part of model): 0 trainable params
Your KPI is 15.086121664346816


In [None]:
import numpy as np
import torch
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report

# ============ FINAL EVALUATION CELL (NOMINAL) ============
print("\n=== Calculating F1, Precision, and Recall on Test Set (Nominal) ===")

def collect_preds_nominal(model, loader):
    """
    Collect true labels and predicted class indices for nominal multiclass classification.
    EMA-aware.
    """
    eval_model = model.module if isinstance(model, torch.optim.swa_utils.AveragedModel) else model
    eval_model.eval()

    ys_true = []
    ys_pred = []

    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE, non_blocking=True)

            # ensure correct shape
            if xb.ndim != 2 or xb.shape[-1] != in_dim:
                xb = xb.view(-1, in_dim)

            logits, _ = eval_model(xb)          # (B, C)
            preds = torch.argmax(logits, dim=1)  # (B,)

            ys_true.append(yb.cpu().numpy())
            ys_pred.append(preds.cpu().numpy())

    y_true = np.concatenate(ys_true).astype(int)
    y_pred = np.concatenate(ys_pred).astype(int)

    return y_true, y_pred


# ---- Run evaluation ----
if 'test_loader' in locals():

    y_true_int, y_pred_int = collect_preds_nominal(model, test_loader)

    # Weighted metrics (appropriate even if perfectly balanced)
    f1 = f1_score(y_true_int, y_pred_int, average="weighted")
    precision = precision_score(y_true_int, y_pred_int, average="weighted")
    recall = recall_score(y_true_int, y_pred_int, average="weighted")

    print("\n--- Weighted Classification Metrics ---")
    print(f"Test Precision (Weighted): {precision:.4f}")
    print(f"Test Recall (Weighted):    {recall:.4f}")
    print(f"Test F1-Score (Weighted):  {f1:.4f}")

    print("\nDetailed Multi-Class Classification Report:")
    target_names = [str(i) for i in range(NUM_CLASSES)]
    print(classification_report(y_true_int, y_pred_int, target_names=target_names))

else:
    print("Skipping evaluation: test_loader not available.")



=== Calculating F1, Precision, and Recall on Test Set (Nominal) ===

--- Weighted Classification Metrics ---
Test Precision (Weighted): 0.9370
Test Recall (Weighted):    0.9366
Test F1-Score (Weighted):  0.9365

Detailed Multi-Class Classification Report:
              precision    recall  f1-score   support

           0       0.96      0.93      0.94      4857
           1       0.98      0.99      0.98      4749
           2       0.92      0.89      0.91      4707
           3       0.90      0.94      0.92      4827

    accuracy                           0.94     19140
   macro avg       0.94      0.94      0.94     19140
weighted avg       0.94      0.94      0.94     19140

