# Setup

## Root and data folders

In [4]:
import os
import pandas as pd
import numpy as np

root_dir = "/Users/silviumatu/Desktop/Code/Python/Disertatie/Disertatie_Matu_Silviu_v1"
os.makedirs(root_dir, exist_ok=True)

data_dir = os.path.join(root_dir, "Data")
os.makedirs(data_dir, exist_ok=True)

# Load data

In [22]:
EXP_reg_df = pd.read_csv(os.path.join(data_dir, "EXP_regression_data_forecast.csv"))
columns_EXP_reg_df = pd.read_csv(os.path.join(data_dir, "columns_EXP_regression_data_forecast.csv"))

EXP_reg_df.head()

Unnamed: 0,x_participant_id,x_age,x_gender,x_BDI_TOTAL_pre,x_YSQ_D1_pre,x_YSQ_D2_pre,x_LSAS_ANX_pre,x_LSAS_AVOID_pre,x_day_participant,x_response_within_day_participant,...,x_q2_value_1,x_q2_value_2,x_q2_value_3,x_q2_value_4,x_q2_value_5,x_q2_value_6,x_q2_value_7,x_time_intervals_copy,x_time_difference,y_dep_score_next
0,6,22,1.0,20,133,61,51,47,0,0,...,0,0,1,0,0,0,0,0,1.0,10.0
1,6,22,1.0,20,133,61,51,47,0,1,...,0,1,0,0,0,0,0,1,1.0,2.0
2,6,22,1.0,20,133,61,51,47,0,2,...,1,0,0,0,0,0,0,2,5.0,4.0
3,6,22,1.0,20,133,61,51,47,1,0,...,0,0,0,1,0,0,0,7,3.0,3.0
4,6,22,1.0,20,133,61,51,47,1,1,...,1,0,0,0,0,0,0,10,5.0,15.0


In [23]:
# Select the outcome column(s) marked with 1 in the "outcomes" column of columns_EXP_reg_df
EXP_reg_outcome_cols = columns_EXP_reg_df.loc[columns_EXP_reg_df['outcomes'] == 1, 'column_name'].tolist()
EXP_reg_y = EXP_reg_df[EXP_reg_outcome_cols]
EXP_reg_y.head()

# Same for outcomes lags column(s)
EXP_reg_outcomes_lags_cols = columns_EXP_reg_df.loc[columns_EXP_reg_df['outcomes_lags'] == 1, 'column_name'].tolist()
EXP_reg_outcomes_lags = EXP_reg_df[EXP_reg_outcomes_lags_cols]

# Same for participant column(s)
EXP_reg_participant_cols = columns_EXP_reg_df.loc[columns_EXP_reg_df['participant_id'] == 1, 'column_name'].tolist()
EXP_reg_participant_id = EXP_reg_df[EXP_reg_participant_cols]

# Same for time column(s)
EXP_reg_time_cols = columns_EXP_reg_df.loc[columns_EXP_reg_df['time'] == 1, 'column_name'].tolist()
EXP_reg_time = EXP_reg_df[EXP_reg_time_cols]

# Same for forecast horizons column(s)
EXP_reg_forecast_horizons_cols = columns_EXP_reg_df.loc[columns_EXP_reg_df['forecast_horizons'] == 1, 'column_name'].tolist()
EXP_reg_forecast_horizons = EXP_reg_df[EXP_reg_forecast_horizons_cols]

# Same for fixed effects column(s)
EXP_reg_only_fixed_cols = columns_EXP_reg_df.loc[columns_EXP_reg_df['only_fixed'] == 1, 'column_name'].tolist()
EXP_reg_only_fixed = EXP_reg_df[EXP_reg_only_fixed_cols]

# Same for random effects column(s)
EXP_reg_fixed_and_random_cols = columns_EXP_reg_df.loc[columns_EXP_reg_df['fixed_and_random'] == 1, 'column_name'].tolist()
EXP_reg_fixed_and_random = EXP_reg_df[EXP_reg_fixed_and_random_cols]

# ARMED

## Architecutre

In [24]:
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, Iterable

import torch
import torch.nn as nn
import torch.nn.functional as F

class GradientReversalFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambd: float):
        ctx.lambd = float(lambd)
        return x
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambd * grad_output, None

class GradientReversal(nn.Module):
    def __init__(self, lambd: float = 1.0):
        super().__init__()
        self.lambd = float(lambd)
    def set_lambda(self, lambd: float):
        self.lambd = float(lambd)
    def forward(self, x):
        return GradientReversalFn.apply(x, self.lambd)

def mlp(in_dim: int, hidden: Iterable[int], out_dim: int, dropout: float = 0.0, last_activation: Optional[nn.Module] = None):
    layers: list[nn.Module] = []
    dims = [in_dim] + list(hidden)
    for d0, d1 in zip(dims[:-1], dims[1:]):
        layers.append(nn.Linear(d0, d1))
        layers.append(nn.ReLU())
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
    layers.append(nn.Linear(dims[-1], out_dim))
    if last_activation is not None:
        layers.append(last_activation)
    return nn.Sequential(*layers)

class FixedAE(nn.Module):
    def __init__(self, in_dim: int, enc_hidden=(128, 64), rep_dim=32, dropout=0.0,
                 use_decoder: bool = False, dec_hidden: Optional[Iterable[int]] = None):
        super().__init__()
        self.encoder = mlp(in_dim, enc_hidden, rep_dim, dropout)
        self.use_decoder = bool(use_decoder)
        if self.use_decoder:
            dec_hidden = list(dec_hidden) if dec_hidden is not None else list(enc_hidden)[::-1]
            self.decoder = mlp(rep_dim, dec_hidden, in_dim, dropout)
        else:
            self.decoder = None
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        z = self.encoder(x)
        xhat = self.decoder(z) if self.decoder is not None else None
        return z, xhat

class RandomEnc(nn.Module):
    def __init__(self, in_dim: int, hidden=(128, 64), rep_dim=32, dropout=0.0):
        super().__init__()
        self.net = mlp(in_dim, hidden, rep_dim, dropout)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class ParticipantEmbedding(nn.Module):
    def __init__(self, n_participants: int, rep_dim: int):
        super().__init__()
        self.n_seen = int(n_participants)
        self.unk_index = self.n_seen
        self.emb = nn.Embedding(self.n_seen + 1, rep_dim, padding_idx=None)
    def forward(self, pid_idx: torch.Tensor) -> torch.Tensor:
        idx = pid_idx.clone()
        idx = torch.where(idx >= 0, idx, torch.full_like(idx, self.unk_index))
        return self.emb(idx)

class FiLM(nn.Module):
    def __init__(self, rep_dim: int, hidden=(64,), dropout=0.0):
        super().__init__()
        self.gamma = mlp(rep_dim, hidden, rep_dim, dropout)
        self.beta  = mlp(rep_dim, hidden, rep_dim, dropout)
    def forward(self, z_id: torch.Tensor, z_obs: torch.Tensor) -> torch.Tensor:
        g = 1.0 + 0.1 * torch.tanh(self.gamma(z_id))
        b = 0.1 * self.beta(z_id)
        return g * z_obs + b

class Adversary(nn.Module):
    def __init__(self, in_dim: int, hidden=(64,), n_participants: int = 1,
                 dropout: float = 0.0, grl_lambda: float = 1.0):
        super().__init__()
        self.grl = GradientReversal(grl_lambda)
        self.net = mlp(in_dim, hidden, n_participants, dropout)
    def set_lambda(self, lambd: float):
        self.grl.set_lambda(lambd)
    def forward(self, z_fixed: torch.Tensor) -> torch.Tensor:
        return self.net(self.grl(z_fixed))

class ARMEDTabular(nn.Module):
    def __init__(
        self,
        d_fixed: int,
        d_random: int = 0,
        y_dim: int = 1,
        n_participants: int = 1,
        include_random_data: bool = True,
        fixed_enc_hidden=(128, 64),
        fixed_rep_dim: int = 32,
        fixed_dropout: float = 0.0,
        use_fixed_decoder: bool = False,
        fixed_dec_hidden: Optional[Iterable[int]] = None,
        random_hidden=(128, 64),
        random_rep_dim: int = 32,
        random_dropout: float = 0.0,
        combine_mode: str = "add",   # "add" or "film"
        film_hidden=(64,),
        film_dropout: float = 0.0,
        adv_hidden=(64,),
        adv_dropout: float = 0.0,
        grl_lambda: float = 1.0,
        head_hidden=(64,),
    ):
        super().__init__()
        self.include_random_data = bool(include_random_data and d_random > 0)

        self.fixed = FixedAE(
            in_dim=d_fixed,
            enc_hidden=fixed_enc_hidden,
            rep_dim=fixed_rep_dim,
            dropout=fixed_dropout,
            use_decoder=use_fixed_decoder,
            dec_hidden=fixed_dec_hidden,
        )

        self.id_emb = ParticipantEmbedding(n_participants, rep_dim=random_rep_dim)

        self.random = RandomEnc(
            in_dim=d_random,
            hidden=random_hidden,
            rep_dim=random_rep_dim,
            dropout=random_dropout,
        ) if self.include_random_data else None

        if combine_mode not in {"add", "film"}:
            raise ValueError("combine_mode must be 'add' or 'film'")
        self.combine_mode = combine_mode
        self.film = FiLM(rep_dim=random_rep_dim, hidden=film_hidden, dropout=film_dropout) if combine_mode == "film" else None

        self.head = mlp(fixed_rep_dim + random_rep_dim, head_hidden, y_dim, dropout=0.0)
        self.adv  = Adversary(fixed_rep_dim, adv_hidden, n_participants, adv_dropout, grl_lambda)

        self.norm_f = nn.LayerNorm(fixed_rep_dim)
        self.norm_r = nn.LayerNorm(random_rep_dim)

    def forward(
        self,
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
        x_random: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:

        z_f, xhat = self.fixed(x_fixed)
        z_f = self.norm_f(z_f)

        z_id = self.id_emb(pid_idx)
        if self.include_random_data and (x_random is not None):
            z_r_obs = self.random(x_random)
            z_r = z_r_obs + z_id if self.combine_mode == "add" else self.film(z_id, z_r_obs)
        else:
            z_r = z_id
        z_r = self.norm_r(z_r)

        y_hat  = self.head(torch.cat([z_f, z_r], dim=1))   # continuous outputs
        adv_logits = self.adv(z_f)                         # adversary still classification

        return y_hat, adv_logits, xhat, z_f, z_r


## Wrapper

In [25]:
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, Iterable, Sequence

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn


@dataclass
class ARMEDLossWeights:
    lambda_adv: float = 1.0
    lambda_recon: float = 0.0


class ARMEDWrapper:
    """
    Regression wrapper around ARMEDTabular:
      - device management
      - regression loss (MSE) for predictions
      - adversary and reconstruction losses unchanged
      - validation/prediction helpers returning continuous outputs
    """
    def __init__(self, model: nn.Module, loss_weights: Optional[ARMEDLossWeights] = None, device: Optional[torch.device] = None):
        self.model = model
        self.loss_w = loss_weights or ARMEDLossWeights()
        self.device = (device
                       or (torch.device("mps") if torch.backends.mps.is_available() else None)
                       or (torch.device("cuda") if torch.cuda.is_available() else None)
                       or torch.device("cpu"))
        self.model.to(self.device)

    def forward(self, x_fixed: torch.Tensor, pid_idx: torch.Tensor, x_random: Optional[torch.Tensor] = None):
        return self.model(x_fixed, pid_idx, x_random)

    # -----------------------------
    # Loss components (regression)
    # -----------------------------
    def _pred_loss(self, y_hat: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        return F.mse_loss(y_hat, y_true)

    def _adv_loss(self, adv_logits: torch.Tensor, pid_idx: torch.Tensor) -> torch.Tensor:
        seen_mask = (pid_idx >= 0)
        if seen_mask.any():
            return F.cross_entropy(adv_logits[seen_mask], pid_idx[seen_mask])
        else:
            return torch.tensor(0.0, device=adv_logits.device)

    def _recon_loss(self, xhat: Optional[torch.Tensor], x: torch.Tensor) -> torch.Tensor:
        if (xhat is None) or (self.loss_w.lambda_recon <= 0.0):
            return torch.tensor(0.0, device=x.device)
        return F.mse_loss(xhat, x)

    def compute_losses(
        self,
        y_true: torch.Tensor,
        y_hat: torch.Tensor,
        adv_logits: torch.Tensor,
        xhat: Optional[torch.Tensor],
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        lp = self._pred_loss(y_hat, y_true)
        la = self._adv_loss(adv_logits, pid_idx)
        lr = self._recon_loss(xhat, x_fixed)

        total = lp + self.loss_w.lambda_adv * la + self.loss_w.lambda_recon * lr

        parts = {
            "loss_total": float(total.detach().cpu()),
            "loss_pred":  float(lp.detach().cpu()),
            "loss_adv":   float(la.detach().cpu()),
            "loss_recon": float(lr.detach().cpu()),
        }
        return total, parts

    # -----------------------------
    # Eval / predict helpers
    # -----------------------------
    @torch.no_grad()
    def validation_step(
        self,
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
        y_true: torch.Tensor,
        x_random: Optional[torch.Tensor] = None,
        prefix: str = "val"
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        self.model.eval()
        x_fixed = x_fixed.to(self.device)
        pid_idx = pid_idx.to(self.device)
        y_true  = y_true.to(self.device)
        x_random = x_random.to(self.device) if x_random is not None else None

        y_hat, adv_logits, xhat, _, _ = self.forward(x_fixed, pid_idx, x_random)
        loss, parts = self.compute_losses(y_true, y_hat, adv_logits, xhat, x_fixed, pid_idx)

        print(
            f"{prefix}_loss: {parts['loss_total']:.6f} | "
            f"pred: {parts['loss_pred']:.6f} | "
            f"adv: {parts['loss_adv']:.6f} | "
            f"recon: {parts['loss_recon']:.6f}"
        )
        return loss, parts

    @torch.no_grad()
    def predict_logits(   # now returns continuous predictions (y_hat)
        self,
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
        x_random: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        self.model.eval()
        x_fixed = x_fixed.to(self.device)
        pid_idx = pid_idx.to(self.device)
        x_random = x_random.to(self.device) if x_random is not None else None
        y_hat, _, _, _, _ = self.forward(x_fixed, pid_idx, x_random)
        return y_hat

    @torch.no_grad()
    def predict_values(
        self,
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
        x_random: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        return self.predict_logits(x_fixed, pid_idx, x_random)


## Evaluation procedure

In [26]:
from typing import Dict, Any, Optional, Tuple, Iterable, List
from dataclasses import dataclass

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import GroupKFold, TimeSeriesSplit, ParameterGrid
from sklearn.decomposition import PCA
from scipy.stats import t as student_t

# -------------------------------------------------------------------
# Dataset / loaders (unchanged)
# -------------------------------------------------------------------
class _ARMEDDataset(Dataset):
    def __init__(self, Xf, pid_idx, y, Xr=None, device=None):
        device = device or torch.device("cpu")
        Xf  = np.asarray(Xf, dtype=np.float32)
        y   = np.asarray(y,  dtype=np.float32)
        pid = np.asarray(pid_idx)

        if y.ndim == 1:
            y = y[:, None]

        self.Xf  = torch.as_tensor(Xf, dtype=torch.float32, device=device)
        self.pid = torch.as_tensor(pid, dtype=torch.long,    device=device)
        self.y   = torch.as_tensor(y,  dtype=torch.float32,  device=device)

        if Xr is None:
            self.Xr = torch.zeros((len(self.Xf), 0), dtype=torch.float32, device=device)
        else:
            Xr = np.asarray(Xr, dtype=np.float32)
            self.Xr = torch.as_tensor(Xr, dtype=torch.float32, device=device)

    def __len__(self):
        return self.Xf.shape[0]

    def __getitem__(self, idx):
        return self.Xf[idx], self.pid[idx], self.Xr[idx], self.y[idx]

def _make_loader(Xf, pid_idx, y, Xr=None, batch_size=256, shuffle=False, device=None):
    ds = _ARMEDDataset(Xf, pid_idx, y, Xr=Xr, device=device)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=0, drop_last=False)

# -------------------------------------------------------------------
# Regression metrics — IDENTICAL to the KAN routine
# -------------------------------------------------------------------
def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
    y_true = np.asarray(y_true); y_pred = np.asarray(y_pred)
    if y_true.ndim == 1: y_true = y_true[:, None]
    if y_pred.ndim == 1: y_pred = y_pred[:, None]
    assert y_true.shape == y_pred.shape
    out = {}
    for j in range(y_true.shape[1]):
        t = y_true[:, j]; p = y_pred[:, j]
        mse = float(np.mean((p - t) ** 2))
        rmse = float(np.sqrt(mse))
        mae = float(np.mean(np.abs(p - t)))
        var = float(np.var(t))
        r2 = float(1.0 - (mse / var)) if var > 0 else np.nan
        if np.std(t) > 0 and np.std(p) > 0:
            r = float(np.corrcoef(t, p)[0, 1])
        else:
            r = np.nan
        out[f"task_{j+1}_MSE"] = mse
        out[f"task_{j+1}_RMSE"] = rmse
        out[f"task_{j+1}_MAE"] = mae
        out[f"task_{j+1}_R2"] = r2
        out[f"task_{j+1}_PearsonR"] = r
    for k in ("MSE","RMSE","MAE","R2","PearsonR"):
        vals = [out[f"task_{j+1}_{k}"] for j in range(y_true.shape[1])]
        out[f"macro_{k}"] = float(np.nanmean(vals))
    return out

def _print_regression_metrics(metrics: dict, title: str = "Test metrics"):
    print(f"\n{title}:")
    macro_keys = [k for k in metrics.keys() if k.startswith("macro_")]
    for k in sorted(macro_keys):
        print(f"{k:>16}: {metrics[k]:.6f}")
    task_indices = sorted({int(k.split('_')[1]) for k in metrics.keys() if k.startswith("task_")})
    for j in task_indices:
        print(f"task_{j}: " +
              ", ".join(f"{m}={metrics.get(f'task_{j}_{m}', np.nan):.6f}"
                        for m in ("MSE","RMSE","MAE","R2","PearsonR")))

def _summarize_cv_folds(results_folds: List[dict]) -> dict:
    if not results_folds:
        return {}

    all_keys = set().union(*results_folds)
    summary = {}

    for k in sorted(all_keys):
        vals = np.array([fold.get(k, np.nan) for fold in results_folds], dtype=float)
        mask = np.isfinite(vals)
        n = int(mask.sum())

        if n == 0:
            m = np.nan; low = np.nan; high = np.nan
        elif n == 1:
            m = float(vals[mask][0]); low = np.nan; high = np.nan
        else:
            m = float(np.nanmean(vals))
            s = float(np.nanstd(vals, ddof=1))
            se = s / np.sqrt(n)
            tcrit = float(student_t.ppf(0.975, df=n - 1))
            low = m - tcrit * se
            high = m + tcrit * se

        summary[f"{k}_mean"] = m
        summary[f"{k}_95ci_low"] = low
        summary[f"{k}_95ci_high"] = high

    return summary

# -------------------------------------------------------------------
# PCA pipelines and helpers (unchanged)
# -------------------------------------------------------------------
_VAR_EPS = 1e-8
_STD_EPS = 1e-6
_CLIP_Z  = 8.0

@dataclass
class PCAPipeline:
    keep_mask: np.ndarray
    mean_: np.ndarray
    scale_: np.ndarray
    pca: PCA

def _fit_pca_pipeline(X_train: np.ndarray, var_ratio: float = 0.95, random_state: int | None = None) -> PCAPipeline:
    X = np.asarray(X_train, dtype=np.float64)
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

    var = X.var(axis=0)
    keep = var > _VAR_EPS
    if not np.any(keep):
        pca = PCA(n_components=0, svd_solver='full', random_state=random_state)
        return PCAPipeline(keep_mask=keep, mean_=np.array([], dtype=np.float64),
                           scale_=np.array([], dtype=np.float64), pca=pca)

    Xk = X[:, keep]

    mean = Xk.mean(axis=0)
    std  = Xk.std(axis=0)
    std  = np.maximum(std, _STD_EPS)

    Z = (Xk - mean) / std
    Z = np.nan_to_num(Z, nan=0.0, posinf=0.0, neginf=0.0)
    np.clip(Z, -_CLIP_Z, _CLIP_Z, out=Z)

    pca = PCA(n_components=var_ratio, svd_solver='full', random_state=random_state)
    pca.fit(Z)
    if not np.isfinite(pca.components_).all():
        raise RuntimeError("PCA components contain non-finite values after fit.")

    return PCAPipeline(keep_mask=keep, mean_=mean, scale_=std, pca=pca)

def _transform_pca_pipeline(pipe: PCAPipeline | None, X: np.ndarray | None) -> np.ndarray | None:
    if pipe is None or X is None:
        return None

    X = np.asarray(X, dtype=np.float64)
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

    if pipe.keep_mask.size == 0 or not np.any(pipe.keep_mask):
        return np.zeros((X.shape[0], 0), dtype=np.float32)
    Xk = X[:, pipe.keep_mask]

    Z = (Xk - pipe.mean_) / pipe.scale_
    Z = np.nan_to_num(Z, nan=0.0, posinf=0.0, neginf=0.0)
    np.clip(Z, -_CLIP_Z, _CLIP_Z, out=Z)

    if not np.isfinite(Z).all():
        bad = np.argwhere(~np.isfinite(Z))[0]
        raise RuntimeError(f"[our PCA] Z non-finite at {tuple(bad)}: {Z[tuple(bad)]}")
    if np.abs(Z).max() > 1e6:
        raise RuntimeError(f"[our PCA] Z max |z| too large: {np.abs(Z).max()}")
    if not np.isfinite(pipe.pca.components_).all():
        raise RuntimeError("[our PCA] components_ non-finite")
    if hasattr(pipe.pca, "mean_") and not np.isfinite(pipe.pca.mean_).all():
        raise RuntimeError("[our PCA] mean_ non-finite")

    Z64 = np.ascontiguousarray(Z, dtype=np.float64)
    CT  = np.ascontiguousarray(pipe.pca.components_.T, dtype=np.float64)

    with np.errstate(over='ignore', invalid='ignore', divide='ignore'):
        Xt = Z64 @ CT

    Xt = np.nan_to_num(Xt, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return Xt

def _concat_safe(*arrays: Optional[np.ndarray]) -> np.ndarray:
    parts = [a for a in arrays if a is not None and a.size > 0]
    if not parts:
        return np.zeros((0, 0), dtype=np.float32)
    return np.concatenate(parts, axis=1).astype(np.float32)

def _filter_time_test_min_measurements(pid_idx: np.ndarray, test_idx: np.ndarray, min_meas: int = 3):
    """Keep only rows in test_idx belonging to pids with >= min_meas measurements overall."""
    pid = np.asarray(pid_idx)
    counts = {pid_val: np.sum(pid == pid_val) for pid_val in np.unique(pid)}
    keep = [i for i in test_idx if counts.get(pid[i], 0) >= min_meas]
    return np.array(keep, dtype=int)

# -------------------------------------------------------------------
# Splitting helpers (unchanged)
# -------------------------------------------------------------------
def _split_cases(pid_array, test_fraction=0.2, seed=42):
    rng = np.random.default_rng(seed)
    unique_ids = np.unique(pid_array)
    te_ids = rng.choice(unique_ids, size=max(1, int(len(unique_ids) * test_fraction)), replace=False)
    te_mask = np.isin(pid_array, te_ids)
    return np.where(~te_mask)[0], np.where(te_mask)[0]

def _split_time_basic(time_index, test_fraction=0.2):
    order = np.argsort(time_index)
    n = len(order)
    split = int(np.floor(n * (1.0 - test_fraction)))
    return order[:split], order[split:]

# -------------------------------------------------------------------
# Early-stop helper on loaders (regression)
# -------------------------------------------------------------------
def _eval_macro_metric_on_loader(wrapper, loader: DataLoader, which: str = "macro_R2") -> float:
    if loader is None:
        return np.nan
    wrapper.model.eval()
    preds_all, y_all = [], []
    with torch.no_grad():
        for Xf_b, pid_b, Xr_b, y_b in loader:
            y_hat = wrapper.predict_logits(Xf_b, pid_b, Xr_b if Xr_b.size(1) > 0 else None)
            preds_all.append(y_hat.cpu().numpy())
            y_all.append(y_b.cpu().numpy())
    y_pred = np.vstack(preds_all); y_true = np.vstack(y_all)
    m = regression_metrics(y_true, y_pred)
    return float(m.get(which, np.nan))

# -------------------------------------------------------------------
# Train loop (prints train & monitor; early stop by loss or macro metric)
# -------------------------------------------------------------------
def _fit_once(
    wrapper, optimizer,
    train_loader: DataLoader,
    val_loader: Optional[DataLoader],          # used for early stopping (default)
    monitor_loader: Optional[DataLoader],      # printed each epoch; often the test set
    max_epochs: int = 100,
    patience: int = 10,
    early_stop_metric: str = "loss",           # "loss" | "macro_R2" | "macro_MAE"
    early_stop_on: str = "val",                # "val" | "train" | "monitor"
    verbose: bool = True,
):
    if early_stop_on == "val" and val_loader is None:
        early_stop_on = "train"

    higher_is_better = (early_stop_metric == "macro_R2")
    best_val = -np.inf if higher_is_better else np.inf
    best_state, no_improve = None, 0

    def _avg_loss(loader) -> Optional[float]:
        if loader is None:
            return None
        wrapper.model.eval()
        total, n = 0.0, 0
        with torch.no_grad():
            for Xf_b, pid_b, Xr_b, y_b in loader:
                y_hat, adv_logits, xhat, _, _ = wrapper.forward(Xf_b, pid_b, Xr_b if Xr_b.size(1) > 0 else None)
                l, _ = wrapper.compute_losses(y_b, y_hat, adv_logits, xhat, Xf_b, pid_b)
                bs = Xf_b.size(0)
                total += float(l.detach().cpu()) * bs
                n += bs
        return total / max(1, n)

    def _current_metric():
        if early_stop_metric == "loss":
            if early_stop_on == "val":
                return _avg_loss(val_loader)
            elif early_stop_on == "train":
                return _avg_loss(train_loader)
            else:
                return _avg_loss(monitor_loader)
        else:
            which = early_stop_metric
            if early_stop_on == "val":
                return _eval_macro_metric_on_loader(wrapper, val_loader, which=which)
            elif early_stop_on == "train":
                return _eval_macro_metric_on_loader(wrapper, train_loader, which=which)
            else:
                return _eval_macro_metric_on_loader(wrapper, monitor_loader, which=which)

    for epoch in range(1, max_epochs + 1):
        wrapper.model.train()
        total_tr, n_tr = 0.0, 0
        for Xf_b, pid_b, Xr_b, y_b in train_loader:
            y_hat, adv_logits, xhat, _, _ = wrapper.forward(
                Xf_b, pid_b, Xr_b if Xr_b.size(1) > 0 else None
            )
            loss, _ = wrapper.compute_losses(y_b, y_hat, adv_logits, xhat, Xf_b, pid_b)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            bs = Xf_b.size(0)
            total_tr += float(loss.detach().cpu()) * bs
            n_tr += bs
        train_loss = total_tr / max(1, n_tr)

        monitor_loss = _avg_loss(monitor_loader)
        if verbose:
            if monitor_loss is not None:
                print(f"Epoch {epoch:03d} | train {train_loss:.6f} | monitor_loss {monitor_loss:.6f}")
            else:
                print(f"Epoch {epoch:03d} | train {train_loss:.6f}")

        current = _current_metric()
        if higher_is_better:
            is_better = (current is not None) and (current > best_val + 1e-6)
        else:
            is_better = (current is not None) and (current < best_val - 1e-6)

        if is_better:
            best_val = current
            best_state = {k: v.detach().cpu().clone() for k, v in wrapper.model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                if verbose:
                    tag = f"{early_stop_metric}@{early_stop_on}"
                    print(f"Early stopping at epoch {epoch:03d} (best {tag} {best_val:.6f})")
                break

    if best_state is not None:
        wrapper.model.load_state_dict(best_state)

# -------------------------------------------------------------------
# PCA + loaders + per-split PID remap
# -------------------------------------------------------------------
def _prepare_split_and_loaders(
    X_only_fixed: np.ndarray,
    X_fixed_and_random: Optional[np.ndarray],
    y: np.ndarray,
    pid_idx_full: np.ndarray,
    indices_train: np.ndarray,
    indices_val: Optional[np.ndarray],
    indices_test: np.ndarray,
    batch_size: int,
    device: torch.device,
    random_state: int = 42,
    pca_var_ratio: float = 0.95,
):
    of_pipe = _fit_pca_pipeline(X_only_fixed[indices_train], var_ratio=pca_var_ratio, random_state=random_state)
    fr_pipe = None
    if X_fixed_and_random is not None and X_fixed_and_random.shape[1] > 0:
        fr_pipe = _fit_pca_pipeline(X_fixed_and_random[indices_train], var_ratio=pca_var_ratio, random_state=random_state)

    def transform_block(idxs):
        of = _transform_pca_pipeline(of_pipe, X_only_fixed[idxs])
        fr = _transform_pca_pipeline(fr_pipe, None if X_fixed_and_random is None else X_fixed_and_random[idxs])
        Xf = _concat_safe(of, fr)
        Xr = fr
        return Xf, Xr

    Xf_tr, Xr_tr = transform_block(indices_train)
    Xf_te, Xr_te = transform_block(indices_test)
    if indices_val is not None:
        Xf_va, Xr_va = transform_block(indices_val)
    else:
        Xf_va, Xr_va = None, None

    seen = np.unique(pid_idx_full[indices_train])
    pid_to_seen = {p: i for i, p in enumerate(seen)}

    def map_pids(idxs):
        vals = pid_idx_full[idxs]
        mapped = np.array([pid_to_seen.get(p, -1) for p in vals], dtype=np.int64)
        return mapped

    pid_tr = map_pids(indices_train)
    pid_te = map_pids(indices_test)
    pid_va = map_pids(indices_val) if indices_val is not None else None

    tr_loader = _make_loader(Xf_tr, pid_tr, y[indices_train], Xr=Xr_tr, batch_size=batch_size, shuffle=True,  device=device)
    va_loader = _make_loader(Xf_va, pid_va, y[indices_val], Xr=Xr_va, batch_size=batch_size, shuffle=False, device=device) if Xf_va is not None else None
    te_loader = _make_loader(Xf_te, pid_te, y[indices_test],  Xr=Xr_te, batch_size=batch_size, shuffle=False, device=device)

    d_fixed  = Xf_tr.shape[1]
    d_random = 0 if Xr_tr is None else Xr_tr.shape[1]

    preprocessors = {
        "only_fixed": of_pipe,
        "fixed_and_random": fr_pipe,
        "d_fixed": d_fixed,
        "d_random": d_random,
        "n_seen": int(len(seen)),
    }
    loaders = {"train": tr_loader, "val": va_loader, "test": te_loader}
    return preprocessors, loaders

# -------------------------------------------------------------------
# One split fit/eval (regression)
# -------------------------------------------------------------------
def _fit_eval_once(
    build_model_fn, wrapper_cls,
    arch_params: Dict[str, Any],
    train_params: Dict[str, Any],
    X_of: np.ndarray,
    X_fr: Optional[np.ndarray],
    y: np.ndarray,
    pid_idx_full: np.ndarray,
    tr_idx: np.ndarray,
    va_idx: Optional[np.ndarray],
    te_idx: np.ndarray,
    device: torch.device,
    monitor_source: str = "test",
    verbose: bool = True,
):
    preprocessors, loaders = _prepare_split_and_loaders(
        X_of, X_fr, y, pid_idx_full,
        tr_idx, va_idx, te_idx,
        batch_size=train_params.get("batch_size", 256),
        device=device,
        random_state=train_params.get("random_state", 42),
        pca_var_ratio=train_params.get("pca_var_ratio", 0.95),
    )

    d_fixed  = preprocessors["d_fixed"]
    d_random = preprocessors["d_random"]
    n_seen   = preprocessors["n_seen"]

    y_dim = y.shape[1] if y.ndim == 2 else 1
    model = build_model_fn(
        d_fixed=d_fixed,
        d_random=d_random,
        y_dim=y_dim,
        n_participants=n_seen,
        **arch_params
    ).to(device)

    wrapper = wrapper_cls(model, loss_weights=train_params.get("loss_weights", None), device=device)
    opt = torch.optim.Adam(wrapper.model.parameters(),
                           lr=train_params.get("lr", 1e-3),
                           weight_decay=train_params.get("weight_decay", 0.0))

    monitor_loader = loaders["test"] if monitor_source == "test" else loaders["val"]

    _fit_once(
        wrapper, opt,
        loaders["train"], loaders["val"], monitor_loader,
        max_epochs=train_params.get("max_epochs", 100),
        patience=train_params.get("patience", 10),
        early_stop_metric=train_params.get("early_stop_metric", "loss"),
        early_stop_on=train_params.get("early_stop_on", "val"),
        verbose=verbose
    )

    wrapper.model.eval()
    preds_all, y_all = [], []
    with torch.no_grad():
        for Xf_b, pid_b, Xr_b, y_b in loaders["test"]:
            y_hat = wrapper.predict_logits(Xf_b, pid_b, Xr_b if Xr_b.size(1) > 0 else None)
            preds_all.append(y_hat.cpu().numpy())
            y_all.append(y_b.cpu().numpy())
    y_pred_te = np.vstack(preds_all); y_true_te = np.vstack(y_all)

    metrics = regression_metrics(y_true_te, y_pred_te)

    return {
        "metrics": metrics,
        "preprocessors": preprocessors,
        "wrapper": wrapper,
        "model": wrapper.model,
        "y_true_test": y_true_te,
        "y_pred_test": y_pred_te,
    }

# -------------------------------------------------------------------
# Main entry: single / cv_only / nested_cv, scenarios cases/time/both
# -------------------------------------------------------------------
def run_training_and_eval_armed(
    X_only_fixed: np.ndarray,
    X_fixed_and_random: Optional[np.ndarray],
    y: np.ndarray,
    pid_idx: np.ndarray,
    time_index: np.ndarray,
    build_model_fn,
    wrapper_cls,
    *,
    mode: str = "single",
    scenario: str = "cases",
    outer_folds: int = 5,
    inner_folds: int = 3,
    param_grid: Optional[Dict[str, List]] = None,
    arch_defaults: Optional[Dict[str, Any]] = None,
    train_defaults: Optional[Dict[str, Any]] = None,
    device: Optional[torch.device] = None,
    verbose: bool = True,
) -> Dict[str, Any]:
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    X_of = np.asarray(X_only_fixed, dtype=np.float32)
    X_fr = None if X_fixed_and_random is None else np.asarray(X_fixed_and_random, dtype=np.float32)
    y    = np.asarray(y, dtype=np.float32)
    pid_idx = np.asarray(pid_idx, dtype=np.int64)
    time_ix = np.asarray(time_index)

    arch_defaults = arch_defaults or {}
    train_defaults = train_defaults or {}

    rnd = int(train_defaults.get("random_state", 42))
    val_frac = float(train_defaults.get("val_fraction", 0.10))
    monitor_source = train_defaults.get("monitor_source", "test")

    def _make_train_val_split(idx_array: np.ndarray, seed: int) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        idx_array = np.asarray(idx_array)
        if len(idx_array) <= 10 or val_frac <= 0.0:
            return idx_array, None
        rng = np.random.default_rng(seed)
        perm = rng.permutation(len(idx_array))
        cut = max(1, int(val_frac * len(idx_array)))
        va_sel, tr_sel = perm[:cut], perm[cut:]
        return idx_array[tr_sel], idx_array[va_sel]

    if scenario == "both":
        out_cases = run_training_and_eval_armed(
            X_of, X_fr, y, pid_idx, time_ix,
            build_model_fn, wrapper_cls,
            mode=mode, scenario="cases",
            outer_folds=outer_folds, inner_folds=inner_folds,
            param_grid=param_grid, arch_defaults=arch_defaults, train_defaults=train_defaults,
            device=device, verbose=verbose
        )
        out_time = run_training_and_eval_armed(
            X_of, X_fr, y, pid_idx, time_ix,
            build_model_fn, wrapper_cls,
            mode=mode, scenario="time",
            outer_folds=outer_folds, inner_folds=inner_folds,
            param_grid=param_grid, arch_defaults=arch_defaults, train_defaults=train_defaults,
            device=device, verbose=verbose
        )
        return {"cases": out_cases, "time": out_time}

    # -------------------- MODE: SINGLE --------------------
    if mode == "single":
        if scenario == "cases":
            tr_idx_all, te_idx = _split_cases(pid_idx, test_fraction=0.2, seed=rnd)
            te_idx_use = te_idx
        elif scenario == "time":
            tr_idx_all, te_idx_raw = _split_time_basic(time_ix, test_fraction=0.2)
            te_idx = _filter_time_test_min_measurements(pid_idx, te_idx_raw, min_meas=3)
            if len(te_idx) == 0:
                raise RuntimeError("Time split produced empty test after >=3 measurements filter.")
            te_idx_use = te_idx
        else:
            raise ValueError("scenario must be 'cases' or 'time'")

        tr_idx, va_idx = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd)

        res = _fit_eval_once(
            build_model_fn, wrapper_cls,
            arch_defaults, train_defaults,
            X_of, X_fr, y,
            pid_idx,
            tr_idx, va_idx, te_idx_use,
            device=device,
            monitor_source=monitor_source,
            verbose=verbose
        )

        if verbose:
            _print_regression_metrics(res["metrics"], title="Single-fit test metrics")
        return res

    # -------------------- MODE: CV-ONLY --------------------
    if mode == "cv_only":
        fold_metrics: List[Dict[str, float]] = []

        if scenario == "cases":
            outer = GroupKFold(n_splits=outer_folds)
            outer_iter = outer.split(X_of, y[:, 0] if y.ndim > 1 else y, groups=pid_idx)
        elif scenario == "time":
            tss = TimeSeriesSplit(n_splits=outer_folds)
            order = np.argsort(time_ix)
            X_order = X_of[order]; y_order = y[order]
            outer_iter = ((order[tr], order[te]) for tr, te in tss.split(X_order, y_order[:, 0] if y.ndim > 1 else y_order))
        else:
            raise ValueError("scenario must be 'cases' or 'time'")

        for fold_id, (tr_idx_all, te_idx_raw) in enumerate(outer_iter, start=1):
            if scenario == "time":
                te_idx = _filter_time_test_min_measurements(pid_idx, te_idx_raw, min_meas=3)
                if len(te_idx) == 0:
                    if verbose: print(f"Fold {fold_id}: skipped (empty train/test).")
                    continue
                te_idx_use = te_idx
            else:
                te_idx_use = te_idx_raw

            tr_idx, va_idx = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd + fold_id)

            res = _fit_eval_once(
                build_model_fn, wrapper_cls,
                arch_defaults, train_defaults,
                X_of, X_fr, y,
                pid_idx,
                tr_idx, va_idx, te_idx_use,
                device=device,
                monitor_source=monitor_source,
                verbose=False
            )
            fold_metrics.append(res["metrics"])
            if verbose:
                print(f"\nFold {fold_id}:")
                _print_regression_metrics(res["metrics"], title="Per-fold test metrics")

        cv_summary = _summarize_cv_folds(fold_metrics)

        if verbose:
            print("\nCV averages (±95% CI):")
            for key in sorted(cv_summary.keys()):
                if key.endswith("_mean"):
                    base = key[:-5]
                    low = cv_summary.get(f"{base}_95ci_low", np.nan)
                    high = cv_summary.get(f"{base}_95ci_high", np.nan)
                    print(f"{base:>20}: {cv_summary[key]:.6f}  (95% CI {low:.6f}, {high:.6f})")

        return {
            "cv_folds_metrics": fold_metrics,
            "cv_summary": cv_summary,
        }

    # -------------------- MODE: NESTED CV --------------------
    if mode == "nested_cv":
        if not param_grid:
            param_grid = {
                "fixed_rep_dim": [32, 64, 128],
                "random_rep_dim": [32],
                "combine_mode": ["add"],
                "grl_lambda": [1.0],
                "lr": [1e-3, 3e-4],
                "weight_decay": [0.0, 1e-4],
                "batch_size": [256],
                "max_epochs": [100],
                "patience": [10],
            }

        results_folds = []
        best_score_global, best_params_global = -np.inf, None

        if scenario == "cases":
            outer = GroupKFold(n_splits=outer_folds)
            outer_iter = outer.split(X_of, y[:, 0] if y.ndim > 1 else y, groups=pid_idx)
        elif scenario == "time":
            tss = TimeSeriesSplit(n_splits=outer_folds)
            order = np.argsort(time_ix)
            X_order = X_of[order]; y_order = y[order]
            outer_iter = ((order[tr], order[te]) for tr, te in tss.split(X_order, y_order[:, 0] if y.ndim > 1 else y_order))
        else:
            raise ValueError("scenario must be 'cases' or 'time'")

        for fold_id, (tr_idx_all, te_idx_raw) in enumerate(outer_iter, start=1):
            if verbose:
                print(f"\nOuter fold {fold_id}/{outer_folds}")

            if scenario == "time":
                te_idx = _filter_time_test_min_measurements(pid_idx, te_idx_raw, min_meas=3)
                if len(te_idx) == 0:
                    if verbose: print(f"Skipping outer time fold {fold_id} (empty test after filter).")
                    continue
                te_idx_use = te_idx
            else:
                te_idx_use = te_idx_raw

            def inner_iter():
                if scenario == "cases":
                    inner = GroupKFold(n_splits=inner_folds)
                    return inner.split(X_of[tr_idx_all], (y[tr_idx_all, 0] if y.ndim > 1 else y[tr_idx_all]), groups=pid_idx[tr_idx_all])
                else:
                    tr_order = np.argsort(time_ix[tr_idx_all])
                    X_tr_order = X_of[tr_idx_all][tr_order]
                    y_tr_order = y[tr_idx_all][tr_order]
                    inner_tss = TimeSeriesSplit(n_splits=inner_folds)
                    return ((tr_idx_all[tr_order][itr], tr_idx_all[tr_order][iva])
                            for itr, iva in inner_tss.split(X_tr_order, y_tr_order[:, 0] if y.ndim > 1 else y_tr_order))

            best_inner_score, best_inner_params = -np.inf, None

            for params in ParameterGrid(param_grid):
                arch_params = dict(arch_defaults)
                train_params = dict(train_defaults)
                for k, v in params.items():
                    if k in ("fixed_rep_dim", "random_rep_dim", "combine_mode", "grl_lambda"):
                        arch_params[k] = v
                    else:
                        train_params[k] = v

                inner_scores = []
                for in_tr, in_va in inner_iter():
                    if scenario == "time":
                        in_va_f = _filter_time_test_min_measurements(pid_idx, in_va, min_meas=3)
                        if len(in_va_f) == 0:
                            continue
                        in_va = in_va_f

                    rng_seed = rnd + fold_id
                    tr_idx_inner, va_idx_inner = _make_train_val_split(np.asarray(in_tr), seed=rng_seed)

                    res_inner = _fit_eval_once(
                        build_model_fn, wrapper_cls,
                        arch_params, train_params,
                        X_of, X_fr, y,
                        pid_idx,
                        tr_idx_inner, va_idx_inner, in_va,
                        device=device,
                        monitor_source="val",
                        verbose=False
                    )
                    score = res_inner["metrics"].get("macro_R2", np.nan)
                    inner_scores.append(score)

                avg_score = float(np.nanmean(inner_scores)) if len(inner_scores) else -np.inf
                if avg_score > best_inner_score:
                    best_inner_score = avg_score
                    best_inner_params = (arch_params, train_params)

            if best_inner_params is None:
                if verbose: print("No viable inner config; skipping outer fold.")
                continue

            arch_params, train_params = best_inner_params
            tr_idx_outer, va_idx_outer = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd + fold_id * 17)

            res_outer = _fit_eval_once(
                build_model_fn, wrapper_cls,
                arch_params, train_params,
                X_of, X_fr, y,
                pid_idx,
                tr_idx_outer, va_idx_outer, te_idx_use,
                device=device,
                monitor_source=monitor_source,
                verbose=False
            )
            results_folds.append(res_outer)

            score_outer = res_outer["metrics"].get("macro_R2", -np.inf)
            if score_outer > best_score_global:
                best_score_global = score_outer
                best_params_global = (arch_params, train_params)

            if verbose:
                _print_regression_metrics(res_outer["metrics"], title="Outer fold test metrics")

        def summarize(results_list: List[Dict[str, Any]]) -> Dict[str, float]:
            keys = list(results_list[0]["metrics"].keys())
            out = {}
            for k in keys:
                arr = np.array([res["metrics"][k] for res in results_list], dtype=float)
                m = float(np.nanmean(arr)); s = float(np.nanstd(arr, ddof=1)); n = len(arr)
                se = s / np.sqrt(n) if n > 1 else np.nan
                if n > 1:
                    tcrit = float(student_t.ppf(0.975, df=n-1))
                    ci = (m - tcrit * se, m + tcrit * se)
                else:
                    ci = (np.nan, np.nan)
                out[k + "_mean"] = m
                out[k + "_95ci_low"] = ci[0]
                out[k + "_95ci_high"] = ci[1]
            return out

        cv_summary = summarize(results_folds)

        if verbose and best_params_global is not None:
            print("\nBest params (by outer macro_R2):")
            arch_p, train_p = best_params_global
            print("[ARCH]:");   [print(f"  {k}: {v}") for k, v in arch_p.items()]
            print("[TRAIN]:");  [print(f"  {k}: {v}") for k, v in train_p.items()]

        print("\nCross-validation results:")
        print("CV Summary:")
        print(cv_summary)

        # Optional final refit (kept silent to match print moments)
        if scenario == "cases":
            tr_idx_all, te_idx = _split_cases(pid_idx, test_fraction=0.2, seed=rnd)
            te_idx_use = te_idx
        else:
            tr_idx_all, te_idx_raw = _split_time_basic(time_ix, test_fraction=0.2)
            te_idx = _filter_time_test_min_measurements(pid_idx, te_idx_raw, min_meas=3)
            if len(te_idx) == 0:
                raise RuntimeError("Final refit: time split produced empty test after filter.")
            te_idx_use = te_idx

        tr_idx_final, va_idx_final = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd + 999)

        final_res = _fit_eval_once(
            build_model_fn, wrapper_cls,
            best_params_global[0], best_params_global[1],
            X_of, X_fr, y,
            pid_idx,
            tr_idx_final, va_idx_final, te_idx_use,
            device=device,
            monitor_source=monitor_source,
            verbose=False
        )

        return {
            "outer_folds": results_folds,
            "cv_summary": cv_summary,
            "best_params": {"arch": best_params_global[0], "train": best_params_global[1]},
            "final_refit": final_res,
        }

    raise ValueError("mode must be one of {'single','cv_only','nested_cv'}")


## Model test

### Define variables and parameters

In [27]:
y_raw      = EXP_reg_y.to_numpy(np.float32)
y_np       = y_raw if y_raw.ndim == 2 else y_raw.reshape(-1, 1)

pid_raw    = EXP_reg_participant_id.to_numpy().ravel()
pid_uniqs, pid_encoded = np.unique(pid_raw, return_inverse=True)
pid_np     = pid_encoded.astype(np.int64)
n_ids      = int(len(pid_uniqs))

time_ix_np = EXP_reg_time.to_numpy().ravel()


def build_model_fn(d_fixed, d_random, y_dim, n_participants, **arch):
    return ARMEDTabular(
        d_fixed=d_fixed,
        d_random=d_random,
        y_dim=y_dim,
        n_participants=n_participants,
        **arch
    )

wrapper_cls = ARMEDWrapper


arch_defaults = dict(
    fixed_rep_dim=256,
    random_rep_dim=256,
    combine_mode="add",
    grl_lambda=1.0,
)

train_defaults = dict(
    lr=1e-4,
    weight_decay=1e-4,
    batch_size=256,
    max_epochs=100,
    patience=20,
    loss_weights=ARMEDLossWeights(lambda_adv=1.0, lambda_recon=0.0),

    # PCA on each block (fit on TRAIN within split)
    pca_var_ratio=0.95,

    # Console monitoring each epoch (uses this loader for the printed “monitor_loss”)
    monitor_source="test",          # or "val"

    # Early stopping target (pick ONE)
    early_stop_metric="macro_R2",   # "loss" (prediction MSE) | "macro_R2" (↑ better) | "macro_MAE" (↓ better)
    early_stop_on="val",            # "val" | "train" | "monitor"

    # Split controls
    random_state=42,
    val_fraction=0.10,
)


### Simple split test

In [28]:
res_single = run_training_and_eval_armed(
    X_only_fixed=EXP_reg_only_fixed,              # e.g., EXP_reg_X_only_fixed.to_numpy(np.float32)
    X_fixed_and_random=EXP_reg_fixed_and_random,  # or None
    y=y_np,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=wrapper_cls,
    mode="single",
    scenario="both",   # or "time"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
)


Epoch 001 | train 28.668227 | monitor_loss 13.754825
Epoch 002 | train 24.318638 | monitor_loss 12.064663
Epoch 003 | train 21.372946 | monitor_loss 11.418877
Epoch 004 | train 19.531485 | monitor_loss 10.973756
Epoch 005 | train 18.108686 | monitor_loss 10.185671
Epoch 006 | train 16.918664 | monitor_loss 9.735870
Epoch 007 | train 16.016300 | monitor_loss 9.569500
Epoch 008 | train 15.430674 | monitor_loss 9.600272
Epoch 009 | train 15.014173 | monitor_loss 9.658867
Epoch 010 | train 14.685780 | monitor_loss 9.745872
Epoch 011 | train 14.407295 | monitor_loss 9.735779
Epoch 012 | train 14.179964 | monitor_loss 9.800245
Epoch 013 | train 13.988724 | monitor_loss 9.709022
Epoch 014 | train 13.817087 | monitor_loss 9.869119
Epoch 015 | train 13.687785 | monitor_loss 9.851102
Epoch 016 | train 13.539674 | monitor_loss 9.905873
Epoch 017 | train 13.416805 | monitor_loss 9.845573
Epoch 018 | train 13.289579 | monitor_loss 9.850217
Epoch 019 | train 13.225102 | monitor_loss 9.771808
Epoch 0

### Nested CV with parameter search

In [None]:
param_grid = {
    "fixed_rep_dim": [64, 128, 256],
    "random_rep_dim": [64, 128, 256],
    "combine_mode": ["add"],       # "film" also supported
    "grl_lambda": [0.5, 1.0],

    "lr": [1e-3, 3e-4, 1e-4],
    "weight_decay": [0.0, 1e-4],
    "batch_size": [64, 128, 256],
    "max_epochs": [100],
    "patience": [20],
}

res_nested = run_training_and_eval_armed(
    X_only_fixed=EXP_reg_only_fixed,              # e.g., EXP_reg_X_only_fixed.to_numpy(np.float32)
    X_fixed_and_random=EXP_reg_fixed_and_random,  # or None
    y=y_np,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=wrapper_cls,
    mode="nested_cv",
    scenario="both",            # or "time"
    outer_folds=5,
    inner_folds=3,
    param_grid=param_grid,
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
)



Outer fold 1/5
Outer fold macro: macro_MAE=2.6737, macro_MSE=11.3197, macro_RMSE=3.3645, macro_MedAE=2.0651, macro_R2=0.3200, macro_EVS=0.3535, macro_MAPE=144465515.8408, macro_sMAPE=1.4069

Outer fold 2/5
Outer fold macro: macro_MAE=2.7981, macro_MSE=12.9290, macro_RMSE=3.5957, macro_MedAE=2.1650, macro_R2=0.4179, macro_EVS=0.4213, macro_MAPE=97661197.6936, macro_sMAPE=1.1188

Outer fold 3/5
Outer fold macro: macro_MAE=2.2361, macro_MSE=10.3293, macro_RMSE=3.2139, macro_MedAE=1.4546, macro_R2=0.2600, macro_EVS=0.2691, macro_MAPE=62917285.7755, macro_sMAPE=1.1926

Outer fold 4/5
Outer fold macro: macro_MAE=2.4923, macro_MSE=15.4924, macro_RMSE=3.9360, macro_MedAE=1.0839, macro_R2=0.0088, macro_EVS=0.2327, macro_MAPE=28625357.8250, macro_sMAPE=1.6312

Outer fold 5/5
Outer fold macro: macro_MAE=2.4270, macro_MSE=9.7935, macro_RMSE=3.1295, macro_MedAE=1.8025, macro_R2=0.1034, macro_EVS=0.1334, macro_MAPE=105777933.2864, macro_sMAPE=1.2509

Best params (by outer macro_R2):
[ARCH]:
  fixed

### CV without parameter search

In [29]:
res_cv = run_training_and_eval_armed(
    X_only_fixed=EXP_reg_only_fixed,
    X_fixed_and_random=EXP_reg_fixed_and_random,
    y=y_np,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=wrapper_cls,
    mode="cv_only",
    scenario="both",   # or "time"
    outer_folds=5,
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
)


Fold 1:

Per-fold test metrics:
       macro_MAE: 2.304174
       macro_MSE: 10.692920
  macro_PearsonR: 0.602544
        macro_R2: 0.357656
      macro_RMSE: 3.270003
task_1: MSE=10.692920, RMSE=3.270003, MAE=2.304174, R2=0.357656, PearsonR=0.602544

Fold 2:

Per-fold test metrics:
       macro_MAE: 2.876254
       macro_MSE: 12.502810
  macro_PearsonR: 0.725843
        macro_R2: 0.437114
      macro_RMSE: 3.535931
task_1: MSE=12.502810, RMSE=3.535931, MAE=2.876254, R2=0.437114, PearsonR=0.725843

Fold 3:

Per-fold test metrics:
       macro_MAE: 2.280378
       macro_MSE: 10.972196
  macro_PearsonR: 0.494338
        macro_R2: 0.213945
      macro_RMSE: 3.312430
task_1: MSE=10.972196, RMSE=3.312430, MAE=2.280378, R2=0.213945, PearsonR=0.494338

Fold 4:

Per-fold test metrics:
       macro_MAE: 2.609625
       macro_MSE: 11.495363
  macro_PearsonR: 0.522278
        macro_R2: 0.264511
      macro_RMSE: 3.390481
task_1: MSE=11.495363, RMSE=3.390481, MAE=2.609625, R2=0.264511, PearsonR=0