# Setup

## Root and data folders

In [2]:
import os
import pandas as pd
import numpy as np

root_dir = "/Users/silviumatu/Desktop/Code/Python/Disertatie/Disertatie_Matu_Silviu_v1"
os.makedirs(root_dir, exist_ok=True)

data_dir = os.path.join(root_dir, "Data")
os.makedirs(data_dir, exist_ok=True)

# Load data

In [17]:
GHQ_cat_df = pd.read_csv(os.path.join(data_dir, "PED_GHQ_categorical_data_forecast.csv"))
columns_GHQ_cat_df = pd.read_csv(os.path.join(data_dir, "columns_PED_GHQ_categorical_data_forecast_S.csv"))

GHQ_cat_df.head()

Unnamed: 0,x_age_baseline,x_gender_baseline,x_previous_psychiatric_diagnostic_baseline,x_previous_psychiatric_treatment_baseline,x_previous_psychologist_baseline,x_curent_psychiatric_treatment_baseline,x_current_psychologist_baseline,x_ABS_irrational_TOTAL_baseline,x_ABS_rational_TOTAL_baseline,x_ATS_Generalization_TOTAL_baseline,...,x_mental_health_last_month_baseline_value_5,x_stress_management_first_day_value_0,x_stress_management_first_day_value_1,x_stress_management_first_day_value_2,x_participant_id,x_GHQ_TOTAL_score_category_first_day,x_time_copy,x_time_difference_first_day,y_GHQ_TOTAL_score_category_next,x_time_difference_first_day_copy
0,19.0,0.0,0.0,0.0,1.0,0.0,0.0,45.0,26.0,8.0,...,0,0,0,1,0,1,0.0,1.0,0.0,1.0
1,19.0,0.0,0.0,0.0,1.0,0.0,0.0,45.0,26.0,8.0,...,0,0,1,0,0,0,1.0,3.0,1.0,3.0
2,19.0,0.0,0.0,0.0,1.0,0.0,0.0,45.0,26.0,8.0,...,0,0,1,0,0,1,4.0,1.0,0.0,1.0
3,19.0,0.0,0.0,0.0,1.0,0.0,0.0,45.0,26.0,8.0,...,0,0,1,0,0,0,5.0,3.0,0.0,3.0
4,18.0,0.0,0.0,0.0,0.0,0.0,0.0,44.0,48.0,10.0,...,0,0,0,1,2,1,0.0,2.0,1.0,2.0


In [18]:
# Select the outcome column(s) marked with 1 in the "outcomes" column of columns_GHQ_cat_df
GHQ_cat_outcome_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['outcomes'] == 1, 'column_name'].tolist()
GHQ_cat_y = GHQ_cat_df[GHQ_cat_outcome_cols]
GHQ_cat_y.head()

# Same for outcomes lags column(s)
GHQ_cat_outcomes_lags_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['outcomes_lags'] == 1, 'column_name'].tolist()
GHQ_cat_outcomes_lags = GHQ_cat_df[GHQ_cat_outcomes_lags_cols]

# Same for participant column(s)
GHQ_cat_participant_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['participant_id'] == 1, 'column_name'].tolist()
GHQ_cat_participant_id = GHQ_cat_df[GHQ_cat_participant_cols]

# Same for time column(s)
GHQ_cat_time_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['time'] == 1, 'column_name'].tolist()
GHQ_cat_time = GHQ_cat_df[GHQ_cat_time_cols]

# Same for forecast horizons column(s)
GHQ_cat_forecast_horizons_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['forecast_horizons'] == 1, 'column_name'].tolist()
GHQ_cat_forecast_horizons = GHQ_cat_df[GHQ_cat_forecast_horizons_cols]

# Same for fixed effects column(s)
GHQ_cat_only_fixed_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['only_fixed'] == 1, 'column_name'].tolist()
GHQ_cat_only_fixed = GHQ_cat_df[GHQ_cat_only_fixed_cols]

# Same for random effects column(s)
GHQ_cat_fixed_and_random_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['fixed_and_random'] == 1, 'column_name'].tolist()
GHQ_cat_fixed_and_random = GHQ_cat_df[GHQ_cat_fixed_and_random_cols]

# Custom model

## Architecture

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Tuple, Dict, Any

try:
    from efficient_kan import KANLinear  # noqa: F401
except Exception:
    class KANLinear(nn.Module):
        def __init__(self, *args, **kwargs):
            super().__init__()
        def forward(self, x):
            return x
        def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
            return torch.tensor(0.0, device=torch.device("cpu"))

def _make_kan(in_dim, out_dim,
              grid_size=8, spline_order=3,
              scale_noise=0.1, scale_base=1.0, scale_spline=1.0,
              enable_standalone_scale_spline=True,
              base_activation=torch.nn.SiLU,
              grid_eps=0.02, grid_range=(-1.0, 1.0)):
    return nn.Linear(in_dim, out_dim)

class KANBlock(nn.Module):
    def __init__(self, in_dim, hidden_dims=(128, 64), out_dim=None,
                 dropout=0.0,
                 grid_size=8, spline_order=3,
                 scale_noise=0.1, scale_base=1.0, scale_spline=1.0,
                 enable_standalone_scale_spline=True,
                 base_activation=torch.nn.SiLU,
                 grid_eps=0.02, grid_range=(-1.0, 1.0)):
        super().__init__()
        dims = [in_dim] + list(hidden_dims)
        act = base_activation()
        layers = []
        for d0, d1 in zip(dims[:-1], dims[1:]):
            layers.append(nn.Linear(d0, d1))
            layers.append(act)
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
        if out_dim is not None:
            layers.append(nn.Linear(dims[-1], out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        reg = 0.0
        for m in self.modules():
            if isinstance(m, KANLinear):
                reg = reg + m.regularization_loss(
                    regularize_activation=regularize_activation,
                    regularize_entropy=regularize_entropy
                )
        return reg

class TemporalKernelAttentionKAN(nn.Module):
    def __init__(self, n_kernels: int = 4, d_att: int = 32,
                 grid_size=8, spline_order=3, dropout=0.0,
                 normalize_weights: bool = False):
        super().__init__()
        self.n_k = n_kernels
        self.normalize = normalize_weights
        self.pi_logits = nn.Parameter(torch.zeros(n_kernels))
        self.lam_raw = nn.Parameter(torch.zeros(n_kernels))
        self.summarize = KANBlock(1, hidden_dims=(d_att,), out_dim=d_att,
                                  dropout=dropout, grid_size=grid_size, spline_order=spline_order)
        self.out_head = KANBlock(d_att, hidden_dims=(d_att,), out_dim=1,
                                 dropout=dropout, grid_size=grid_size, spline_order=spline_order)

    def forward(self, y_lags, dt_lags):
        B, L = y_lags.shape
        pi = F.softmax(self.pi_logits, dim=-1)
        lam = F.softplus(self.lam_raw) + 1e-6
        pi_exp = pi.view(1, 1, -1).expand(B, 1, -1)
        lam_exp = lam.view(1, 1, -1).expand(B, 1, -1)
        kernel = torch.exp(-lam_exp * dt_lags.unsqueeze(-1))
        w_lags = torch.sum(pi_exp * kernel, dim=-1)
        if self.normalize:
            w_sum = w_lags.sum(dim=1, keepdim=True) + 1e-8
            w_lags = w_lags / w_sum
        s = torch.sum(w_lags * y_lags, dim=1, keepdim=True)
        z_att = self.summarize(s)
        e_att = self.out_head(z_att)
        return e_att, w_lags, z_att

class FixedBranchKAN(nn.Module):
    def __init__(self, d_fix, d_latent=64, grid_size=8, spline_order=3, dropout=0.0):
        super().__init__()
        self.enc = KANBlock(d_fix, hidden_dims=(256, 128), out_dim=d_latent,
                            dropout=dropout, grid_size=grid_size, spline_order=spline_order)
        self.head = KANBlock(d_latent, hidden_dims=(64,), out_dim=1,
                             dropout=dropout, grid_size=grid_size, spline_order=spline_order)

    def forward(self, X_fix):
        z = self.enc(X_fix)
        e = self.head(z)
        return e, z

class RandEncoderKAN(nn.Module):
    def __init__(self, d_zrand, d_latent=64, grid_size=8, spline_order=3, dropout=0.0):
        super().__init__()
        self.enc = KANBlock(d_zrand, hidden_dims=(256, 128), out_dim=d_latent,
                            dropout=dropout, grid_size=grid_size, spline_order=spline_order)

    def forward(self, Zrand):
        return self.enc(Zrand)

class TCEncoderKAN(nn.Module):
    def __init__(self, d_tc, d_latent=64, grid_size=8, spline_order=3, dropout=0.0):
        super().__init__()
        self.enc = KANBlock(d_tc, hidden_dims=(256, 128), out_dim=d_latent,
                            dropout=dropout, grid_size=grid_size, spline_order=spline_order)

    def forward(self, TC):
        return self.enc(TC)

class RandomHeadKAN(nn.Module):
    def __init__(self, d_latent=64, grid_size=8, spline_order=3, dropout=0.0):
        super().__init__()
        self.head = KANBlock(d_latent, hidden_dims=(64,), out_dim=1,
                             dropout=dropout, grid_size=grid_size, spline_order=spline_order)

    def forward(self, z):
        return self.head(z)

class FiLMFromTC(nn.Module):
    def __init__(self, d_latent=64, grid_size=8, spline_order=3, dropout=0.0):
        super().__init__()
        self.gamma = KANBlock(d_latent, hidden_dims=(64,), out_dim=d_latent,
                              dropout=dropout, grid_size=grid_size, spline_order=spline_order)
        self.beta = KANBlock(d_latent, hidden_dims=(64,), out_dim=d_latent,
                             dropout=dropout, grid_size=grid_size, spline_order=spline_order)

    def forward(self, e_tc):
        return self.gamma(e_tc), self.beta(e_tc)

class KANAdditiveMixedEffects(nn.Module):
    def __init__(self,
                 y_dim: int,
                 d_fix: int,
                 d_tc: int,
                 d_zrand: int,
                 n_ids: int = None,
                 use_id_intercept: bool = False,
                 n_kernels: int = 4,
                 d_att: int = 128,
                 d_fix_latent: int = 128,
                 d_rand_latent: int = 128,
                 grid_size: int = 8,
                 spline_order: int = 3,
                 dropout: float = 0.0,
                 normalize_att_weights: bool = True,
                 use_attention: bool = True,
                 use_random: bool = True):
        super().__init__()
        self.y_dim = y_dim
        self.d_fix = d_fix
        self.d_tc  = d_tc
        self.d_zr  = d_zrand
        self.d_rand_latent = d_rand_latent
        self.d_att = d_att
        self.use_attention = use_attention
        self.use_random = use_random

        self.att_branches = nn.ModuleList([
            TemporalKernelAttentionKAN(n_kernels=n_kernels, d_att=d_att,
                                       grid_size=grid_size, spline_order=spline_order,
                                       dropout=dropout, normalize_weights=normalize_att_weights)
            for _ in range(y_dim)
        ]) if use_attention else None

        self.fix_branches = nn.ModuleList([
            FixedBranchKAN(d_fix=d_fix, d_latent=d_fix_latent,
                           grid_size=grid_size, spline_order=spline_order, dropout=dropout)
            for _ in range(y_dim)
        ])

        if use_random:
            self.rand_encoders = nn.ModuleList([
                RandEncoderKAN(d_zrand=d_zrand, d_latent=d_rand_latent,
                               grid_size=grid_size, spline_order=spline_order, dropout=dropout)
                for _ in range(y_dim)
            ]) if d_zrand > 0 else None
            self.tc_encoders = nn.ModuleList([
                TCEncoderKAN(d_tc=d_tc, d_latent=d_rand_latent,
                             grid_size=grid_size, spline_order=spline_order, dropout=dropout)
                for _ in range(y_dim)
            ]) if d_tc > 0 else None
            self.film_from_tc = nn.ModuleList([
                FiLMFromTC(d_latent=d_rand_latent, grid_size=grid_size, spline_order=spline_order, dropout=dropout)
                for _ in range(y_dim)
            ]) if d_tc > 0 else None
            self.rand_heads = nn.ModuleList([
                RandomHeadKAN(d_latent=d_rand_latent, grid_size=grid_size, spline_order=spline_order, dropout=dropout)
                for _ in range(y_dim)
            ])
        else:
            self.rand_encoders = None
            self.tc_encoders = None
            self.film_from_tc = None
            self.rand_heads = None

    def forward(self, X_fix, TC, Zrand, y_lags, dt_lags,
                pid_idx: Optional[torch.Tensor] = None,
                pid_seen_mask: Optional[torch.Tensor] = None):
        B = X_fix.size(0)

        if y_lags.dim() == 2:
            y_lags_list = [y_lags for _ in range(self.y_dim)]
        elif y_lags.dim() == 3:
            assert y_lags.size(2) == self.y_dim
            y_lags_list = [y_lags[:, :, j] for j in range(self.y_dim)]
        else:
            raise ValueError("y_lags must be [B, L] or [B, L, y_dim].")

        e_att_all, e_fix_all, e_rand_all = [], [], []
        z_att_list, z_fix_list = [], []
        z_rand_list, z_rand_film_list, z_tc_list = [], [], []
        w_lags_list, e_rand_mod_list, gamma_list, beta_list = [], [], [], []

        L = dt_lags.size(1) if dt_lags is not None and dt_lags.dim() == 2 else 1

        for j in range(self.y_dim):
            if self.use_attention and (self.att_branches is not None):
                e_att_j, w_lags_j, z_att_j = self.att_branches[j](y_lags_list[j], dt_lags)
            else:
                e_att_j = torch.zeros(B, 1, device=X_fix.device, dtype=X_fix.dtype)
                w_lags_j = torch.zeros(B, L, device=X_fix.device, dtype=X_fix.dtype)
                z_att_j = torch.zeros(B, self.d_att, device=X_fix.device, dtype=X_fix.dtype)

            e_fix_j, z_fix_j = self.fix_branches[j](X_fix)

            if self.use_random and (self.rand_heads is not None):
                if (Zrand is not None) and (Zrand.size(1) > 0) and (self.rand_encoders is not None):
                    z_rand_j = self.rand_encoders[j](Zrand)
                else:
                    z_rand_j = torch.zeros(B, self.d_rand_latent, device=X_fix.device, dtype=X_fix.dtype)

                if self.d_tc > 0 and (TC is not None) and (TC.size(1) > 0) and (self.tc_encoders is not None):
                    z_tc_j = self.tc_encoders[j](TC)
                    gamma_j, beta_j = self.film_from_tc[j](z_tc_j) if self.film_from_tc is not None else (torch.ones_like(z_rand_j), torch.zeros_like(z_rand_j))
                else:
                    z_tc_j = torch.zeros_like(z_rand_j)
                    gamma_j = torch.ones_like(z_rand_j)
                    beta_j  = torch.zeros_like(z_rand_j)

                z_tilde_j = gamma_j * z_rand_j + beta_j
                e_rand_j  = self.rand_heads[j](z_tilde_j)
            else:
                z_rand_j = torch.zeros(B, self.d_rand_latent, device=X_fix.device, dtype=X_fix.dtype)
                z_tc_j   = torch.zeros_like(z_rand_j)
                gamma_j  = torch.ones_like(z_rand_j)
                beta_j   = torch.zeros_like(z_rand_j)
                z_tilde_j= z_rand_j
                e_rand_j = torch.zeros(B, 1, device=X_fix.device, dtype=X_fix.dtype)

            e_att_all.append(e_att_j)
            e_fix_all.append(e_fix_j)
            e_rand_all.append(e_rand_j)

            z_att_list.append(z_att_j)
            z_fix_list.append(z_fix_j)
            z_rand_list.append(z_rand_j)
            z_rand_film_list.append(z_tilde_j)
            z_tc_list.append(z_tc_j)

            w_lags_list.append(w_lags_j)
            e_rand_mod_list.append(e_rand_j)
            gamma_list.append(gamma_j)
            beta_list.append(beta_j)

        e_att = torch.cat(e_att_all, dim=1) if e_att_all else torch.zeros(B, self.y_dim, device=X_fix.device, dtype=X_fix.dtype)
        e_fix = torch.cat(e_fix_all, dim=1) if e_fix_all else torch.zeros(B, self.y_dim, device=X_fix.device, dtype=X_fix.dtype)
        e_rand= torch.cat(e_rand_all,dim=1) if e_rand_all else torch.zeros(B, self.y_dim, device=X_fix.device, dtype=X_fix.dtype)
        logits = e_att + e_fix + e_rand

        parts = {
            "e_att": e_att,
            "e_fix": e_fix,
            "e_rand": e_rand,
            "z_att_list": z_att_list,
            "z_fix_list": z_fix_list,
            "z_rand_list": z_rand_list,
            "z_rand_film_list": z_rand_film_list,
            "z_tc_list": z_tc_list,
            "w_lags_list": w_lags_list,
            "e_rand_mod_list": e_rand_mod_list,
            "film_gamma_list": gamma_list,
            "film_beta_list": beta_list,
        }
        return logits, parts


## Wrapper

In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, Any

def _bce_logits(
    y_hat_logits: torch.Tensor,
    y_true: torch.Tensor,
    pos_weight: Optional[torch.Tensor] = None,
    sample_weight: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if y_true.shape != y_hat_logits.shape:
        if y_true.dim() == 1 and y_hat_logits.dim() == 2 and y_hat_logits.size(1) == 1:
            y_true = y_true.unsqueeze(1)
        else:
            raise ValueError(f"y_true shape {y_true.shape} must equal logits shape {y_hat_logits.shape}")
    pw = None
    if pos_weight is not None:
        pw = pos_weight.to(y_hat_logits.device, dtype=y_hat_logits.dtype)
        if pw.numel() == 1:
            pw = pw.view(1)
        elif y_hat_logits.dim() == 2 and pw.numel() == y_hat_logits.size(1):
            pw = pw.view(-1)
        else:
            raise ValueError(
                f"pos_weight has {pw.numel()} elements but needs 1 or {y_hat_logits.size(-1)}"
            )
    sw = None
    if sample_weight is not None:
        sw = sample_weight.to(y_hat_logits.device, dtype=y_hat_logits.dtype)
    return F.binary_cross_entropy_with_logits(
        y_hat_logits, y_true.float(), weight=sw, pos_weight=pw
    )

def _orthogonality_penalty_latents(z_fix: torch.Tensor, z_rand_film: torch.Tensor) -> torch.Tensor:
    B = z_fix.size(0)
    if B <= 1:
        return z_fix.new_zeros(())
    zf = z_fix - z_fix.mean(dim=0, keepdim=True)
    zr = z_rand_film - z_rand_film.mean(dim=0, keepdim=True)
    M = (zf.T @ zr) / float(B)
    return (M ** 2).mean()

def _optimal_threshold_exact_np(y_true: np.ndarray, y_prob: np.ndarray, beta: float = 1.0) -> float:
    y = np.asarray(y_true, dtype=int).ravel()
    p = np.asarray(y_prob, dtype=float).ravel()
    if y.size == 0:
        return 0.5
    P = int(y.sum()); N = y.size - P
    if P == 0:
        return 1.0
    if N == 0:
        return 0.0
    o = np.argsort(-p)
    p = p[o]; y = y[o]
    tp = np.cumsum(y)
    fp = np.cumsum(1 - y)
    prec = tp / np.maximum(1, tp + fp)
    rec  = tp / max(1, P)
    f = (1 + beta**2) * prec * rec / np.maximum(1e-12, beta**2 * prec + rec)
    idx = int(np.nanargmax(f))
    next_p = p[idx+1] if idx + 1 < len(p) else -np.inf
    thr = (p[idx] + next_p) / 2.0 if np.isfinite(next_p) and next_p < p[idx] else max(0.0, p[idx] - np.finfo(p.dtype).eps)
    return float(thr)

def _combine_logits_from_parts(parts: dict, combine: str = "all"):
    alias = {
        "time_constant": "only_fixed",
        "time_varying": "all",
        "fixed": "only_fixed",
    }
    mode = alias.get(combine, combine)
    e_fix  = parts["e_fix"]
    e_rand = parts.get("e_rand", None)
    e_att  = parts.get("e_att", None)
    if mode == "only_fixed":
        return e_fix
    elif mode == "fixed_and_random":
        if e_rand is None:
            e_rand = torch.zeros_like(e_fix)
        return e_fix + e_rand
    elif mode == "all":
        if e_rand is None:
            e_rand = torch.zeros_like(e_fix)
        if e_att is None:
            e_att = torch.zeros_like(e_fix)
        return e_fix + e_rand + e_att
    else:
        raise ValueError("combine must be one of {'only_fixed','fixed_and_random','all','time_constant','time_varying','fixed'}")

class KANMixedEffectsWrapper:
    def __init__(self, model: nn.Module, cfg: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None):
        self.model = model
        self.cfg = {
            "lambda_mean0": 1e-4,
            "lambda_ridge": 1e-4,
            "lambda_orth_latent": 1e-3,
            "lambda_film_identity": 1e-4,
            "lambda_kan": 0.0,
            "kan_reg_activation": 1.0,
            "kan_reg_entropy": 1.0,
            "clip_grad": 5.0,
            "lr": 1e-3,
            "weight_decay": 1e-4,
            "max_epochs": 100,
            "patience": 10,
            "batch_size": 256,
            "amp": True,
            "threshold": 0.5,
            "auto_pos_weight": True,
            "pos_weight_eps": 1e-6,
            "sample_weight_index": None,
        }
        if cfg:
            self.cfg.update(cfg)
        self.device = (
            device
            or (torch.device("mps") if torch.backends.mps.is_available() else None)
            or (torch.device("cuda") if torch.cuda.is_available() else None)
            or torch.device("cpu")
        )
        self.model.to(self.device)
        self.thresholds_: Optional[np.ndarray] = None
        self.history_: Dict[str, list] = {"train_loss": [], "val_loss": []}

    def forward(
        self,
        X_fix: torch.Tensor,
        TC: Optional[torch.Tensor],
        Zrand: Optional[torch.Tensor],
        y_lags: torch.Tensor,
        dt_lags: torch.Tensor,
        pid_idx: Optional[torch.Tensor] = None,
        pid_seen_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        return self.model(
            X_fix=X_fix,
            TC=TC,
            Zrand=Zrand,
            y_lags=y_lags,
            dt_lags=dt_lags,
            pid_idx=pid_idx,
            pid_seen_mask=pid_seen_mask
        )

    def _sum_kan_regularization(self) -> torch.Tensor:
        if hasattr(self.model, "regularization_loss"):
            return self.model.regularization_loss(
                regularize_activation=self.cfg["kan_reg_activation"],
                regularize_entropy=self.cfg["kan_reg_entropy"]
            )
        device = self.device
        reg = torch.tensor(0.0, device=device)
        use_att = bool(getattr(self.model, "use_attention", True))
        use_rand = bool(getattr(self.model, "use_random", True))
        if hasattr(self.model, "fix_branches") and self.model.fix_branches is not None:
            for fb in self.model.fix_branches:
                if hasattr(fb, "enc"):  reg = reg + fb.enc.regularization_loss(self.cfg["kan_reg_activation"], self.cfg["kan_reg_entropy"])
                if hasattr(fb, "head"): reg = reg + fb.head.regularization_loss(self.cfg["kan_reg_activation"], self.cfg["kan_reg_entropy"])
        if use_att and hasattr(self.model, "att_branches") and self.model.att_branches is not None:
            for ab in self.model.att_branches:
                if hasattr(ab, "summarize"): reg = reg + ab.summarize.regularization_loss(self.cfg["kan_reg_activation"], self.cfg["kan_reg_entropy"])
                if hasattr(ab, "out_head"):  reg = reg + ab.out_head.regularization_loss(self.cfg["kan_reg_activation"], self.cfg["kan_reg_entropy"])
        if use_rand:
            if hasattr(self.model, "rand_encoders") and self.model.rand_encoders is not None:
                for re in self.model.rand_encoders:
                    if hasattr(re, "enc"): reg = reg + re.enc.regularization_loss(self.cfg["kan_reg_activation"], self.cfg["kan_reg_entropy"])
            if hasattr(self.model, "tc_encoders") and self.model.tc_encoders is not None:
                for te in self.model.tc_encoders:
                    if hasattr(te, "enc"): reg = reg + te.enc.regularization_loss(self.cfg["kan_reg_activation"], self.cfg["kan_reg_entropy"])
            if hasattr(self.model, "film_from_tc") and self.model.film_from_tc is not None:
                for film in self.model.film_from_tc:
                    if hasattr(film, "gamma"): reg = reg + film.gamma.regularization_loss(self.cfg["kan_reg_activation"], self.cfg["kan_reg_entropy"])
                    if hasattr(film, "beta"):  reg = reg + film.beta.regularization_loss(self.cfg["kan_reg_activation"], self.cfg["kan_reg_entropy"])
            if hasattr(self.model, "rand_heads") and self.model.rand_heads is not None:
                for rh in self.model.rand_heads:
                    if hasattr(rh, "head"): reg = reg + rh.head.regularization_loss(self.cfg["kan_reg_activation"], self.cfg["kan_reg_entropy"])
        return reg

    def _compute_pos_weight_from_loader(self, train_loader) -> torch.Tensor:
        device = self.device
        pos_sum = None
        total_sum = 0
        with torch.no_grad():
            for batch in train_loader:
                y_b = batch[3].to(device)
                if y_b.dim() == 1:
                    y_b = y_b.unsqueeze(1)
                bs, y_dim = y_b.shape
                if pos_sum is None:
                    pos_sum = torch.zeros(y_dim, device=device, dtype=torch.float32)
                pos_sum += y_b.float().sum(dim=0)
                total_sum += bs
        if pos_sum is None:
            return torch.ones(1, device=device, dtype=torch.float32)
        P = pos_sum
        T = torch.tensor(float(total_sum), device=device, dtype=torch.float32)
        N = T - P
        eps = float(self.cfg.get("pos_weight_eps", 1e-6))
        pos_weight = N / torch.clamp(P, min=eps)
        pos_weight = torch.where(torch.isfinite(pos_weight), pos_weight, torch.ones_like(pos_weight))
        pos_weight = torch.clamp(pos_weight, min=eps)
        return pos_weight

    def compute_loss(
        self,
        y_true: torch.Tensor,
        logits: torch.Tensor,
        parts: Dict[str, Any],
        *,
        X_fix: torch.Tensor,
        pid_idx: Optional[torch.Tensor] = None,
        TC: Optional[torch.Tensor] = None,
        pos_weight: Optional[torch.Tensor] = None,
        sample_weight: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        cfg = self.cfg
        use_rand = bool(getattr(self.model, "use_random", True))

        loss_pred = _bce_logits(logits, y_true, pos_weight=pos_weight, sample_weight=sample_weight)

        if use_rand:
            e_rand = parts.get("e_rand", None)
            if e_rand is None:
                e_rand = torch.zeros_like(logits)
            loss_mean0 = (e_rand.mean(dim=0) ** 2).sum()
            loss_ridge = (e_rand ** 2).mean()
            z_fix_list = parts.get("z_fix_list", [])
            z_rand_film_list = parts.get("z_rand_film_list", [])
            loss_orth = torch.tensor(0.0, device=logits.device)
            if len(z_fix_list) and len(z_rand_film_list) and len(z_fix_list) == len(z_rand_film_list):
                acc = 0.0
                for zf, zr in zip(z_fix_list, z_rand_film_list):
                    acc = acc + _orthogonality_penalty_latents(zf, zr)
                loss_orth = acc / float(len(z_fix_list))
            gamma_list = parts.get("film_gamma_list", [])
            beta_list  = parts.get("film_beta_list", [])
            loss_film = torch.tensor(0.0, device=logits.device)
            if len(gamma_list) and len(beta_list):
                acc = 0.0
                one = torch.tensor(1.0, device=logits.device, dtype=logits.dtype)
                for g, b in zip(gamma_list, beta_list):
                    acc = acc + ((g - one) ** 2).mean() + (b ** 2).mean()
                loss_film = acc / float(len(gamma_list))
        else:
            loss_mean0 = torch.tensor(0.0, device=logits.device)
            loss_ridge = torch.tensor(0.0, device=logits.device)
            loss_orth  = torch.tensor(0.0, device=logits.device)
            loss_film  = torch.tensor(0.0, device=logits.device)

        kan_reg = torch.tensor(0.0, device=logits.device)
        if cfg["lambda_kan"] > 0:
            kan_reg = self._sum_kan_regularization()

        total_loss = (
            loss_pred
            + cfg["lambda_mean0"] * loss_mean0
            + cfg["lambda_ridge"] * loss_ridge
            + cfg["lambda_orth_latent"] * loss_orth
            + cfg["lambda_film_identity"] * loss_film
            + cfg["lambda_kan"] * kan_reg
        )

        parts_out = {
            "loss_total": float(total_loss.detach().cpu()),
            "loss_pred":  float(loss_pred.detach().cpu()),
            "loss_mean0": float(loss_mean0.detach().cpu()),
            "loss_ridge": float(loss_ridge.detach().cpu()),
            "loss_orth":  float(loss_orth.detach().cpu()),
            "loss_fi":    float(loss_film.detach().cpu()),
            "loss_kan":   float(kan_reg.detach().cpu()),
        }
        return total_loss, parts_out

    def fit(
        self,
        train_loader,
        val_loader=None,
        *,
        verbose: bool = True,
        pos_weight: Optional[torch.Tensor] = None
    ):
        cfg = self.cfg
        model = self.model
        device = self.device
        opt = torch.optim.Adam(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
        use_amp = bool(cfg.get("amp", True) and torch.cuda.is_available())
        scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
        if pos_weight is None and cfg.get("auto_pos_weight", True):
            pos_weight = self._compute_pos_weight_from_loader(train_loader)
        if pos_weight is not None:
            pos_weight = pos_weight.to(device, dtype=next(model.parameters()).dtype)
        best_val = float("inf")
        best_state = None
        no_improve = 0
        for epoch in range(1, cfg["max_epochs"] + 1):
            model.train()
            total_tr, n_tr = 0.0, 0
            for batch in train_loader:
                tensors = [b.to(device) for b in batch]
                Xf_b, TC_b, Zr_b, y_b, yl_b, dt_b, pid_b, seen_b = tensors[:8]
                sample_w_b = None
                if cfg.get("sample_weight_index") is not None:
                    idx_w = int(cfg["sample_weight_index"])
                    if idx_w < len(tensors):
                        sample_w_b = tensors[idx_w]
                TC_in = TC_b if TC_b.size(1) > 0 else None
                Zr_in = Zr_b if Zr_b.size(1) > 0 else None
                opt.zero_grad(set_to_none=True)
                with torch.amp.autocast('cuda', enabled=use_amp):
                    logits, parts = self.forward(
                        X_fix=Xf_b, TC=TC_in, Zrand=Zr_in,
                        y_lags=yl_b, dt_lags=dt_b,
                        pid_idx=pid_b, pid_seen_mask=seen_b
                    )
                    loss, _ = self.compute_loss(
                        y_true=y_b, logits=logits, parts=parts,
                        X_fix=Xf_b, pid_idx=pid_b, TC=TC_in,
                        pos_weight=pos_weight,
                        sample_weight=sample_w_b,
                    )
                scaler.scale(loss).backward()
                if cfg.get("clip_grad", None):
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["clip_grad"])
                scaler.step(opt)
                scaler.update()
                bs = Xf_b.size(0)
                total_tr += float(loss.detach().cpu()) * bs
                n_tr += bs
            train_loss = total_tr / max(1, n_tr)
            self.history_["train_loss"].append(train_loss)
            if val_loader is not None:
                model.eval()
                total_va, n_va = 0.0, 0
                with torch.no_grad():
                    for batch in val_loader:
                        tensors = [b.to(device) for b in batch]
                        Xf_b, TC_b, Zr_b, y_b, yl_b, dt_b, pid_b, seen_b = tensors[:8]
                        sample_w_b = None
                        TC_in = TC_b if TC_b.size(1) > 0 else None
                        Zr_in = Zr_b if Zr_b.size(1) > 0 else None
                        with torch.amp.autocast('cuda', enabled=use_amp):
                            logits, parts = self.forward(
                                X_fix=Xf_b, TC=TC_in, Zrand=Zr_in,
                                y_lags=yl_b, dt_lags=dt_b,
                                pid_idx=pid_b, pid_seen_mask=seen_b
                            )
                            l, _ = self.compute_loss(
                                y_true=y_b, logits=logits, parts=parts,
                                X_fix=Xf_b, pid_idx=pid_b, TC=TC_in,
                                pos_weight=pos_weight,
                                sample_weight=sample_w_b,
                            )
                        bs = Xf_b.size(0)
                        total_va += float(l.detach().cpu()) * bs
                        n_va += bs
                val_loss = total_va / max(1, n_va)
                self.history_["val_loss"].append(val_loss)
                if verbose:
                    print(f"Epoch {epoch:03d} | train {train_loss:.6f} | val {val_loss:.6f}")
                if val_loss < best_val - 1e-6:
                    best_val = val_loss
                    best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                    no_improve = 0
                else:
                    no_improve += 1
                    if no_improve >= cfg["patience"]:
                        if verbose:
                            print(f"Early stopping at epoch {epoch:03d} (best val {best_val:.6f})")
                        break
            else:
                if verbose:
                    print(f"Epoch {epoch:03d} | train {train_loss:.6f}")
        if best_state is not None:
            model.load_state_dict(best_state)
        self.thresholds_ = None
        if val_loader is not None:
            model.eval()
            all_prob, all_true = [], []
            with torch.no_grad():
                for batch in val_loader:
                    tensors = [b.to(device) for b in batch]
                    Xf_b, TC_b, Zr_b, y_b, yl_b, dt_b, pid_b, seen_b = tensors[:8]
                    TC_in = TC_b if TC_b.size(1) > 0 else None
                    Zr_in = Zr_b if Zr_b.size(1) > 0 else None
                    logits, _ = self.forward(
                        X_fix=Xf_b, TC=TC_in, Zrand=Zr_in,
                        y_lags=yl_b, dt_lags=dt_b,
                        pid_idx=pid_b, pid_seen_mask=seen_b
                    )
                    all_prob.append(torch.sigmoid(logits).cpu().numpy())
                    all_true.append(y_b.cpu().numpy())
            y_prob = np.vstack(all_prob)
            y_true = np.vstack(all_true)
            y_dim = y_prob.shape[1] if y_prob.ndim == 2 else 1
            thr_vec = np.zeros((y_dim,), dtype=float)
            for j in range(y_dim):
                thr_vec[j] = _optimal_threshold_exact_np(y_true[:, j], y_prob[:, j], beta=1.0)
            self.thresholds_ = thr_vec
        return {"best_val_loss": (best_val if val_loader is not None else self.history_["train_loss"][-1])}

    @torch.no_grad()
    def predict_logits(
        self,
        X_fix: torch.Tensor,
        TC: Optional[torch.Tensor],
        Zrand: Optional[torch.Tensor],
        y_lags: torch.Tensor,
        dt_lags: torch.Tensor,
        pid_idx: Optional[torch.Tensor] = None,
        pid_seen_mask: Optional[torch.Tensor] = None,
        *,
        combine: str = "all",
    ) -> torch.Tensor:
        self.model.eval()
        logits_full, parts = self.forward(X_fix, TC, Zrand, y_lags, dt_lags, pid_idx, pid_seen_mask)
        return _combine_logits_from_parts(parts, combine=combine)

    @torch.no_grad()
    def predict_proba(self, *args, combine: str = "all", **kwargs) -> torch.Tensor:
        logits = self.predict_logits(*args, combine=combine, **kwargs)
        return torch.sigmoid(logits)

    @torch.no_grad()
    def predict(
        self,
        *args,
        threshold: Optional[float] = None,
        thresholds: Optional[np.ndarray] = None,
        use_fitted_thresholds: bool = False,
        combine: str = "all",
        **kwargs
    ) -> torch.Tensor:
        probs = self.predict_proba(*args, combine=combine, **kwargs)
        if probs.ndim == 2:
            B, y_dim = probs.shape
        else:
            B, y_dim = probs.numel(), 1
        if thresholds is not None:
            thr = torch.as_tensor(thresholds, device=probs.device, dtype=probs.dtype).view(1, -1)
        elif use_fitted_thresholds and (self.thresholds_ is not None):
            thr = torch.as_tensor(self.thresholds_, device=probs.device, dtype=probs.dtype).view(1, -1)
        else:
            thr_scalar = self.cfg.get("threshold", 0.5) if threshold is None else threshold
            thr = torch.tensor([thr_scalar], device=probs.device, dtype=probs.dtype).view(1, 1)
            if probs.ndim == 2 and probs.size(1) > 1:
                thr = thr.expand(1, probs.size(1))
        return (probs >= thr).to(torch.int32)


## Evaluation

In [7]:
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_fscore_support, confusion_matrix
from sklearn.calibration import calibration_curve
from sklearn.decomposition import PCA
from scipy.stats import t as student_t
from dataclasses import dataclass

_VAR_EPS = 1e-8
_STD_EPS = 1e-6
_CLIP_Z  = 8.0

@dataclass
class PCAPipeline:
    keep_mask: np.ndarray
    mean_: np.ndarray
    scale_: np.ndarray
    pca: PCA

def _fit_pca_pipeline(X_train: np.ndarray, var_ratio: float = 0.95, random_state: int | None = None) -> PCAPipeline:
    X = np.asarray(X_train, dtype=np.float64)
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
    var = X.var(axis=0)
    keep = var > _VAR_EPS
    if not np.any(keep):
        pca = PCA(n_components=0, svd_solver='full', random_state=random_state)
        return PCAPipeline(keep_mask=keep, mean_=np.array([], dtype=np.float64),
                           scale_=np.array([], dtype=np.float64), pca=pca)
    Xk = X[:, keep]
    mean = Xk.mean(axis=0)
    std  = Xk.std(axis=0)
    std  = np.maximum(std, _STD_EPS)
    Z = (Xk - mean) / std
    Z = np.nan_to_num(Z, nan=0.0, posinf=0.0, neginf=0.0)
    np.clip(Z, -_CLIP_Z, _CLIP_Z, out=Z)
    pca = PCA(n_components=var_ratio, svd_solver='full', random_state=random_state)
    pca.fit(Z)
    if not np.isfinite(pca.components_).all():
        raise RuntimeError("PCA components contain non-finite values after fit.")
    return PCAPipeline(keep_mask=keep, mean_=mean, scale_=std, pca=pca)

def _transform_pca_pipeline(pipe: PCAPipeline | None, X: np.ndarray | None) -> np.ndarray | None:
    if pipe is None or X is None:
        return None
    X = np.asarray(X, dtype=np.float64)
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
    if pipe.keep_mask.size == 0 or not np.any(pipe.keep_mask):
        return np.zeros((X.shape[0], 0), dtype=np.float32)
    Xk = X[:, pipe.keep_mask]
    Z = (Xk - pipe.mean_) / pipe.scale_
    Z = np.nan_to_num(Z, nan=0.0, posinf=0.0, neginf=0.0)
    np.clip(Z, -_CLIP_Z, _CLIP_Z, out=Z)
    if not np.isfinite(Z).all():
        bad = np.argwhere(~np.isfinite(Z))[0]
        raise RuntimeError(f"[our PCA] Z non-finite at {tuple(bad)}: {Z[tuple(bad)]}")
    if np.abs(Z).max() > 1e6:
        raise RuntimeError(f"[our PCA] Z max |z| too large: {np.abs(Z).max()}")
    if not np.isfinite(pipe.pca.components_).all():
        raise RuntimeError("[our PCA] components_ non-finite")
    if hasattr(pipe.pca, "mean_") and not np.isfinite(pipe.pca.mean_).all():
        raise RuntimeError("[our PCA] mean_ non-finite")
    Z64 = np.ascontiguousarray(Z, dtype=np.float64)
    CT  = np.ascontiguousarray(pipe.pca.components_.T, dtype=np.float64)
    with np.errstate(over='ignore', invalid='ignore', divide='ignore'):
        Xt = Z64 @ CT
    Xt = np.nan_to_num(Xt, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return Xt

def _metrics_binary_full(y_true, y_prob, thr=0.5) -> dict:
    y_true = np.asarray(y_true).astype(int)
    y_prob = np.asarray(y_prob).astype(float)
    y_pred = (y_prob >= thr).astype(int)
    try: auc = float(roc_auc_score(y_true, y_prob))
    except: auc = float("nan")
    try: auprc = float(average_precision_score(y_true, y_prob))
    except: auprc = float("nan")
    brier = float(np.mean((y_prob - y_true) ** 2))
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary', zero_division=0)
    acc = float((y_pred == y_true).mean())
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
        sens = float(tp / (tp + fn)) if (tp + fn) > 0 else float("nan")
        spec = float(tn / (tn + fp)) if (tn + fp) > 0 else float("nan")
    except:
        sens, spec = float("nan"), float("nan")
    return {"AUC": auc, "AUPRC": auprc, "Brier": brier, "ACC": acc, "F1": float(f1),
            "Precision": float(prec), "Recall": float(rec),
            "Sensitivity": sens, "Specificity": spec}

def evaluate_multitask(y_true: np.ndarray, y_prob: np.ndarray, thr=0.5) -> dict:
    y_true = np.asarray(y_true); y_prob = np.asarray(y_prob)
    if y_true.ndim == 1: y_true = y_true[:, None]
    if y_prob.ndim == 1: y_prob = y_prob[:, None]
    assert y_true.shape == y_prob.shape
    y_dim = y_true.shape[1]
    out, macro = {}, {}
    for j in range(y_dim):
        m = _metrics_binary_full(y_true[:, j], y_prob[:, j], thr)
        for k, v in m.items():
            out[f"task_{j+1}_{k}"] = float(v)
            macro.setdefault(k, []).append(float(v))
    for k, vals in macro.items():
        out[f"macro_{k}"] = float(np.nanmean(np.array(vals, dtype=float)))
    return out

def _evaluate_with_thresholds(y_true: np.ndarray, y_prob: np.ndarray, thresholds: np.ndarray) -> dict:
    if y_true.ndim == 1: y_true = y_true[:, None]
    if y_prob.ndim == 1: y_prob = y_prob[:, None]
    y_dim = y_true.shape[1]
    out, macro = {}, {}
    for j in range(y_dim):
        m = _metrics_binary_full(y_true[:, j], y_prob[:, j], thr=float(thresholds[j]))
        for k, v in m.items():
            out[f"task_{j+1}_{k}_optthr"] = float(v)
            macro.setdefault(k, []).append(float(v))
    for k, vals in macro.items():
        out[f"macro_{k}_optthr"] = float(np.nanmean(np.array(vals, dtype=float)))
    return out

def compute_calibration_curves(y_true: np.ndarray, y_prob: np.ndarray, n_bins: int = 10, strategy: str = "quantile") -> dict:
    y_true = np.asarray(y_true); y_prob = np.asarray(y_prob)
    if y_true.ndim == 1: y_true = y_true[:, None]
    if y_prob.ndim == 1: y_prob = y_prob[:, None]
    y_dim = y_true.shape[1]
    per_task, eces, mces = {}, [], []
    for j in range(y_dim):
        t = y_true[:, j].astype(int)
        p = np.clip(y_prob[:, j].astype(float), 1e-6, 1-1e-6)
        try:
            frac_pos, mean_pred = calibration_curve(t, p, n_bins=n_bins, strategy=strategy)
        except Exception:
            frac_pos, mean_pred = np.array([]), np.array([])
        if strategy == "uniform":
            edges = np.linspace(0, 1, n_bins+1)
        else:
            qs = np.linspace(0, 1, n_bins+1)
            edges = np.quantile(p, qs); edges[0], edges[-1] = 0.0, 1.0
        bin_ids = np.digitize(p, edges[1:-1], right=True)
        counts = np.bincount(bin_ids, minlength=n_bins).astype(float)
        N = max(1, len(p))
        if len(mean_pred) == n_bins:
            weights = counts / N
        else:
            if len(mean_pred) > 0:
                idxs = np.digitize(mean_pred, edges[1:-1], right=True)
                weights = counts[idxs] / N
            else:
                weights = np.array([])
        gaps = np.abs(frac_pos - mean_pred) if len(mean_pred) else np.array([np.nan])
        ece = float(np.nansum(weights * gaps)) if len(mean_pred) else float("nan")
        mce = float(np.nanmax(gaps)) if len(mean_pred) else float("nan")
        brier = float(np.mean((p - t) ** 2)) if N > 0 else float("nan")
        per_task[j] = {"mean_pred": mean_pred.tolist(), "frac_pos": frac_pos.tolist(),
                       "counts": counts.tolist(), "ece": ece, "mce": mce, "brier": brier}
        eces.append(ece); mces.append(mce)
    return {"per_task": per_task, "macro_ECE": float(np.nanmean(eces)), "macro_MCE": float(np.nanmean(mces))}

def _split_cases(pid_array, test_fraction=0.2, seed=42):
    rng = np.random.default_rng(seed)
    u = np.unique(pid_array)
    te_ids = rng.choice(u, size=max(1, int(len(u)*test_fraction)), replace=False)
    te_mask = np.isin(pid_array, te_ids)
    return np.where(~te_mask)[0], np.where(te_mask)[0]

def _split_time_basic(time_index, test_fraction=0.2):
    order = np.argsort(time_index)
    n = len(order)
    split = int(np.floor(n*(1.0 - test_fraction)))
    return order[:split], order[split:]

def _filter_time_test_min_measurements(pid_idx: np.ndarray, test_idx: np.ndarray, min_meas: int = 2):
    pid = np.asarray(pid_idx)
    counts = {pid_val: np.sum(pid == pid_val) for pid_val in np.unique(pid)}
    keep = [i for i in test_idx if counts.get(pid[i], 0) >= min_meas]
    return np.array(keep, dtype=int)

def _concat_safe(*arrays: Optional[np.ndarray]) -> np.ndarray:
    parts = [a for a in arrays if a is not None and a.size > 0]
    if not parts:
        return np.zeros((0, 0), dtype=np.float32)
    return np.concatenate(parts, axis=1).astype(np.float32)

from typing import Dict, Any, Optional, Tuple, List
import torch
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import GroupKFold, TimeSeriesSplit, ParameterGrid
from scipy.stats import t as student_t

class _KANDataset(Dataset):
    def __init__(
        self,
        X_fix: np.ndarray,
        TC: Optional[np.ndarray],
        Zrand: Optional[np.ndarray],
        y: np.ndarray,
        y_lags: np.ndarray,
        dt_lags: np.ndarray,
        pid_idx: np.ndarray,
        pid_seen_mask: np.ndarray,
        device: torch.device
    ):
        N = X_fix.shape[0]
        assert y.shape[0] == N and y_lags.shape[0] == N and dt_lags.shape[0] == N and pid_idx.shape[0] == N
        def _to_tensor(a, dtype=torch.float32):
            if a is None:
                return torch.zeros((N, 0), dtype=dtype, device=device)
            return torch.as_tensor(a, dtype=dtype, device=device)
        self.X_fix  = _to_tensor(X_fix, torch.float32)
        self.TC     = _to_tensor(TC,    torch.float32)
        self.Zrand  = _to_tensor(Zrand, torch.float32)
        self.y      = _to_tensor(y,     torch.float32)
        self.y_lags = _to_tensor(y_lags, torch.float32)
        self.dt     = _to_tensor(dt_lags, torch.float32)
        self.pid    = torch.as_tensor(pid_idx, dtype=torch.long, device=device)
        self.seen   = torch.as_tensor(pid_seen_mask.astype(bool), dtype=torch.bool, device=device)
    def __len__(self):
        return self.X_fix.shape[0]
    def __getitem__(self, i):
        return (
            self.X_fix[i], self.TC[i], self.Zrand[i],
            self.y[i], self.y_lags[i], self.dt[i],
            self.pid[i], self.seen[i]
        )

def _make_loader_kan(
    X_fix, TC, Zrand, y, y_lags, dt_lags,
    pid_idx, pid_seen_mask,
    batch_size: int,
    shuffle: bool,
    device: torch.device
) -> DataLoader:
    ds = _KANDataset(
        X_fix, TC, Zrand, y, y_lags, dt_lags, pid_idx, pid_seen_mask, device
    )
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=False, num_workers=0)

def _optimal_threshold_exact_np(y_true: np.ndarray, y_prob: np.ndarray, beta: float = 1.0) -> float:
    y = y_true.astype(int).ravel()
    p = y_prob.astype(float).ravel()
    if y.size == 0:
        return 0.5
    P = int(y.sum()); N = y.size - P
    if P == 0: return 1.0
    if N == 0: return 0.0
    o = np.argsort(-p); p = p[o]; y = y[o]
    tp = np.cumsum(y); fp = np.cumsum(1 - y)
    prec = tp / np.maximum(1, tp + fp)
    rec  = tp / max(1, P)
    f = (1 + beta**2) * prec * rec / np.maximum(1e-12, beta**2 * prec + rec)
    idx = int(np.nanargmax(f))
    next_p = p[idx+1] if idx+1 < len(p) else -np.inf
    thr = (p[idx] + next_p)/2.0 if np.isfinite(next_p) and next_p < p[idx] else max(0.0, p[idx] - np.finfo(p.dtype).eps)
    return float(thr)

def _find_best_thresholds_from_loader(wrapper, loader: DataLoader) -> Tuple[np.ndarray, Dict[str, float]]:
    wrapper.model.eval()
    probs_all, y_all = [], []
    with torch.no_grad():
        for Xf_b, TC_b, Zr_b, y_b, yl_b, dt_b, pid_b, seen_b in loader:
            logits = wrapper.predict_logits(
                X_fix=Xf_b,
                TC=TC_b if TC_b.size(1) > 0 else None,
                Zrand=Zr_b if Zr_b.size(1) > 0 else None,
                y_lags=yl_b,
                dt_lags=dt_b,
                pid_idx=pid_b,
                pid_seen_mask=seen_b,
            )
            probs_all.append(torch.sigmoid(logits).cpu().numpy())
            y_all.append(y_b.cpu().numpy())
    y_prob = np.vstack(probs_all); y_true = np.vstack(y_all)
    y_dim = y_prob.shape[1]
    thr = np.zeros((y_dim,), dtype=float)
    for j in range(y_dim):
        thr[j] = _optimal_threshold_exact_np(y_true[:, j], y_prob[:, j], beta=1.0)
    from sklearn.metrics import precision_recall_fscore_support
    preds = (y_prob >= thr[None, :]).astype(int)
    f1s = []
    for j in range(y_dim):
        _, _, f1, _ = precision_recall_fscore_support(y_true[:, j], preds[:, j], average='binary', zero_division=0)
        f1s.append(float(f1))
    return thr, {"macro_F1_trainthr": float(np.nanmean(f1s))}

def _print_split_info(name, idxs, pid_idx_full):
    n = int(len(idxs))
    u = int(len(np.unique(pid_idx_full[idxs]))) if n > 0 else 0
    print(f"[split] {name:<5} | rows={n:5d} | unique_ids={u:5d}")

def _prepare_split_and_loaders_kan(
    X_only_fixed: np.ndarray,
    X_fixed_and_random: Optional[np.ndarray],
    y: np.ndarray,
    y_lags: np.ndarray,
    dt_lags: np.ndarray,
    pid_idx_full: np.ndarray,
    indices_train: np.ndarray,
    indices_val: Optional[np.ndarray],
    indices_test: np.ndarray,
    scenario: str,
    batch_size: int,
    device: torch.device,
    random_state: int = 42,
    pca_var_ratio: float = 0.95,
    verbose: bool = True,
):
    of_pipe = _fit_pca_pipeline(X_only_fixed[indices_train], var_ratio=pca_var_ratio, random_state=random_state)
    fr_pipe = None
    if X_fixed_and_random is not None and X_fixed_and_random.shape[1] > 0:
        fr_pipe = _fit_pca_pipeline(X_fixed_and_random[indices_train], var_ratio=pca_var_ratio, random_state=random_state)
    def transform_block(idxs):
        of = _transform_pca_pipeline(of_pipe, X_only_fixed[idxs])
        fr = _transform_pca_pipeline(fr_pipe, None if X_fixed_and_random is None else X_fixed_and_random[idxs])
        X_fix = _concat_safe(of, fr)
        TC    = of
        Zr    = fr
        return X_fix, TC, Zr
    if verbose:
        _print_split_info("train", indices_train, pid_idx_full)
        if indices_val is not None:
            _print_split_info("val", indices_val, pid_idx_full)
        _print_split_info("test", indices_test, pid_idx_full)
    Xf_tr, TC_tr, Zr_tr = transform_block(indices_train)
    Xf_te, TC_te, Zr_te = transform_block(indices_test)
    if indices_val is not None:
        Xf_va, TC_va, Zr_va = transform_block(indices_val)
    else:
        Xf_va = TC_va = Zr_va = None
    if scenario == "cases":
        seen_pids = set(pid_idx_full[indices_train].tolist())
        seen_tr = np.ones(indices_train.shape[0], dtype=bool)
        seen_va = np.ones(indices_val.shape[0], dtype=bool) if indices_val is not None else None
        seen_te = np.array([p in seen_pids for p in pid_idx_full[indices_test]], dtype=bool)
    else:
        seen_tr = np.ones(indices_train.shape[0], dtype=bool)
        seen_va = np.ones(indices_val.shape[0], dtype=bool) if indices_val is not None else None
        seen_te = np.ones(indices_test.shape[0], dtype=bool)
    tr_loader = _make_loader_kan(
        Xf_tr, TC_tr, Zr_tr, y[indices_train], y_lags[indices_train], dt_lags[indices_train],
        pid_idx_full[indices_train], seen_tr,
        batch_size=batch_size, shuffle=True, device=device
    )
    va_loader = None
    if indices_val is not None:
        va_loader = _make_loader_kan(
            Xf_va, TC_va, Zr_va, y[indices_val], y_lags[indices_val], dt_lags[indices_val],
            pid_idx_full[indices_val], seen_va,
            batch_size=batch_size, shuffle=False, device=device
        )
    te_loader = _make_loader_kan(
        Xf_te, TC_te, Zr_te, y[indices_test], y_lags[indices_test], dt_lags[indices_test],
        pid_idx_full[indices_test], seen_te,
        batch_size=batch_size, shuffle=False, device=device
    )
    preprocessors = {
        "of_pipe": of_pipe,
        "fr_pipe": fr_pipe,
        "d_fix":   Xf_tr.shape[1],
        "d_tc":    (Xf_tr.shape[1] - (0 if Zr_tr is None else Zr_tr.shape[1])),
        "d_zrand": 0 if Zr_tr is None else Zr_tr.shape[1],
        "n_ids":   int(len(np.unique(pid_idx_full))),
    }
    loaders = {"train": tr_loader, "val": va_loader, "test": te_loader}
    return preprocessors, loaders

def _model_param_count(model: torch.nn.Module) -> int:
    return int(sum(p.numel() for p in model.parameters()))

def _fit_eval_once_kan(
    build_model_fn, wrapper_cls,
    arch_params: Dict[str, Any],
    train_params: Dict[str, Any],
    X_only_fixed: np.ndarray,
    X_fixed_and_random: Optional[np.ndarray],
    y: np.ndarray,
    y_lags: np.ndarray,
    dt_lags: np.ndarray,
    pid_idx_full: np.ndarray,
    tr_idx: np.ndarray,
    va_idx: Optional[np.ndarray],
    te_idx: np.ndarray,
    device: torch.device,
    scenario: str,
    threshold_selection_source: str = "train",
    verbose: bool = True,
):
    preprocessors, loaders = _prepare_split_and_loaders_kan(
        X_only_fixed, X_fixed_and_random, y, y_lags, dt_lags, pid_idx_full,
        tr_idx, va_idx, te_idx, scenario,
        batch_size=train_params.get("batch_size", 256),
        device=device,
        random_state=train_params.get("random_state", 42),
        pca_var_ratio=train_params.get("pca_var_ratio", 0.95),
        verbose=verbose,
    )
    y_dim = y.shape[1] if y.ndim == 2 else 1
    model = build_model_fn(
        y_dim=y_dim,
        d_fix=preprocessors["d_fix"],
        d_tc=preprocessors["d_tc"],
        d_zrand=preprocessors["d_zrand"],
        n_ids=preprocessors["n_ids"],
        **arch_params
    ).to(device)
    wrapper = wrapper_cls(model, cfg=train_params, device=device)
    wrapper.fit(
        loaders["train"], loaders["val"],
        verbose=verbose,
        pos_weight=train_params.get("pos_weight", None)
    )
    wrapper.model.eval()
    probs_all, y_all = [], []
    with torch.no_grad():
        for Xf_b, TC_b, Zr_b, y_b, yl_b, dt_b, pid_b, seen_b in loaders["test"]:
            logits = wrapper.predict_logits(
                X_fix=Xf_b,
                TC=TC_b if TC_b.size(1) > 0 else None,
                Zrand=Zr_b if Zr_b.size(1) > 0 else None,
                y_lags=yl_b,
                dt_lags=dt_b,
                pid_idx=pid_b,
                pid_seen_mask=seen_b,
            )
            probs_all.append(torch.sigmoid(logits).cpu().numpy())
            y_all.append(y_b.cpu().numpy())
    y_prob_te = np.vstack(probs_all); y_true_te = np.vstack(y_all)
    metrics_050 = evaluate_multitask(y_true_te, y_prob_te, thr=train_params.get("threshold", 0.5))
    if threshold_selection_source == "train":
        thr_vec, thr_summary = _find_best_thresholds_from_loader(wrapper, loaders["train"])
    elif threshold_selection_source == "val" and loaders["val"] is not None:
        thr_vec, thr_summary = _find_best_thresholds_from_loader(wrapper, loaders["val"])
    else:
        thr_vec, thr_summary = _find_best_thresholds_from_loader(wrapper, loaders["test"])
    metrics_opt = _evaluate_with_thresholds(y_true_te, y_prob_te, thr_vec)
    calib = compute_calibration_curves(
        y_true_te, y_prob_te,
        n_bins=train_params.get("calibration_bins", 10),
        strategy=train_params.get("calibration_strategy", "quantile"),
    )
    y_pred_050 = (y_prob_te >= float(train_params.get("threshold", 0.5))).astype(int)
    y_pred_opt = (y_prob_te >= thr_vec[None, :]).astype(int)
    return {
        "metrics@0.5": metrics_050,
        "metrics@optthr": metrics_opt,
        "opt_thresholds": thr_vec,
        "opt_thresholds_summary": thr_summary,
        "preprocessors": preprocessors,
        "wrapper": wrapper,
        "model": wrapper.model,
        "model_class": wrapper.model.__class__.__name__,
        "model_param_count": _model_param_count(wrapper.model),
        "arch_params_final": dict(arch_params),
        "train_params_final": dict(train_params),
        "macro_ECE": calib["macro_ECE"],
        "macro_MCE": calib["macro_MCE"],
        "y_true_test": y_true_te,
        "y_prob_test": y_prob_te,
        "y_pred_test@0.5": y_pred_050,
        "y_pred_test@optthr": y_pred_opt,
    }

def run_training_and_eval_kan(
    X_only_fixed: np.ndarray,
    X_fixed_and_random: Optional[np.ndarray],
    y: np.ndarray,
    y_lags: np.ndarray,
    dt_lags: np.ndarray,
    pid_idx: np.ndarray,
    time_index: np.ndarray,
    build_model_fn,
    wrapper_cls,
    *,
    mode: str = "single",
    scenario: str = "cases",
    outer_folds: int = 5,
    inner_folds: int = 3,
    param_grid: Optional[Dict[str, List]] = None,
    arch_defaults: Optional[Dict[str, Any]] = None,
    train_defaults: Optional[Dict[str, Any]] = None,
    device: Optional[torch.device] = None,
    verbose: bool = True,
) -> Dict[str, Any]:
    device = (
        device
        or (torch.device("mps") if torch.backends.mps.is_available() else None)
        or (torch.device("cuda") if torch.cuda.is_available() else None)
        or torch.device("cpu")
    )
    if scenario == "both":
        out_cases = run_training_and_eval_kan(
            X_only_fixed, X_fixed_and_random, y, y_lags, dt_lags, pid_idx, time_index,
            build_model_fn, wrapper_cls,
            mode=mode, scenario="cases", outer_folds=outer_folds, inner_folds=inner_folds,
            param_grid=param_grid, arch_defaults=arch_defaults, train_defaults=train_defaults,
            device=device, verbose=verbose
        )
        out_time = run_training_and_eval_kan(
            X_only_fixed, X_fixed_and_random, y, y_lags, dt_lags, pid_idx, time_index,
            build_model_fn, wrapper_cls,
            mode=mode, scenario="time", outer_folds=outer_folds, inner_folds=inner_folds,
            param_grid=param_grid, arch_defaults=arch_defaults, train_defaults=train_defaults,
            device=device, verbose=verbose
        )
        return {"scenario_cases": out_cases, "scenario_time": out_time}
    X_of = np.asarray(X_only_fixed, dtype=np.float32)
    X_fr = None if X_fixed_and_random is None else np.asarray(X_fixed_and_random, dtype=np.float32)
    y    = np.asarray(y, dtype=np.float32)
    if y.ndim == 1: y = y[:, None]
    y_lags = np.asarray(y_lags, dtype=np.float32)
    dt_lags = np.asarray(dt_lags, dtype=np.float32)
    pid_idx = np.asarray(pid_idx, dtype=np.int64)
    time_ix = np.asarray(time_index)
    arch_defaults = arch_defaults or {}
    train_defaults = train_defaults or {}
    rnd = int(train_defaults.get("random_state", 42))
    val_frac = float(train_defaults.get("val_fraction", 0.10))
    thr_source = train_defaults.get("threshold_selection_source", "train")
    def _make_train_val_split(idx_array: np.ndarray, seed: int) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        idx_array = np.asarray(idx_array)
        if len(idx_array) <= 10 or val_frac <= 0.0:
            return idx_array, None
        rng = np.random.default_rng(seed)
        perm = rng.permutation(len(idx_array))
        cut = max(1, int(val_frac * len(idx_array)))
        va_sel, tr_sel = perm[:cut], perm[cut:]
        return idx_array[tr_sel], idx_array[va_sel]
    if mode == "single":
        if scenario == "cases":
            tr_idx_all, te_idx = _split_cases(pid_idx, test_fraction=0.2, seed=rnd)
        elif scenario == "time":
            tr_idx_all, te_idx_raw = _split_time_basic(time_ix, test_fraction=0.2)
            te_idx = _filter_time_test_min_measurements(
                pid_idx, te_idx_raw, min_meas=train_defaults.get("min_meas_test", 3)
            )
            if len(te_idx) == 0:
                raise RuntimeError("Time split produced empty test after ≥3-measurements filter.")
        else:
            raise ValueError("scenario must be 'cases' or 'time'")
        tr_idx, va_idx = _make_train_val_split(tr_idx_all, seed=rnd)
        res = _fit_eval_once_kan(
            build_model_fn, wrapper_cls,
            arch_defaults, train_defaults,
            X_of, X_fr, y, y_lags, dt_lags, pid_idx,
            tr_idx, va_idx, te_idx,
            device=device, scenario=scenario,
            threshold_selection_source=thr_source,
            verbose=verbose
        )
        if verbose:
            print("\nSingle-fit test metrics @0.5:")
            for k, v in res["metrics@0.5"].items():
                print(f"{k:>18}: {v:.4f}")
            print("\nSingle-fit test metrics @F1-opt per task:")
            for k, v in res["metrics@optthr"].items():
                print(f"{k:>18}: {v:.4f}")
            print(f"\nModel: {res['model_class']} | params={res['model_param_count']}")
            print("Final arch params:", res["arch_params_final"])
            print("Final train params:", res["train_params_final"])
        return res
    if mode == "cv_only":
        fold_metrics_050, fold_metrics_opt = [], []
        if scenario == "cases":
            outer = GroupKFold(n_splits=outer_folds)
            outer_iter = outer.split(X_of, y[:, 0], groups=pid_idx)
        else:
            tss = TimeSeriesSplit(n_splits=outer_folds)
            order = np.argsort(time_ix)
            X_order = X_of[order]; y_order = y[order]
            outer_iter = ((order[tr], order[te]) for tr, te in tss.split(X_order, y_order[:, 0]))
        for fold_id, (tr_idx_all, te_idx) in enumerate(outer_iter, start=1):
            tr_idx, va_idx = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd + fold_id)
            if verbose:
                print(f"\n[CV fold {fold_id}/{outer_folds}] scenario={scenario}")
            res = _fit_eval_once_kan(
                build_model_fn, wrapper_cls,
                arch_defaults, train_defaults,
                X_of, X_fr, y, y_lags, dt_lags, pid_idx,
                tr_idx, va_idx, te_idx,
                device=device, scenario=scenario,
                threshold_selection_source=thr_source,
                verbose=verbose
            )
            fold_metrics_050.append(res["metrics@0.5"])
            fold_metrics_opt.append(res["metrics@optthr"])
            if verbose:
                macro_keys = [k for k in res["metrics@optthr"].keys() if k.startswith("macro_")]
                print("Fold macro (optthr): " + ", ".join(f"{k}={res['metrics@optthr'][k]:.4f}" for k in macro_keys))
        def _summarize_cv_folds(results_folds: List[Dict[str, float]]) -> Dict[str, float]:
            if not results_folds:
                return {}
            all_keys = set().union(*results_folds)
            summary = {}
            for k in sorted(all_keys):
                vals = np.array([fold.get(k, np.nan) for fold in results_folds], dtype=float)
                mask = np.isfinite(vals); n = int(mask.sum())
                if n == 0:
                    m = low = high = np.nan
                elif n == 1:
                    m = float(vals[mask][0]); low = high = np.nan
                else:
                    m = float(np.nanmean(vals)); s = float(np.nanstd(vals, ddof=1)); se = s/np.sqrt(n)
                    tcrit = float(student_t.ppf(0.975, df=n-1))
                    low, high = m - tcrit*se, m + tcrit*se
                summary[f"{k}_mean"] = m
                summary[f"{k}_95ci_low"] = low
                summary[f"{k}_95ci_high"] = high
            return summary
        cv_summary_050 = _summarize_cv_folds(fold_metrics_050)
        cv_summary_opt = _summarize_cv_folds(fold_metrics_opt)
        if verbose:
            print("\nCV averages (±95% CI) @0.5:")
            for key in sorted(cv_summary_050.keys()):
                if key.endswith("_mean"):
                    base = key[:-5]
                    low = cv_summary_050.get(f"{base}_95ci_low", np.nan)
                    high = cv_summary_050.get(f"{base}_95ci_high", np.nan)
                    print(f"{base:>20}: {cv_summary_050[key]:.4f}  (95% CI {low:.4f}, {high:.4f})")
            print("\nCV averages (±95% CI) @opt thresholds:")
            for key in sorted(cv_summary_opt.keys()):
                if key.endswith("_mean"):
                    base = key[:-5]
                    low = cv_summary_opt.get(f"{base}_95ci_low", np.nan)
                    high = cv_summary_opt.get(f"{base}_95ci_high", np.nan)
                    print(f"{base:>20}: {cv_summary_opt[key]:.4f}  (95% CI {low:.4f}, {high:.4f})")
            print("Arch params (used in all folds):", arch_defaults)
            print("Train params (used in all folds):", train_defaults)
        return {
            "cv_folds_metrics@0.5": fold_metrics_050,
            "cv_folds_metrics@optthr": fold_metrics_opt,
            "cv_summary@0.5": cv_summary_050,
            "cv_summary@optthr": cv_summary_opt,
            "arch_params_final": dict(arch_defaults),
            "train_params_final": dict(train_defaults),
        }
    if mode == "nested_cv":
        if not param_grid:
            param_grid = {
                "d_fix_latent": [128, 256],
                "d_rand_latent": [128],
                "n_kernels": [4],
                "dropout": [0.0, 0.1],
                "lr": [1e-3, 3e-4],
                "weight_decay": [0.0, 1e-4],
                "batch_size": [256],
                "max_epochs": [100],
                "patience": [10],
            }
        results_folds = []
        best_score_global, best_params_global = -np.inf, None
        if scenario == "cases":
            outer = GroupKFold(n_splits=outer_folds)
            outer_iter = outer.split(X_of, y[:, 0], groups=pid_idx)
        else:
            tss = TimeSeriesSplit(n_splits=outer_folds)
            order = np.argsort(time_ix)
            X_order = X_of[order]; y_order = y[order]
            outer_iter = ((order[tr], order[te]) for tr, te in tss.split(X_order, y_order[:, 0]))
        for fold_id, (tr_idx_all, te_idx) in enumerate(outer_iter, start=1):
            if verbose:
                print(f"\nOuter fold {fold_id}/{outer_folds} scenario={scenario}")
            def inner_iter():
                if scenario == "cases":
                    inner = GroupKFold(n_splits=inner_folds)
                    return inner.split(X_of[tr_idx_all], y[tr_idx_all, 0], groups=pid_idx[tr_idx_all])
                else:
                    tr_order = np.argsort(time_ix[tr_idx_all])
                    X_tr_order = X_of[tr_idx_all][tr_order]
                    y_tr_order = y[tr_idx_all][tr_order]
                    inner_tss = TimeSeriesSplit(n_splits=inner_folds)
                    return ((tr_idx_all[tr_order][itr], tr_idx_all[tr_order][iva])
                            for itr, iva in inner_tss.split(X_tr_order, y_tr_order[:, 0]))
            best_inner_score, best_inner_params = -np.inf, None
            for params in ParameterGrid(param_grid):
                arch_params = dict(arch_defaults)
                train_params = dict(train_defaults)
                for k, v in params.items():
                    if k in ("d_fix_latent", "d_rand_latent", "n_kernels", "dropout", "use_attention", "use_random"):
                        arch_params[k] = v
                    else:
                        train_params[k] = v
                inner_scores = []
                for in_tr, in_va in inner_iter():
                    tr_idx_inner, va_idx_inner = _make_train_val_split(np.asarray(in_tr), seed=rnd + fold_id)
                    res_inner = _fit_eval_once_kan(
                        build_model_fn, wrapper_cls,
                        arch_params, train_params,
                        X_of, X_fr, y, y_lags, dt_lags, pid_idx,
                        tr_idx_inner, va_idx_inner, in_va,
                        device=device, scenario=scenario,
                        threshold_selection_source="train",
                        verbose=False
                    )
                    score = res_inner["metrics@optthr"].get("macro_F1_optthr", np.nan)
                    inner_scores.append(score)
                avg_score = float(np.nanmean(inner_scores)) if len(inner_scores) else -np.inf
                if avg_score > best_inner_score:
                    best_inner_score = avg_score
                    best_inner_params = (arch_params, train_params)
            if best_inner_params is None:
                if verbose: print("No viable inner config; skipping outer fold.")
                continue
            arch_params, train_params = best_inner_params
            tr_idx_outer, va_idx_outer = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd + fold_id * 17)
            res_outer = _fit_eval_once_kan(
                build_model_fn, wrapper_cls,
                arch_params, train_params,
                X_of, X_fr, y, y_lags, dt_lags, pid_idx,
                tr_idx_outer, va_idx_outer, te_idx,
                device=device, scenario=scenario,
                threshold_selection_source="train",
                verbose=False
            )
            results_folds.append(res_outer)
            score_outer = res_outer["metrics@optthr"].get("macro_F1_optthr", -np.inf)
            if score_outer > best_score_global:
                best_score_global = score_outer
                best_params_global = (arch_params, train_params)
            if verbose:
                macro_keys = [k for k in res_outer["metrics@optthr"].keys() if k.startswith("macro_")]
                print("Outer fold macro (optthr): " + ", ".join(f"{k}={res_outer['metrics@optthr'][k]:.4f}" for k in macro_keys))
        def _summarize_block(results_list: List[Dict[str, Any]], which: str) -> Dict[str, float]:
            keys = list(results_list[0][which].keys())
            out = {}
            for k in keys:
                arr = np.array([res[which][k] for res in results_list], dtype=float)
                m = float(np.nanmean(arr)); s = float(np.nanstd(arr, ddof=1)); n = len(arr)
                se = s / np.sqrt(n) if n > 1 else np.nan
                if n > 1:
                    tcrit = float(student_t.ppf(0.975, df=n-1))
                    ci = (m - tcrit * se, m + tcrit * se)
                else:
                    ci = (np.nan, np.nan)
                out[k + "_mean"] = m
                out[k + "_95ci_low"] = ci[0]
                out[k + "_95ci_high"] = ci[1]
            return out
        cv_summary_050    = _summarize_block(results_folds, "metrics@0.5")
        cv_summary_optthr = _summarize_block(results_folds, "metrics@optthr")
        if verbose and best_params_global is not None:
            print("\nBest params (by outer macro_F1 at opt thresholds):")
            arch_p, train_p = best_params_global
            print("[ARCH]:");   [print(f"  {k}: {v}") for k, v in arch_p.items()]
            print("[TRAIN]:");  [print(f"  {k}: {v}") for k, v in train_p.items()]
        print("\nCross-validation results:")
        print("CV Summary @ 0.5:")
        print(cv_summary_050)
        print("\nCV Summary @ Optimal Threshold:")
        print(cv_summary_optthr)
        return {
            "outer_folds": results_folds,
            "cv_summary@0.5": cv_summary_050,
            "cv_summary@optthr": cv_summary_optthr,
            "best_params": (best_params_global[0], best_params_global[1]) if best_params_global else None,
        }
    raise ValueError("mode must be one of {'single','cv_only','nested_cv'}")


## Model test

### Define variables and parameters

In [19]:
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

import numpy as np

# ---------- targets & aux ----------
y_raw      = GHQ_cat_y.to_numpy(np.float32)
y_np       = y_raw if y_raw.ndim == 2 else y_raw.reshape(-1, 1)

y_lags_np  = GHQ_cat_outcomes_lags.to_numpy(np.float32)
dt_lags_np = GHQ_cat_forecast_horizons.to_numpy(np.float32)

# ---------- inputs ----------
# keep the two blocks SEPARATE; evaluator will:
#   - TC  <- only_fixed (after PCA)
#   - X_fix <- concat(only_fixed_PCA, fixed_and_random_PCA)
#   - Zrand <- fixed_and_random_PCA
X_only_fixed_np        = GHQ_cat_only_fixed.to_numpy(np.float32)
X_fixed_and_random_np  = GHQ_cat_fixed_and_random.to_numpy(np.float32)

# ---------- ids & time ----------
pid_raw    = GHQ_cat_participant_id.to_numpy().ravel()
pid_uniqs, pid_encoded = np.unique(pid_raw, return_inverse=True)
pid_np     = pid_encoded.astype(np.int64)
n_ids      = int(len(pid_uniqs))

time_ix_np = GHQ_cat_time.to_numpy().ravel()

# ---------- dynamic builder ----------
def build_model_fn(
    *,
    y_dim: int,
    d_fix: int,
    d_tc: int,        # this will be the PCA’d only_fixed dim (activates FiLM)
    d_zrand: int,     # this will be the PCA’d fixed_and_random dim
    n_ids: int,
    **arch
):
    return KANAdditiveMixedEffects(
        y_dim=y_dim,
        d_fix=d_fix,       # = dim(concat(only_fixed_PCA, fixed_and_random_PCA))
        d_tc=d_tc,         # = dim(only_fixed_PCA)  -> used by FiLM
        d_zrand=d_zrand,   # = dim(fixed_and_random_PCA)
        n_ids=n_ids,
        **arch
    )



arch_defaults = dict(d_fix_latent=256, d_rand_latent=256, n_kernels=8, dropout=0.0)
train_defaults = dict(
    lr=3e-4, weight_decay=3e-6, batch_size=256,
    max_epochs=100, patience=20, threshold=0.5,
    lambda_mean0=1e-3, lambda_ridge=3e-3,
    lambda_orth_latent=3e-3,
    lambda_film_identity=3e-3,
    lambda_kan=0,
    random_state=42,
)

print(y_np.shape, X_only_fixed_np.shape, X_fixed_and_random_np.shape, pid_np.shape, time_ix_np.shape)

(1031, 1) (1031, 130) (1031, 66) (1031,) (1031,)


### Simple cases split test

In [147]:
res = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # -> TC and part of FIXED
    X_fixed_and_random=X_fixed_and_random_np,     # -> RANDOM and part of FIXED
    y=y_np, y_lags=y_lags_np, dt_lags=dt_lags_np,
    pid_idx=pid_np, time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=KANMixedEffectsWrapper,
    mode="single",          # or "cv_only"/"nested_cv"
    scenario="time",        # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True
)


[split] train | rows=  742 | unique_ids=  247
[split] val   | rows=   82 | unique_ids=   72
[split] test  | rows=  204 | unique_ids=   93
Epoch 001 | train 0.655723 | val 0.654983
Epoch 002 | train 0.647988 | val 0.648192
Epoch 003 | train 0.638431 | val 0.637774
Epoch 004 | train 0.627023 | val 0.624220
Epoch 005 | train 0.611346 | val 0.606442
Epoch 006 | train 0.590734 | val 0.584293
Epoch 007 | train 0.566892 | val 0.558976
Epoch 008 | train 0.540184 | val 0.534596
Epoch 009 | train 0.518674 | val 0.517156
Epoch 010 | train 0.506928 | val 0.510262
Epoch 011 | train 0.498021 | val 0.510095
Epoch 012 | train 0.489918 | val 0.507949
Epoch 013 | train 0.479653 | val 0.504469
Epoch 014 | train 0.471373 | val 0.501960
Epoch 015 | train 0.463004 | val 0.503813
Epoch 016 | train 0.455794 | val 0.506980
Epoch 017 | train 0.447606 | val 0.514464
Epoch 018 | train 0.437956 | val 0.524415
Epoch 019 | train 0.429094 | val 0.538501
Epoch 020 | train 0.417618 | val 0.554193
Epoch 021 | train 0.40

### Simple time split test

In [9]:
res = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # -> TC and part of FIXED
    X_fixed_and_random=X_fixed_and_random_np,     # -> RANDOM and part of FIXED
    y=y_np, y_lags=y_lags_np, dt_lags=dt_lags_np,
    pid_idx=pid_np, time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=KANMixedEffectsWrapper,
    mode="single",          # or "cv_only"/"nested_cv"
    scenario="both",        # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True
)


[split] train | rows=  751 | unique_ids=  200
[split] val   | rows=   83 | unique_ids=   67
[split] test  | rows=  197 | unique_ids=   51
Epoch 001 | train 0.625355 | val 0.621029
Epoch 002 | train 0.614971 | val 0.610853
Epoch 003 | train 0.603476 | val 0.598269
Epoch 004 | train 0.589149 | val 0.582723
Epoch 005 | train 0.570891 | val 0.563352
Epoch 006 | train 0.548452 | val 0.540165
Epoch 007 | train 0.521367 | val 0.515338
Epoch 008 | train 0.492249 | val 0.491707
Epoch 009 | train 0.463986 | val 0.474024
Epoch 010 | train 0.442665 | val 0.463530
Epoch 011 | train 0.430051 | val 0.460556
Epoch 012 | train 0.424118 | val 0.459908
Epoch 013 | train 0.416297 | val 0.462933
Epoch 014 | train 0.407116 | val 0.464862
Epoch 015 | train 0.398769 | val 0.459882
Epoch 016 | train 0.390784 | val 0.455816
Epoch 017 | train 0.382113 | val 0.451618
Epoch 018 | train 0.373503 | val 0.449699
Epoch 019 | train 0.363996 | val 0.450802
Epoch 020 | train 0.354566 | val 0.451563
Epoch 021 | train 0.34

### Cases split test CV without parameter search

In [154]:

res = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # -> TC and part of FIXED
    X_fixed_and_random=X_fixed_and_random_np,     # -> RANDOM and part of FIXED
    y=y_np, y_lags=y_lags_np, dt_lags=dt_lags_np,
    pid_idx=pid_np, time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=KANMixedEffectsWrapper,
    mode="cv_only",          # or "cv_only"/"nested_cv"
    scenario="cases",        # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
    outer_folds=3
)


[CV fold 1/3] scenario=cases
[split] train | rows=  619 | unique_ids=  162
[split] val   | rows=   68 | unique_ids=   51
[split] test  | rows=  344 | unique_ids=   85
Epoch 001 | train 0.626386 | val 0.631180
Epoch 002 | train 0.618664 | val 0.626191
Epoch 003 | train 0.609833 | val 0.619980
Epoch 004 | train 0.598692 | val 0.613087
Epoch 005 | train 0.584390 | val 0.604130
Epoch 006 | train 0.565252 | val 0.593680
Epoch 007 | train 0.543033 | val 0.582972
Epoch 008 | train 0.516012 | val 0.574479
Epoch 009 | train 0.490675 | val 0.571622
Epoch 010 | train 0.468066 | val 0.571354
Epoch 011 | train 0.451320 | val 0.577060
Epoch 012 | train 0.438476 | val 0.582741
Epoch 013 | train 0.429051 | val 0.587713
Epoch 014 | train 0.419968 | val 0.590652
Epoch 015 | train 0.412148 | val 0.592757
Epoch 016 | train 0.403862 | val 0.595539
Epoch 017 | train 0.395921 | val 0.599732
Epoch 018 | train 0.388455 | val 0.602804
Epoch 019 | train 0.380055 | val 0.607732
Epoch 020 | train 0.370436 | val 0

### Time split test CV without parameter search

In [20]:
res = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # -> TC and part of FIXED
    X_fixed_and_random=X_fixed_and_random_np,     # -> RANDOM and part of FIXED
    y=y_np, y_lags=y_lags_np, dt_lags=dt_lags_np,
    pid_idx=pid_np, time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=KANMixedEffectsWrapper,
    mode="cv_only",          # or "cv_only"/"nested_cv"
    scenario="both",        # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
    outer_folds=3
)


[CV fold 1/3] scenario=cases
[split] train | rows=  619 | unique_ids=  162
[split] val   | rows=   68 | unique_ids=   51
[split] test  | rows=  344 | unique_ids=   85
Epoch 001 | train 0.629266 | val 0.639404
Epoch 002 | train 0.608112 | val 0.623038
Epoch 003 | train 0.585267 | val 0.607057
Epoch 004 | train 0.558784 | val 0.592006
Epoch 005 | train 0.529264 | val 0.579286
Epoch 006 | train 0.494826 | val 0.573656
Epoch 007 | train 0.465392 | val 0.582767
Epoch 008 | train 0.444371 | val 0.610108
Epoch 009 | train 0.436558 | val 0.640593
Epoch 010 | train 0.431692 | val 0.650805
Epoch 011 | train 0.425172 | val 0.645518
Epoch 012 | train 0.415187 | val 0.629519
Epoch 013 | train 0.405254 | val 0.616713
Epoch 014 | train 0.397368 | val 0.607551
Epoch 015 | train 0.391718 | val 0.602223
Epoch 016 | train 0.385515 | val 0.603144
Epoch 017 | train 0.377531 | val 0.608961
Epoch 018 | train 0.368462 | val 0.620888
Epoch 019 | train 0.359260 | val 0.636419
Epoch 020 | train 0.350693 | val 0

### Ablation

In [157]:
arch_defaults = dict(d_fix_latent=256, d_rand_latent=256, n_kernels=8, dropout=0.0, use_attention = False, use_random = False)

res = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # -> TC and part of FIXED
    X_fixed_and_random=X_fixed_and_random_np,     # -> RANDOM and part of FIXED
    y=y_np, y_lags=y_lags_np, dt_lags=dt_lags_np,
    pid_idx=pid_np, time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=KANMixedEffectsWrapper,
    mode="cv_only",          # or "cv_only"/"nested_cv"
    scenario="both",        # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
    outer_folds=5
)


[CV fold 1/5] scenario=cases
[split] train | rows=  742 | unique_ids=  201
[split] val   | rows=   82 | unique_ids=   65
[split] test  | rows=  207 | unique_ids=   51
Epoch 001 | train 0.662666 | val 0.657978
Epoch 002 | train 0.660247 | val 0.655375
Epoch 003 | train 0.657050 | val 0.651790
Epoch 004 | train 0.652186 | val 0.646179
Epoch 005 | train 0.644552 | val 0.637915
Epoch 006 | train 0.634448 | val 0.626129
Epoch 007 | train 0.619174 | val 0.611002
Epoch 008 | train 0.598446 | val 0.592786
Epoch 009 | train 0.575270 | val 0.573702
Epoch 010 | train 0.549078 | val 0.557326
Epoch 011 | train 0.523538 | val 0.547020
Epoch 012 | train 0.502395 | val 0.547514
Epoch 013 | train 0.486301 | val 0.553952
Epoch 014 | train 0.472944 | val 0.562537
Epoch 015 | train 0.461520 | val 0.567270
Epoch 016 | train 0.448859 | val 0.566741
Epoch 017 | train 0.436466 | val 0.563500
Epoch 018 | train 0.423948 | val 0.566173
Epoch 019 | train 0.413287 | val 0.566165
Epoch 020 | train 0.402715 | val 0

### Cases split CV with parameter search test

In [None]:
param_grid = {
    "d_fix_latent": [128],
    "d_rand_latent": [128],
    "n_kernels": [6],
    "dropout": [0.1],

    "lr": [1e-4, 3e-4],
    "weight_decay": [1e-4, 1e-3],
    "batch_size": [128, 64],
    "max_epochs": [100],
    "patience": [10],

    "lambda_mean0": [1e-4, 1e-3],
    "lambda_ridge": [1e-4, 3e-4],
    "lambda_orth_latent": [1e-3, 1e-4],
    "lambda_film_identity": [1e-4, 1e-3],
    "lambda_kan": [0.0, 1e-4],
}

res_cases = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # used for BOTH X_fix and TC
    X_fixed_and_random=X_fixed_and_random_np,           # used for Zrand
    y=y_np,
    y_lags=y_lags_np,
    dt_lags=dt_lags_np,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,         # the dynamic-dims builder we just fixed
    wrapper_cls=KANMixedEffectsWrapper,
    mode="nested_cv",                         # or "cv_only" / "nested_cv"
    scenario="cases",                       # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
    outer_folds=3,
    inner_folds=2,
    param_grid=param_grid
)


Outer fold 1/3
Outer fold macro (optthr): macro_AUC_optthr=0.7936, macro_AUPRC_optthr=0.8187, macro_Brier_optthr=0.1857, macro_ACC_optthr=0.6860, macro_F1_optthr=0.7465, macro_Precision_optthr=0.6570, macro_Recall_optthr=0.8641, macro_Sensitivity_optthr=0.8641, macro_Specificity_optthr=0.4813

Outer fold 2/3


### Time split CV with parameter search test

In [None]:
res_cases = run_training_and_eval_kan(
    X_only_fixed=X_fix_np,                 # used for BOTH X_fix and TC
    X_fixed_and_random=Zrand_np,           # used for Zrand
    y=y_np,
    y_lags=y_lags_np,
    dt_lags=dt_lags_np,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,         # the dynamic-dims builder we just fixed
    wrapper_cls=KANMixedEffectsWrapper,
    mode="nested_cv",                         # or "cv_only" / "nested_cv"
    scenario="time",                       # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
    outer_folds=3,
    inner_folds=2,
    param_grid=param_grid
)

## New outcome

In [21]:
Body_cat_df = pd.read_csv(os.path.join(data_dir, "Body_data_for_categorical_forecast.csv"))
columns_Body_cat_df = pd.read_csv(os.path.join(data_dir, "columns_Body_data_for_categorical_forecast.csv"))

Body_cat_df.head()

Unnamed: 0,id,age,group,day,beep,Body_check,Restr,Comp,BE,EE,...,EDDS15,EDDS16,EDDS17,EDDS18,education2,education3,education4,critical_event_next,day_beep_diff,day_beep_diff_copy
0,1,29,1,0,4,3,3,1,2,1,...,0,2,12,5,0,1,0,1,2,2
1,1,29,1,1,1,3,4,1,2,1,...,0,2,12,5,0,1,0,0,1,1
2,1,29,1,1,2,1,5,1,1,1,...,0,2,12,5,0,1,0,1,1,1
3,1,29,1,1,3,4,3,1,2,1,...,0,2,12,5,0,1,0,0,1,1
4,1,29,1,1,4,4,2,1,2,1,...,0,2,12,5,0,1,0,0,2,2


In [22]:
# Select the outcome column(s) marked with 1 in the "outcomes" column of columns_Body_cat_df
Body_cat_outcome_cols = columns_Body_cat_df.loc[columns_Body_cat_df['outcomes'] == 1, 'column_name'].tolist()
Body_cat_y = Body_cat_df[Body_cat_outcome_cols]
Body_cat_y.head()

# Same for outcomes lags column(s)
Body_cat_outcomes_lags_cols = columns_Body_cat_df.loc[columns_Body_cat_df['outcomes_lags'] == 1, 'column_name'].tolist()
Body_cat_outcomes_lags = Body_cat_df[Body_cat_outcomes_lags_cols]

# Same for participant column(s)
Body_cat_participant_cols = columns_Body_cat_df.loc[columns_Body_cat_df['participant_id'] == 1, 'column_name'].tolist()
Body_cat_participant_id = Body_cat_df[Body_cat_participant_cols]

# Same for time column(s)
Body_cat_time_cols = columns_Body_cat_df.loc[columns_Body_cat_df['time'] == 1, 'column_name'].tolist()
Body_cat_time = Body_cat_df[Body_cat_time_cols]

# Same for forecast horizons column(s)
Body_cat_forecast_horizons_cols = columns_Body_cat_df.loc[columns_Body_cat_df['forecast_horizons'] == 1, 'column_name'].tolist()
Body_cat_forecast_horizons = Body_cat_df[Body_cat_forecast_horizons_cols]

# Same for fixed effects column(s)
Body_cat_only_fixed_cols = columns_Body_cat_df.loc[columns_Body_cat_df['only_fixed'] == 1, 'column_name'].tolist()
Body_cat_only_fixed = Body_cat_df[Body_cat_only_fixed_cols]

# Same for random effects column(s)
Body_cat_fixed_and_random_cols = columns_Body_cat_df.loc[columns_Body_cat_df['fixed_and_random'] == 1, 'column_name'].tolist()
Body_cat_fixed_and_random = Body_cat_df[Body_cat_fixed_and_random_cols]

In [23]:
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

import numpy as np

# ---------- targets & aux ----------
y_raw      = Body_cat_y.to_numpy(np.float32)
y_np       = y_raw if y_raw.ndim == 2 else y_raw.reshape(-1, 1)

y_lags_np  = Body_cat_outcomes_lags.to_numpy(np.float32)
dt_lags_np = Body_cat_forecast_horizons.to_numpy(np.float32)

# ---------- inputs ----------
# keep the two blocks SEPARATE; evaluator will:
#   - TC  <- only_fixed (after PCA)
#   - X_fix <- concat(only_fixed_PCA, fixed_and_random_PCA)
#   - Zrand <- fixed_and_random_PCA
X_only_fixed_np        = Body_cat_only_fixed.to_numpy(np.float32)
X_fixed_and_random_np  = Body_cat_fixed_and_random.to_numpy(np.float32)

# ---------- ids & time ----------
pid_raw    = Body_cat_participant_id.to_numpy().ravel()
pid_uniqs, pid_encoded = np.unique(pid_raw, return_inverse=True)
pid_np     = pid_encoded.astype(np.int64)
n_ids      = int(len(pid_uniqs))

time_ix_np = Body_cat_time.to_numpy().ravel()

# ---------- dynamic builder ----------
def build_model_fn(
    *,
    y_dim: int,
    d_fix: int,
    d_tc: int,        # this will be the PCA’d only_fixed dim (activates FiLM)
    d_zrand: int,     # this will be the PCA’d fixed_and_random dim
    n_ids: int,
    **arch
):
    return KANAdditiveMixedEffects(
        y_dim=y_dim,
        d_fix=d_fix,       # = dim(concat(only_fixed_PCA, fixed_and_random_PCA))
        d_tc=d_tc,         # = dim(only_fixed_PCA)  -> used by FiLM
        d_zrand=d_zrand,   # = dim(fixed_and_random_PCA)
        n_ids=n_ids,
        **arch
    )



arch_defaults = dict(d_fix_latent=32, d_rand_latent=32, n_kernels=6, dropout=0.00)
train_defaults = dict(
    lr=5e-4, weight_decay=3e-5, batch_size=64,
    max_epochs=100, patience=20, threshold=0.5,
    lambda_mean0=1e-6, lambda_ridge=1e-5,
    lambda_orth_latent=3e-5,
    lambda_film_identity=3e-5,
    lambda_kan=0,
    random_state=42,
)


In [134]:
res = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # -> TC and part of FIXED
    X_fixed_and_random=X_fixed_and_random_np,     # -> RANDOM and part of FIXED
    y=y_np, y_lags=y_lags_np, dt_lags=dt_lags_np,
    pid_idx=pid_np, time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=KANMixedEffectsWrapper,
    mode="single",          # or "cv_only"/"nested_cv"
    scenario="cases",        # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True
)


Epoch 001 | train 0.526399 | val 0.536068
Epoch 002 | train 0.504631 | val 0.514953
Epoch 003 | train 0.473995 | val 0.491043
Epoch 004 | train 0.437780 | val 0.471919
Epoch 005 | train 0.404375 | val 0.438891
Epoch 006 | train 0.365621 | val 0.424810
Epoch 007 | train 0.342439 | val 0.437593
Epoch 008 | train 0.331724 | val 0.469246
Epoch 009 | train 0.324498 | val 0.485574
Epoch 010 | train 0.315082 | val 0.449413
Epoch 011 | train 0.302206 | val 0.451521
Epoch 012 | train 0.296104 | val 0.458941
Epoch 013 | train 0.287374 | val 0.455781
Epoch 014 | train 0.280117 | val 0.467125
Epoch 015 | train 0.272721 | val 0.477199
Epoch 016 | train 0.263935 | val 0.496088
Epoch 017 | train 0.259578 | val 0.487866
Epoch 018 | train 0.249786 | val 0.514506
Epoch 019 | train 0.240466 | val 0.541955
Epoch 020 | train 0.230078 | val 0.528945
Epoch 021 | train 0.217564 | val 0.551798
Epoch 022 | train 0.201423 | val 0.574612
Epoch 023 | train 0.190832 | val 0.623385
Epoch 024 | train 0.171922 | val 0

In [89]:
res = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # -> TC and part of FIXED
    X_fixed_and_random=X_fixed_and_random_np,     # -> RANDOM and part of FIXED
    y=y_np, y_lags=y_lags_np, dt_lags=dt_lags_np,
    pid_idx=pid_np, time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=KANMixedEffectsWrapper,
    mode="single",          # or "cv_only"/"nested_cv"
    scenario="time",        # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True
)


Epoch 001 | train 0.691174 | val 0.674552
Epoch 002 | train 0.673155 | val 0.657005
Epoch 003 | train 0.657176 | val 0.638801
Epoch 004 | train 0.640712 | val 0.618207
Epoch 005 | train 0.626919 | val 0.593741
Epoch 006 | train 0.609616 | val 0.562917
Epoch 007 | train 0.593699 | val 0.524263
Epoch 008 | train 0.571390 | val 0.477930
Epoch 009 | train 0.546989 | val 0.425699
Epoch 010 | train 0.520621 | val 0.375775
Epoch 011 | train 0.492534 | val 0.341928
Epoch 012 | train 0.476113 | val 0.328653
Epoch 013 | train 0.466694 | val 0.323414
Epoch 014 | train 0.459762 | val 0.317265
Epoch 015 | train 0.450997 | val 0.313831
Epoch 016 | train 0.437306 | val 0.325763
Epoch 017 | train 0.423853 | val 0.345655
Epoch 018 | train 0.414891 | val 0.364518
Epoch 019 | train 0.408511 | val 0.371982
Epoch 020 | train 0.401708 | val 0.369975
Epoch 021 | train 0.395335 | val 0.371225
Epoch 022 | train 0.388389 | val 0.373633
Epoch 023 | train 0.381316 | val 0.377838
Epoch 024 | train 0.373431 | val 0

In [135]:

res = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # -> TC and part of FIXED
    X_fixed_and_random=X_fixed_and_random_np,     # -> RANDOM and part of FIXED
    y=y_np, y_lags=y_lags_np, dt_lags=dt_lags_np,
    pid_idx=pid_np, time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=KANMixedEffectsWrapper,
    mode="cv_only",          # or "cv_only"/"nested_cv"
    scenario="cases",        # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
    outer_folds=5
)

Fold 1: macro_AUC_optthr=0.8232, macro_AUPRC_optthr=0.9041, macro_Brier_optthr=0.1664, macro_ACC_optthr=0.8058, macro_F1_optthr=0.8214, macro_Precision_optthr=0.9020, macro_Recall_optthr=0.7541, macro_Sensitivity_optthr=0.7541, macro_Specificity_optthr=0.8810
Fold 2: macro_AUC_optthr=0.7019, macro_AUPRC_optthr=0.6877, macro_Brier_optthr=0.3025, macro_ACC_optthr=0.4706, macro_F1_optthr=0.6301, macro_Precision_optthr=0.4792, macro_Recall_optthr=0.9200, macro_Sensitivity_optthr=0.9200, macro_Specificity_optthr=0.0385
Fold 3: macro_AUC_optthr=0.2531, macro_AUPRC_optthr=0.6796, macro_Brier_optthr=0.4370, macro_ACC_optthr=0.4804, macro_F1_optthr=0.6345, macro_Precision_optthr=0.7077, macro_Recall_optthr=0.5750, macro_Sensitivity_optthr=0.5750, macro_Specificity_optthr=0.1364
Fold 4: macro_AUC_optthr=0.6719, macro_AUPRC_optthr=0.6236, macro_Brier_optthr=0.2280, macro_ACC_optthr=0.3960, macro_F1_optthr=0.5674, macro_Precision_optthr=0.4124, macro_Recall_optthr=0.9091, macro_Sensitivity_optthr=

In [24]:
res = run_training_and_eval_kan(
    X_only_fixed=X_only_fixed_np,                 # -> TC and part of FIXED
    X_fixed_and_random=X_fixed_and_random_np,     # -> RANDOM and part of FIXED
    y=y_np, y_lags=y_lags_np, dt_lags=dt_lags_np,
    pid_idx=pid_np, time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=KANMixedEffectsWrapper,
    mode="cv_only",          # or "cv_only"/"nested_cv"
    scenario="both",        # or "cases"
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
    outer_folds=5
)


[CV fold 1/5] scenario=cases
[split] train | rows=  366 | unique_ids=   27
[split] val   | rows=   40 | unique_ids=   15
[split] test  | rows=  103 | unique_ids=    6
Epoch 001 | train 0.522843 | val 0.569819
Epoch 002 | train 0.488875 | val 0.526627
Epoch 003 | train 0.450330 | val 0.478078
Epoch 004 | train 0.394434 | val 0.416660
Epoch 005 | train 0.366790 | val 0.391940
Epoch 006 | train 0.352934 | val 0.392635
Epoch 007 | train 0.337484 | val 0.371298
Epoch 008 | train 0.320750 | val 0.357873
Epoch 009 | train 0.309322 | val 0.349957
Epoch 010 | train 0.298842 | val 0.338049
Epoch 011 | train 0.291214 | val 0.330355
Epoch 012 | train 0.284389 | val 0.326210
Epoch 013 | train 0.280014 | val 0.342395
Epoch 014 | train 0.274001 | val 0.338431
Epoch 015 | train 0.268582 | val 0.326771
Epoch 016 | train 0.263268 | val 0.345018
Epoch 017 | train 0.260025 | val 0.354825
Epoch 018 | train 0.255105 | val 0.329520
Epoch 019 | train 0.251449 | val 0.353573
Epoch 020 | train 0.244268 | val 0