# Setup

## Root and data folders

In [18]:
import os
import pandas as pd
import numpy as np

root_dir = "/Users/silviumatu/Desktop/Code/Python/Disertatie/Disertatie_Matu_Silviu_v1"
os.makedirs(root_dir, exist_ok=True)

data_dir = os.path.join(root_dir, "Data")
os.makedirs(data_dir, exist_ok=True)

# Data Extraction

## Helper function to split columns with multiple values stored into one-hot encoding

In [19]:
def one_hot_split_column(df, column_name, separator):
    split_values = df[column_name].fillna("").str.strip().str.replace(separator, "", regex=False).str.split()

    unique_values = sorted(set(val for sublist in split_values for val in sublist))

    for val in unique_values:
        new_col = f"{column_name}_{val}"
        df[new_col] = split_values.apply(lambda x: val in x).astype(int)
    
    return df

## PED GHQ data extraction

In [20]:
# Loading the PED data
PED_data_file = "All_timeframes_merged_processed.csv"
PED_GHQ_df = pd.read_csv(
    os.path.join(data_dir, PED_data_file))

# Clening the data
PED_GHQ_df = PED_GHQ_df[PED_GHQ_df["email_baseline"] != "test"]
PED_GHQ_df = PED_GHQ_df[PED_GHQ_df["to_exclude_baseline"] != "1"]
PED_GHQ_df = PED_GHQ_df[PED_GHQ_df["to_exclude_first_day"] != "1"]
PED_GHQ_df = PED_GHQ_df.drop_duplicates(subset=["email_baseline", "week_weekly"], keep="first")

# Drop rows that have missing data on the basline and first day reports
PED_GHQ_df = PED_GHQ_df.dropna(subset=["response_id_baseline"])
PED_GHQ_df = PED_GHQ_df.dropna(subset=["response_id_first_day"])

# Save the columns to filter them manually
PED_GHQ_df.columns.to_series().to_csv(os.path.join(data_dir, "PED_GHQ_baseline_day1_columns_to_filter.csv"), index = False)

# Reading the columns after they have been filtered
PED_columns_to_keep = pd.read_csv(os.path.join(data_dir, "PED_GHQ_baseline_day1_filtered_columns_forecast.csv"))["columns"].tolist()

# Keeping only filtered columns
PED_GHQ_df = PED_GHQ_df[PED_columns_to_keep]

# Handle an error in the calculation of a score
PED_GHQ_df["PDA_sadness_TOTAL_baseline"] = (
    PED_GHQ_df["PDA_sadness_TOTAL_baseline"].fillna(0) +
    PED_GHQ_df["PDA_saddness_TOTAL_baseline"].fillna(0)
)
# Drop the misspelled column
PED_GHQ_df = PED_GHQ_df.drop(columns=["PDA_saddness_TOTAL_baseline"])

# Drop text type columns
PED_GHQ_df = PED_GHQ_df.drop(columns=["locations_today_first_day",
                              "locations_now_first_day",
                              "activity_now_first_day",
                              "ER_ESM_1_first_day",
                              "ER_ESM_14_first_day",
                              "ER_ESM_15_first_day",
                              "ER_ESM_16_first_day",
                              "ER_ESM_17_first_day",
                              "ER_ESM_18_first_day"])

# Drop incomplete cases
PED_GHQ_df=PED_GHQ_df.dropna()

# Using the helper function to split columns with multiple values to one-hot encoding
PED_GHQ_df = one_hot_split_column(PED_GHQ_df, "stress_factors_first_day", "_")

# Drop the column that was split
PED_GHQ_df = PED_GHQ_df.drop(columns=["stress_factors_first_day"])

# Convert 1 to 0 and 2 to 1
PED_GHQ_cols_to_convert_0_1 = [
    "gender_baseline",
    "previous_psychiatric_diagnostic_baseline",
    "previous_psychiatric_treatment_baseline",
    "previous_psychologist_baseline",
    "curent_psychiatric_treatment_baseline",
    "current_psychologist_baseline"
]
PED_GHQ_df[PED_GHQ_cols_to_convert_0_1] = PED_GHQ_df[PED_GHQ_cols_to_convert_0_1].replace({1: 0, 2: 1})


# Columns for one hot encoding
PED_GHQ_columns_to_one_hot_encode = [
    "nationality_recoded_baseline",
    "live_same_city_baseline",
    "work_status_baseline",
    "siblings_baseline",
    "familly_income_baseline",
    "home_town_type_baseline",
    "type_study_program_baseline",
    "relationship_baseline",
    "religion_baseline",
    "parents_higher_education_baseline",
    "health_last_month_baseline",
    "mental_health_last_month_baseline",
    "stress_management_first_day"
]
PED_GHQ_df = pd.get_dummies(PED_GHQ_df, columns=PED_GHQ_columns_to_one_hot_encode, prefix_sep="_value_", drop_first=False)
PED_GHQ_df_dummy_cols = [col for col in PED_GHQ_df.columns if any(prefix in col for prefix in PED_GHQ_columns_to_one_hot_encode)]


# Convert booledn dummy columns to int
PED_GHQ_df[PED_GHQ_df_dummy_cols] = PED_GHQ_df[PED_GHQ_df_dummy_cols].astype(int)


# Remove .0 from the names of the columns 
PED_GHQ_df.columns = PED_GHQ_df.columns.str.replace(r"\.0\b", "", regex=True)


# Creat unique ids from emails
PED_email_to_id = {email: idx for idx, email in enumerate(PED_GHQ_df["email_baseline"].unique())}
PED_GHQ_df["participant_id"] = PED_GHQ_df["email_baseline"].map(PED_email_to_id)
PED_GHQ_mapping_df = pd.DataFrame(list(PED_email_to_id.items()), columns=["email_baseline", "participant_id"])
PED_GHQ_mapping_df.to_csv(os.path.join(data_dir,"PED_GHQ_email_id_mapping_forecast.csv"), index=False)


# Drop the email column to keep the data anonymus
PED_GHQ_df = PED_GHQ_df.drop(columns=["email_baseline"])


# Renaiming and processing time columns so that first completion for each participant is time 0
PED_GHQ_df["week_weekly"] = PED_GHQ_df["week_weekly"] - PED_GHQ_df.groupby("participant_id")["week_weekly"].transform("min")
PED_GHQ_df = PED_GHQ_df.rename(columns={"week_weekly": "time"})


# Categorical version of GHQ_Total_score
PED_GHQ_df["GHQ_TOTAL_score_category_first_day"] = np.where(PED_GHQ_df["GHQ_TOTAL_score_first_day"] >= 12, 1, 0)


# Rename columns to start with x
PED_GHQ_df = PED_GHQ_df.rename(columns={
    col: f"x_{col}" for col in PED_GHQ_df.columns
})

PED_GHQ_df["x_time_copy"] = PED_GHQ_df["x_time"]



# Save the data to a CSV file
PED_GHQ_df.to_csv(os.path.join(data_dir, "PED_GHQ_filtered_data_forecast.csv"), index = False)


# Creating the time structure for the data
PED_GHQ_df = PED_GHQ_df.sort_values(by=["x_participant_id", "x_time"])
PED_GHQ_df["x_time_difference_first_day"] = PED_GHQ_df.groupby("x_participant_id")["x_time"].shift(-1) - PED_GHQ_df["x_time"]
PED_GHQ_df["y_GHQ_TOTAL_score_next"] = PED_GHQ_df.groupby("x_participant_id")["x_GHQ_TOTAL_score_first_day"].shift(-1)
PED_GHQ_df["y_GHQ_TOTAL_score_category_next"] = PED_GHQ_df.groupby("x_participant_id")["x_GHQ_TOTAL_score_category_first_day"].shift(-1)

#Making a copy of this
PED_GHQ_df["x_time_difference_first_day_copy"] = PED_GHQ_df["x_time_difference_first_day"]

# Dropping rows were there is nothing to predict
PED_GHQ_df = PED_GHQ_df[PED_GHQ_df["y_GHQ_TOTAL_score_next"].notna()]


# Compute a count for each ID
PED_GHQ_participant_counts = PED_GHQ_df["x_participant_id"].value_counts().reset_index()
PED_GHQ_participant_counts.columns = ["participant_id", "count"]
PED_GHQ_participant_counts.to_csv(os.path.join(data_dir, "PED_GHQ_participant_count_forecast.csv"), index = False)


# Saving files for regression and categorical analyses for the forecast scenario
PED_GHQ_df_regression = PED_GHQ_df.drop(columns=["x_GHQ_TOTAL_score_category_first_day" ,"y_GHQ_TOTAL_score_category_next"])
PED_GHQ_df_regression.to_csv(os.path.join(data_dir, "PED_GHQ_regression_data_forecast.csv"), index = False)

PED_GHQ_df_categorical = PED_GHQ_df.drop(columns=["x_GHQ_TOTAL_score_first_day" ,"y_GHQ_TOTAL_score_next"])
PED_GHQ_df_categorical.to_csv(os.path.join(data_dir, "PED_GHQ_categorical_data_forecast.csv"), index = False)

  PED_GHQ_df = pd.read_csv(


# Load data

In [50]:
GHQ_cat_df = pd.read_csv(os.path.join(data_dir, "PED_GHQ_categorical_data_forecast.csv"))
columns_GHQ_cat_df = pd.read_csv(os.path.join(data_dir, "columns_PED_GHQ_categorical_data_forecast.csv"))

GHQ_cat_df.head()

Unnamed: 0,x_age_baseline,x_gender_baseline,x_previous_psychiatric_diagnostic_baseline,x_previous_psychiatric_treatment_baseline,x_previous_psychologist_baseline,x_curent_psychiatric_treatment_baseline,x_current_psychologist_baseline,x_ABS_irrational_TOTAL_baseline,x_ABS_rational_TOTAL_baseline,x_ATS_Generalization_TOTAL_baseline,...,x_mental_health_last_month_baseline_value_5,x_stress_management_first_day_value_0,x_stress_management_first_day_value_1,x_stress_management_first_day_value_2,x_participant_id,x_GHQ_TOTAL_score_category_first_day,x_time_copy,x_time_difference_first_day,y_GHQ_TOTAL_score_category_next,x_time_difference_first_day_copy
0,19.0,0.0,0.0,0.0,1.0,0.0,0.0,45.0,26.0,8.0,...,0,0,0,1,0,1,0.0,1.0,0.0,1.0
1,19.0,0.0,0.0,0.0,1.0,0.0,0.0,45.0,26.0,8.0,...,0,0,1,0,0,0,1.0,3.0,1.0,3.0
2,19.0,0.0,0.0,0.0,1.0,0.0,0.0,45.0,26.0,8.0,...,0,0,1,0,0,1,4.0,1.0,0.0,1.0
3,19.0,0.0,0.0,0.0,1.0,0.0,0.0,45.0,26.0,8.0,...,0,0,1,0,0,0,5.0,3.0,0.0,3.0
4,18.0,0.0,0.0,0.0,0.0,0.0,0.0,44.0,48.0,10.0,...,0,0,0,1,2,1,0.0,2.0,1.0,2.0


In [51]:
# Select the outcome column(s) marked with 1 in the "outcomes" column of columns_GHQ_cat_df
GHQ_cat_outcome_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['outcomes'] == 1, 'column_name'].tolist()
GHQ_cat_y = GHQ_cat_df[GHQ_cat_outcome_cols]
GHQ_cat_y.head()

# Same for outcomes lags column(s)
GHQ_cat_outcomes_lags_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['outcomes_lags'] == 1, 'column_name'].tolist()
GHQ_cat_outcomes_lags = GHQ_cat_df[GHQ_cat_outcomes_lags_cols]

# Same for participant column(s)
GHQ_cat_participant_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['participant_id'] == 1, 'column_name'].tolist()
GHQ_cat_participant_id = GHQ_cat_df[GHQ_cat_participant_cols]

# Same for time column(s)
GHQ_cat_time_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['time'] == 1, 'column_name'].tolist()
GHQ_cat_time = GHQ_cat_df[GHQ_cat_time_cols]

# Same for forecast horizons column(s)
GHQ_cat_forecast_horizons_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['forecast_horizons'] == 1, 'column_name'].tolist()
GHQ_cat_forecast_horizons = GHQ_cat_df[GHQ_cat_forecast_horizons_cols]

# Same for fixed effects column(s)
GHQ_cat_only_fixed_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['only_fixed'] == 1, 'column_name'].tolist()
GHQ_cat_only_fixed = GHQ_cat_df[GHQ_cat_only_fixed_cols]

# Same for random effects column(s)
GHQ_cat_fixed_and_random_cols = columns_GHQ_cat_df.loc[columns_GHQ_cat_df['fixed_and_random'] == 1, 'column_name'].tolist()
GHQ_cat_fixed_and_random = GHQ_cat_df[GHQ_cat_fixed_and_random_cols]

# ARMED

## Architecutre

In [23]:
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, Iterable

import torch
import torch.nn as nn
import torch.nn.functional as F

class GradientReversalFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambd: float):
        ctx.lambd = float(lambd)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambd * grad_output, None


class GradientReversal(nn.Module):
    def __init__(self, lambd: float = 1.0):
        super().__init__()
        self.lambd = float(lambd)

    def set_lambda(self, lambd: float):
        self.lambd = float(lambd)
    def forward(self, x):
        return GradientReversalFn.apply(x, self.lambd)


def mlp(in_dim: int, hidden: Iterable[int], out_dim: int, dropout: float = 0.0, last_activation: Optional[nn.Module] = None):
    layers: list[nn.Module] = []
    dims = [in_dim] + list(hidden)
    for d0, d1 in zip(dims[:-1], dims[1:]):
        layers.append(nn.Linear(d0, d1))
        layers.append(nn.ReLU())
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
    layers.append(nn.Linear(dims[-1], out_dim))
    if last_activation is not None:
        layers.append(last_activation)
    return nn.Sequential(*layers)


class FixedAE(nn.Module):
    def __init__(self, in_dim: int, enc_hidden=(128, 64), rep_dim=32, dropout=0.0,
                 use_decoder: bool = False, dec_hidden: Optional[Iterable[int]] = None):
        super().__init__()
        self.encoder = mlp(in_dim, enc_hidden, rep_dim, dropout)
        self.use_decoder = bool(use_decoder)
        if self.use_decoder:
            dec_hidden = list(dec_hidden) if dec_hidden is not None else list(enc_hidden)[::-1]
            self.decoder = mlp(rep_dim, dec_hidden, in_dim, dropout)
        else:
            self.decoder = None

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        z = self.encoder(x)
        xhat = self.decoder(z) if self.decoder is not None else None
        return z, xhat


class RandomEnc(nn.Module):
    def __init__(self, in_dim: int, hidden=(128, 64), rep_dim=32, dropout=0.0):
        super().__init__()
        self.net = mlp(in_dim, hidden, rep_dim, dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class ParticipantEmbedding(nn.Module):
    def __init__(self, n_participants: int, rep_dim: int):
        super().__init__()
        self.n_seen = int(n_participants)
        self.unk_index = self.n_seen
        self.emb = nn.Embedding(self.n_seen + 1, rep_dim, padding_idx=None)

    def forward(self, pid_idx: torch.Tensor) -> torch.Tensor:
        # Map unseen (-1) → UNK
        idx = pid_idx.clone()
        idx = torch.where(idx >= 0, idx, torch.full_like(idx, self.unk_index))
        return self.emb(idx)


class FiLM(nn.Module):
    def __init__(self, rep_dim: int, hidden=(64,), dropout=0.0):
        super().__init__()
        self.gamma = mlp(rep_dim, hidden, rep_dim, dropout)
        self.beta  = mlp(rep_dim, hidden, rep_dim, dropout)

    def forward(self, z_id: torch.Tensor, z_obs: torch.Tensor) -> torch.Tensor:
        # Stabilized residual scaling: gamma centered near 1, beta near 0
        g = 1.0 + 0.1 * torch.tanh(self.gamma(z_id))
        b = 0.1 * self.beta(z_id)
        return g * z_obs + b


class Adversary(nn.Module):
    def __init__(self, in_dim: int, hidden=(64,), n_participants: int = 1,
                 dropout: float = 0.0, grl_lambda: float = 1.0):
        super().__init__()
        self.grl = GradientReversal(grl_lambda)
        self.net = mlp(in_dim, hidden, n_participants, dropout)

    def set_lambda(self, lambd: float):
        self.grl.set_lambda(lambd)

    def forward(self, z_fixed: torch.Tensor) -> torch.Tensor:
        return self.net(self.grl(z_fixed))

class ARMEDTabular(nn.Module):
    def __init__(
        self,
        d_fixed: int,
        d_random: int = 0,
        y_dim: int = 1,
        n_participants: int = 1,
        include_random_data: bool = True,
        
        fixed_enc_hidden=(128, 64),
        fixed_rep_dim: int = 32,
        fixed_dropout: float = 0.0,
        use_fixed_decoder: bool = False,
        fixed_dec_hidden: Optional[Iterable[int]] = None,

        random_hidden=(128, 64),
        random_rep_dim: int = 32,
        random_dropout: float = 0.0,
        combine_mode: str = "add",   # "add" or "film"
        film_hidden=(64,),
        film_dropout: float = 0.0,


        adv_hidden=(64,),
        adv_dropout: float = 0.0,
        grl_lambda: float = 1.0,
        head_hidden=(64,),
    ):
        super().__init__()
        self.include_random_data = bool(include_random_data and d_random > 0)

        self.fixed = FixedAE(
            in_dim=d_fixed,
            enc_hidden=fixed_enc_hidden,
            rep_dim=fixed_rep_dim,
            dropout=fixed_dropout,
            use_decoder=use_fixed_decoder,
            dec_hidden=fixed_dec_hidden,
        )

        self.id_emb = ParticipantEmbedding(n_participants, rep_dim=random_rep_dim)

        self.random = RandomEnc(
            in_dim=d_random,
            hidden=random_hidden,
            rep_dim=random_rep_dim,
            dropout=random_dropout,
        ) if self.include_random_data else None

        self.combine_mode = combine_mode
        if combine_mode not in {"add", "film"}:
            raise ValueError("combine_mode must be 'add' or 'film'")
        self.film = FiLM(rep_dim=random_rep_dim, hidden=film_hidden, dropout=film_dropout) if combine_mode == "film" else None

        self.head = mlp(fixed_rep_dim + random_rep_dim, head_hidden, y_dim, dropout=0.0)
        self.adv  = Adversary(fixed_rep_dim, adv_hidden, n_participants, adv_dropout, grl_lambda)

        self.norm_f = nn.LayerNorm(fixed_rep_dim)
        self.norm_r = nn.LayerNorm(random_rep_dim)

    def forward(
        self,
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
        x_random: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:

        z_f, xhat = self.fixed(x_fixed)
        z_f = self.norm_f(z_f)

        z_id = self.id_emb(pid_idx)  
        if self.include_random_data and (x_random is not None):
            z_r_obs = self.random(x_random)
            if self.combine_mode == "add":
                z_r = z_r_obs + z_id
            else:
                z_r = self.film(z_id, z_r_obs)
        else:
            z_r = z_id
        z_r = self.norm_r(z_r)

        y_logits  = self.head(torch.cat([z_f, z_r], dim=1))
        adv_logits = self.adv(z_f)

        return y_logits, adv_logits, xhat, z_f, z_r

## Wrapper

In [24]:
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, Iterable, Sequence

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn


@dataclass
class ARMEDLossWeights:
    lambda_adv: float = 1.0
    lambda_recon: float = 0.0


class ARMEDWrapper:
    """
    Wrapper around ARMEDTabular adding:
      - device management
      - loss computation (unchanged)
      - validation/prediction helpers (unchanged)
      - NEW: per-task F1-optimal threshold selection utilities
             * find_best_thresholds_f1_from_arrays(y_true, y_prob)
             * find_best_thresholds_f1_from_loader(loader)
             * set_thresholds(thresholds)
             * predict_with_thresholds(..., thresholds=...) / predict_using_stored_thresholds(...)
    """
    def __init__(self, model: nn.Module, loss_weights: Optional[ARMEDLossWeights] = None, device: Optional[torch.device] = None):
        self.model = model
        self.loss_w = loss_weights or ARMEDLossWeights()
        self.device = (device
                       or (torch.device("mps") if torch.backends.mps.is_available() else None)
                       or (torch.device("cuda") if torch.cuda.is_available() else None)
                       or torch.device("cpu"))
        self.model.to(self.device)

        # Holds per-task thresholds after selection (numpy array shape [y_dim])
        self.thresholds_: Optional[np.ndarray] = None

    def forward(self, x_fixed: torch.Tensor, pid_idx: torch.Tensor, x_random: Optional[torch.Tensor] = None):
        return self.model(x_fixed, pid_idx, x_random)

    # -----------------------------
    # Loss components (unchanged)
    # -----------------------------
    def _pred_loss(self, logits: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        return F.binary_cross_entropy_with_logits(logits, y_true.float())

    def _adv_loss(self, adv_logits: torch.Tensor, pid_idx: torch.Tensor) -> torch.Tensor:
        seen_mask = (pid_idx >= 0)
        if seen_mask.any():
            return F.cross_entropy(adv_logits[seen_mask], pid_idx[seen_mask])
        else:
            return torch.tensor(0.0, device=adv_logits.device)

    def _recon_loss(self, xhat: Optional[torch.Tensor], x: torch.Tensor) -> torch.Tensor:
        if (xhat is None) or (self.loss_w.lambda_recon <= 0.0):
            return torch.tensor(0.0, device=x.device)
        return F.mse_loss(xhat, x)

    def compute_losses(
        self,
        y_true: torch.Tensor,
        y_logits: torch.Tensor,
        adv_logits: torch.Tensor,
        xhat: Optional[torch.Tensor],
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        lp = self._pred_loss(y_logits, y_true)
        la = self._adv_loss(adv_logits, pid_idx)
        lr = self._recon_loss(xhat, x_fixed)

        total = lp + self.loss_w.lambda_adv * la + self.loss_w.lambda_recon * lr

        parts = {
            "loss_total": float(total.detach().cpu()),
            "loss_pred":  float(lp.detach().cpu()),
            "loss_adv":   float(la.detach().cpu()),
            "loss_recon": float(lr.detach().cpu()),
        }
        return total, parts

    # -----------------------------
    # Eval / predict helpers
    # -----------------------------
    @torch.no_grad()
    def validation_step(
        self,
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
        y_true: torch.Tensor,
        x_random: Optional[torch.Tensor] = None,
        prefix: str = "val"
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        self.model.eval()
        x_fixed = x_fixed.to(self.device)
        pid_idx = pid_idx.to(self.device)
        y_true  = y_true.to(self.device)
        x_random = x_random.to(self.device) if x_random is not None else None

        y_logits, adv_logits, xhat, _, _ = self.forward(x_fixed, pid_idx, x_random)
        loss, parts = self.compute_losses(y_true, y_logits, adv_logits, xhat, x_fixed, pid_idx)

        print(
            f"{prefix}_loss: {parts['loss_total']:.6f} | "
            f"pred: {parts['loss_pred']:.6f} | "
            f"adv: {parts['loss_adv']:.6f} | "
            f"recon: {parts['loss_recon']:.6f}"
        )
        return loss, parts

    @torch.no_grad()
    def predict_logits(
        self,
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
        x_random: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        self.model.eval()
        x_fixed = x_fixed.to(self.device)
        pid_idx = pid_idx.to(self.device)
        x_random = x_random.to(self.device) if x_random is not None else None
        y_logits, _, _, _, _ = self.forward(x_fixed, pid_idx, x_random)
        return y_logits

    @torch.no_grad()
    def predict_proba(self, *args, **kwargs) -> torch.Tensor:
        logits = self.predict_logits(*args, **kwargs)
        return torch.sigmoid(logits)

    @torch.no_grad()
    def predict(self, *args, threshold: float = 0.5, **kwargs) -> torch.Tensor:
        proba = self.predict_proba(*args, **kwargs)
        return (proba >= threshold).to(torch.int32)

    # -----------------------------
    # NEW: F1-optimal thresholds
    # -----------------------------
    @staticmethod
    def _best_threshold_f1_1d(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]:
        """
        Exact F1 maximizing threshold for a single task by scanning the midpoints
        between unique scores (plus endpoints). Returns (best_thr, best_f1).
        """
        y_true = y_true.astype(int).ravel()
        y_prob = y_prob.astype(float).ravel()
        if y_prob.size == 0:
            return 0.5, 0.0

        # Sort unique scores and consider midpoints between them
        scores = np.unique(y_prob)
        # Edge case: constant probabilities
        if scores.size == 1:
            thr = scores[0]
            y_pred = (y_prob >= thr).astype(int)
            tp = np.sum((y_true == 1) & (y_pred == 1))
            fp = np.sum((y_pred == 1) & (y_true == 0))
            fn = np.sum((y_pred == 0) & (y_true == 1))
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            recall    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
            return float(thr), float(f1)

        mids = (scores[:-1] + scores[1:]) / 2.0
        # include slightly outside endpoints to catch all regimes
        candidates = np.concatenate(([scores[0] - 1e-12], mids, [scores[-1] + 1e-12]))

        best_f1, best_thr = -1.0, 0.5
        for thr in candidates:
            y_pred = (y_prob >= thr).astype(int)
            tp = np.sum((y_true == 1) & (y_pred == 1))
            fp = np.sum((y_pred == 1) & (y_true == 0))
            fn = np.sum((y_pred == 0) & (y_true == 1))
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            recall    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
            if f1 > best_f1 or (f1 == best_f1 and thr < best_thr):
                best_f1, best_thr = f1, thr

        return float(best_thr), float(best_f1)

    @classmethod
    def find_best_thresholds_f1_from_arrays(
        cls,
        y_true: np.ndarray,
        y_prob: np.ndarray
    ) -> Tuple[np.ndarray, Dict[str, float]]:
        """
        y_true, y_prob: arrays of shape [N, y_dim] (or [N] for single task).
        Returns:
          - thresholds: np.ndarray of shape [y_dim]
          - summary: dict with per-task F1 and macro_F1
        """
        y_true = np.asarray(y_true)
        y_prob = np.asarray(y_prob)
        if y_true.ndim == 1:
            y_true = y_true[:, None]
        if y_prob.ndim == 1:
            y_prob = y_prob[:, None]
        assert y_true.shape == y_prob.shape, "y_true and y_prob must have same shape"

        y_dim = y_true.shape[1]
        thresholds = np.zeros(y_dim, dtype=float)
        f1s = np.zeros(y_dim, dtype=float)

        for j in range(y_dim):
            thr, f1 = cls._best_threshold_f1_1d(y_true[:, j], y_prob[:, j])
            thresholds[j] = thr
            f1s[j] = f1

        summary = {f"task_{j+1}_F1_opt": float(f1s[j]) for j in range(y_dim)}
        summary["macro_F1_opt"] = float(np.nanmean(f1s))
        return thresholds, summary

    @torch.no_grad()
    def find_best_thresholds_f1_from_loader(self, loader) -> Tuple[np.ndarray, Dict[str, float]]:
        """
        Computes per-task F1-optimal thresholds using predictions on the given loader.
        """
        self.model.eval()
        probs_all, y_all = [], []
        for Xf_b, pid_b, Xr_b, y_b in loader:
            logits = self.predict_logits(Xf_b, pid_b, Xr_b if Xr_b.size(1) > 0 else None)
            probs_all.append(torch.sigmoid(logits).cpu().numpy())
            y_all.append(y_b.cpu().numpy())

        y_prob = np.vstack(probs_all)
        y_true = np.vstack(y_all)
        thresholds, summary = self.find_best_thresholds_f1_from_arrays(y_true, y_prob)
        return thresholds, summary

    def set_thresholds(self, thresholds: Sequence[float]) -> None:
        th = np.asarray(thresholds, dtype=float).ravel()
        self.thresholds_ = th

    @torch.no_grad()
    def predict_with_thresholds(
        self,
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
        x_random: Optional[torch.Tensor] = None,
        thresholds: Optional[Sequence[float]] = None
    ) -> torch.Tensor:
        """
        Predict labels using provided per-task thresholds (or stored thresholds_ if None).
        Returns an int32 tensor of shape [N, y_dim].
        """
        probs = self.predict_proba(x_fixed, pid_idx, x_random)
        if thresholds is None:
            if self.thresholds_ is None:
                raise ValueError("No thresholds provided and self.thresholds_ is not set.")
            thr = torch.as_tensor(self.thresholds_, device=probs.device, dtype=probs.dtype).view(1, -1)
        else:
            thr = torch.as_tensor(np.asarray(thresholds, dtype=float), device=probs.device, dtype=probs.dtype).view(1, -1)
        return (probs >= thr).to(torch.int32)

    @torch.no_grad()
    def predict_using_stored_thresholds(
        self,
        x_fixed: torch.Tensor,
        pid_idx: torch.Tensor,
        x_random: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Convenience wrapper that uses self.thresholds_ (must be set).
        """
        return self.predict_with_thresholds(x_fixed, pid_idx, x_random, thresholds=None)


## Evaluation procedure

In [55]:
from typing import Dict, Any, Optional, Tuple, Iterable, List
from dataclasses import dataclass

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import GroupKFold, TimeSeriesSplit, ParameterGrid
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from scipy.stats import t as student_t

# -------------------------------------------------------------------
# Dataset / loaders
# -------------------------------------------------------------------
class _ARMEDDataset(Dataset):
    def __init__(self, Xf, pid_idx, y, Xr=None, device=None):
        device = device or torch.device("cpu")
        Xf  = np.asarray(Xf, dtype=np.float32)
        y   = np.asarray(y,  dtype=np.float32)
        pid = np.asarray(pid_idx)

        if y.ndim == 1:
            y = y[:, None]

        self.Xf  = torch.as_tensor(Xf, dtype=torch.float32, device=device)
        self.pid = torch.as_tensor(pid, dtype=torch.long,    device=device)
        self.y   = torch.as_tensor(y,  dtype=torch.float32,  device=device)

        if Xr is None:
            self.Xr = torch.zeros((len(self.Xf), 0), dtype=torch.float32, device=device)
        else:
            Xr = np.asarray(Xr, dtype=np.float32)
            self.Xr = torch.as_tensor(Xr, dtype=torch.float32, device=device)

    def __len__(self):
        return self.Xf.shape[0]

    def __getitem__(self, idx):
        return self.Xf[idx], self.pid[idx], self.Xr[idx], self.y[idx]

def _make_loader(Xf, pid_idx, y, Xr=None, batch_size=256, shuffle=False, device=None):
    ds = _ARMEDDataset(Xf, pid_idx, y, Xr=Xr, device=device)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=0, drop_last=False)

# -------------------------------------------------------------------
# Metrics
# -------------------------------------------------------------------
from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    precision_recall_fscore_support, confusion_matrix
)

def _metrics_binary_full(y_true, y_prob, thr=0.5) -> Dict[str, float]:
    y_true = np.asarray(y_true).astype(int)
    y_prob = np.asarray(y_prob).astype(float)
    y_pred = (y_prob >= thr).astype(int)

    # Threshold-independent metrics
    try:
        auc = float(roc_auc_score(y_true, y_prob))
    except Exception:
        auc = float("nan")
    try:
        auprc = float(average_precision_score(y_true, y_prob))
    except Exception:
        auprc = float("nan")

    # NEW: Brier score (lower is better)
    brier = float(np.mean((y_prob - y_true) ** 2))
    # (If you prefer sklearn: from sklearn.metrics import brier_score_loss; brier = float(brier_score_loss(y_true, y_prob)))

    # Thresholded metrics
    prec, rec, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='binary', zero_division=0
    )
    acc = float((y_pred == y_true).mean())
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
        sens = float(tp / (tp + fn)) if (tp + fn) > 0 else float("nan")
        spec = float(tn / (tn + fp)) if (tn + fp) > 0 else float("nan")
    except Exception:
        sens, spec = float("nan"), float("nan")

    return {
        "AUC": auc,
        "AUPRC": auprc,
        "Brier": brier,          # <— added
        "ACC": acc,
        "F1": float(f1),
        "Precision": float(prec),
        "Recall": float(rec),
        "Sensitivity": sens,
        "Specificity": spec,
    }


import numpy as np
from typing import Dict, Any
from sklearn.calibration import calibration_curve

def compute_calibration_curves(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    n_bins: int = 10,
    strategy: str = "quantile",   # "uniform" or "quantile"
) -> Dict[str, Any]:
    """
    Returns:
      {
        "per_task": {
          0: {"mean_pred": [...], "frac_pos": [...], "counts": [...], "ece": float, "mce": float, "brier": float},
          1: {...}, ...
        },
        "macro_ECE": float,
        "macro_MCE": float,
      }
    """
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    assert y_true.shape == y_prob.shape
    y_dim = y_true.shape[1] if y_true.ndim == 2 else 1
    if y_true.ndim == 1:
        y_true = y_true[:, None]
        y_prob = y_prob[:, None]

    per_task = {}
    eces, mces = [], []

    for j in range(y_dim):
        t = y_true[:, j].astype(int)
        p = np.clip(y_prob[:, j].astype(float), 1e-6, 1.0 - 1e-6)

        # Try to build bins; if only one class present, fall back to NaNs
        try:
            frac_pos, mean_pred = calibration_curve(t, p, n_bins=n_bins, strategy=strategy)
        except Exception:
            frac_pos = np.array([])
            mean_pred = np.array([])

        # Reconstruct bin edges and counts to weight ECE properly
        if strategy == "uniform":
            edges = np.linspace(0.0, 1.0, n_bins + 1)
        else:
            # quantile bins over predictions
            qs = np.linspace(0.0, 1.0, n_bins + 1)
            edges = np.quantile(p, qs)
            edges[0], edges[-1] = 0.0, 1.0

        bin_ids = np.digitize(p, edges[1:-1], right=True)  # 0..n_bins-1
        counts = np.bincount(bin_ids, minlength=n_bins).astype(float)
        N = max(1, len(p))

        # Align counts to the returned bins (calibration_curve may drop empty bins)
        # Map each returned mean_pred to its bin to pull the right weight
        # If shapes match, we can compute directly; otherwise do a safe join.
        if len(mean_pred) == n_bins:
            weights = counts / N
        else:
            # Safe mapping by nearest edge (rare when calibration_curve prunes empties)
            weights = np.zeros_like(mean_pred, dtype=float)
            if len(mean_pred) > 0:
                # find each mean_pred's bin index
                idxs = np.digitize(mean_pred, edges[1:-1], right=True)
                weights = counts[idxs] / N

        # ECE / MCE
        gaps = np.abs(frac_pos - mean_pred) if len(mean_pred) else np.array([np.nan])
        ece = float(np.nansum(weights * gaps)) if len(mean_pred) else float("nan")
        mce = float(np.nanmax(gaps)) if len(mean_pred) else float("nan")

        # Brier (already in your metrics, but handy here too)
        brier = float(np.mean((p - t) ** 2)) if N > 0 else float("nan")

        per_task[j] = {
            "mean_pred": mean_pred.tolist(),
            "frac_pos": frac_pos.tolist(),
            "counts": counts.tolist(),
            "ece": ece,
            "mce": mce,
            "brier": brier,
        }
        eces.append(ece); mces.append(mce)

    return {
        "per_task": per_task,
        "macro_ECE": float(np.nanmean(eces)),
        "macro_MCE": float(np.nanmean(mces)),
    }

import matplotlib.pyplot as plt

def plot_calibration_curves(calib: Dict[str, Any], task_names=None, suptitle=None):
    per_task = calib["per_task"]
    for j, d in per_task.items():
        mp = np.array(d["mean_pred"], dtype=float)
        fp = np.array(d["frac_pos"], dtype=float)
        ece = d.get("ece", np.nan)

        plt.figure()
        plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1)
        if mp.size and fp.size:
            plt.plot(mp, fp, marker="o")
        name = f"Task {j+1}" if not task_names else task_names[j]
        plt.title(f"{name} — Reliability (ECE={ece:.3f})")
        plt.xlabel("Mean predicted probability")
        plt.ylabel("Observed positive rate")
        plt.grid(True, alpha=0.3)
        if suptitle:
            plt.suptitle(suptitle)
        plt.show()



def evaluate_multitask(y_true: np.ndarray, y_prob: np.ndarray, thr=0.5) -> dict:
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    assert y_true.shape == y_prob.shape
    y_dim = y_true.shape[1]

    out, macro = {}, {}
    for j in range(y_dim):
        m = _metrics_binary_full(y_true[:, j], y_prob[:, j], thr)
        for k, v in m.items():
            out[f"task_{j+1}_{k}"] = float(v)
            macro.setdefault(k, []).append(float(v))

    for k, vals in macro.items():
        out[f"macro_{k}"] = float(np.nanmean(np.array(vals, dtype=float)))
    return out

import numpy as np
from scipy.stats import t as student_t

def _summarize_cv_folds(results_folds: list[dict]) -> dict:
    """
    results_folds: list of per-fold metrics dicts (same keys each fold; missing -> NaN OK)
    Returns { "<metric>_mean": ..., "<metric>_95ci_low": ..., "<metric>_95ci_high": ... }
    """
    if not results_folds:
        return {}

    # Union of all metric keys across folds
    all_keys = set().union(*results_folds)
    summary = {}

    for k in sorted(all_keys):
        vals = np.array([fold.get(k, np.nan) for fold in results_folds], dtype=float)
        mask = np.isfinite(vals)
        n = int(mask.sum())

        if n == 0:
            m = np.nan; low = np.nan; high = np.nan
        elif n == 1:
            m = float(vals[mask][0]); low = np.nan; high = np.nan
        else:
            m = float(np.nanmean(vals))
            s = float(np.nanstd(vals, ddof=1))
            se = s / np.sqrt(n)
            tcrit = float(student_t.ppf(0.975, df=n - 1))
            low = m - tcrit * se
            high = m + tcrit * se

        summary[f"{k}_mean"] = m
        summary[f"{k}_95ci_low"] = low
        summary[f"{k}_95ci_high"] = high

    return summary


# -------------------------------------------------------------------
# PCA pipelines and helpers
# -------------------------------------------------------------------
from dataclasses import dataclass
import numpy as np
from sklearn.decomposition import PCA



# Safe defaults (tighten if needed)
_VAR_EPS = 1e-8   # drop columns with train variance <= this (use 1e-6 if still unstable)
_STD_EPS = 1e-6   # minimum std used in scaling
_CLIP_Z  = 8.0   # clip z-scores to [-CLIP_Z, CLIP_Z] before PCA

@dataclass
class PCAPipeline:
    keep_mask: np.ndarray         # True = keep column (decided on TRAIN variance)
    mean_: np.ndarray             # scaler mean (after masking)
    scale_: np.ndarray            # scaler std with epsilon floor (after masking)
    pca: PCA

def _fit_pca_pipeline(X_train: np.ndarray, var_ratio: float = 0.95, random_state: int | None = None) -> PCAPipeline:
    # 1) sanitize + float64 for numerical stability
    X = np.asarray(X_train, dtype=np.float64)
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

    # 2) drop near-constant columns on TRAIN ONLY
    var = X.var(axis=0)
    keep = var > _VAR_EPS
    if not np.any(keep):
        # degenerate: no usable features
        pca = PCA(n_components=0, svd_solver='full', random_state=random_state)
        return PCAPipeline(keep_mask=keep, mean_=np.array([], dtype=np.float64),
                           scale_=np.array([], dtype=np.float64), pca=pca)

    Xk = X[:, keep]

    # 3) robust scaling params with epsilon floor
    mean = Xk.mean(axis=0)
    std  = Xk.std(axis=0)
    std  = np.maximum(std, _STD_EPS)

    # 4) scale -> sanitize -> clip
    Z = (Xk - mean) / std
    Z = np.nan_to_num(Z, nan=0.0, posinf=0.0, neginf=0.0)
    np.clip(Z, -_CLIP_Z, _CLIP_Z, out=Z)

    # 5) PCA fit
    pca = PCA(n_components=var_ratio, svd_solver='full', random_state=random_state)
    pca.fit(Z)
    if not np.isfinite(pca.components_).all():
        raise RuntimeError("PCA components contain non-finite values after fit.")

    return PCAPipeline(keep_mask=keep, mean_=mean, scale_=std, pca=pca)

def _transform_pca_pipeline(pipe: PCAPipeline | None, X: np.ndarray | None) -> np.ndarray | None:
    
    if pipe is None or X is None:
        return None

    # 1) sanitize + float64
    X = np.asarray(X, dtype=np.float64)
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

    # 2) apply TRAIN mask
    if pipe.keep_mask.size == 0 or not np.any(pipe.keep_mask):
        return np.zeros((X.shape[0], 0), dtype=np.float32)
    Xk = X[:, pipe.keep_mask]

    # 3) scale with epsilon-protected std; then sanitize & clip
    Z = (Xk - pipe.mean_) / pipe.scale_
    Z = np.nan_to_num(Z, nan=0.0, posinf=0.0, neginf=0.0)
    np.clip(Z, -_CLIP_Z, _CLIP_Z, out=Z)


    if not np.isfinite(Z).all():
        bad = np.argwhere(~np.isfinite(Z))[0]
        raise RuntimeError(f"[our PCA] Z non-finite at {tuple(bad)}: {Z[tuple(bad)]}")
    if np.abs(Z).max() > 1e6:
        raise RuntimeError(f"[our PCA] Z max |z| too large: {np.abs(Z).max()}")
    if not np.isfinite(pipe.pca.components_).all():
        raise RuntimeError("[our PCA] components_ non-finite")
    if hasattr(pipe.pca, "mean_") and not np.isfinite(pipe.pca.mean_).all():
        raise RuntimeError("[our PCA] mean_ non-finite")

    # 4) PCA transform (manual projection in float64, then cast)
    Z64 = np.ascontiguousarray(Z, dtype=np.float64)
    CT  = np.ascontiguousarray(pipe.pca.components_.T, dtype=np.float64)

    
    with np.errstate(over='ignore', invalid='ignore', divide='ignore'):
        Xt = Z64 @ CT

    # ensure strictly finite output (belt & suspenders)
    Xt = np.nan_to_num(Xt, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return Xt




def _concat_safe(*arrays: Optional[np.ndarray]) -> np.ndarray:
    parts = [a for a in arrays if a is not None and a.size > 0]
    if not parts:
        return np.zeros((0, 0), dtype=np.float32)
    return np.concatenate(parts, axis=1).astype(np.float32)

def _filter_time_test_min_measurements(pid_idx: np.ndarray, test_idx: np.ndarray, min_meas: int = 3):
    """Keep only rows in test_idx belonging to pids with >= min_meas measurements overall."""
    pid = np.asarray(pid_idx)
    counts = {pid_val: np.sum(pid == pid_val) for pid_val in np.unique(pid)}
    keep = [i for i in test_idx if counts.get(pid[i], 0) >= min_meas]
    return np.array(keep, dtype=int)

def _evaluate_with_thresholds(y_true: np.ndarray, y_prob: np.ndarray, thresholds: np.ndarray) -> Dict[str, float]:
    y_true = np.asarray(y_true); y_prob = np.asarray(y_prob)
    if y_true.ndim == 1: y_true = y_true[:, None]
    if y_prob.ndim == 1: y_prob = y_prob[:, None]
    y_dim = y_true.shape[1]
    out, macro = {}, {}
    for j in range(y_dim):
        thr = thresholds[j]
        m = _metrics_binary_full(y_true[:, j], y_prob[:, j], thr)
        for k, v in m.items():
            out[f"task_{j+1}_{k}_optthr"] = float(v)
            macro.setdefault(k, []).append(float(v))
    for k, vals in macro.items():
        out[f"macro_{k}_optthr"] = float(np.nanmean(np.array(vals, dtype=float)))
    return out

# -------------------------------------------------------------------
# Splitting
# -------------------------------------------------------------------
def _split_cases(pid_array, test_fraction=0.2, seed=42):
    rng = np.random.default_rng(seed)
    unique_ids = np.unique(pid_array)
    te_ids = rng.choice(unique_ids, size=max(1, int(len(unique_ids) * test_fraction)), replace=False)
    te_mask = np.isin(pid_array, te_ids)
    return np.where(~te_mask)[0], np.where(te_mask)[0]

def _split_time_basic(time_index, test_fraction=0.2):
    order = np.argsort(time_index)
    n = len(order)
    split = int(np.floor(n * (1.0 - test_fraction)))
    return order[:split], order[split:]

# -------------------------------------------------------------------
# Train loop (prints train & monitor; early stop on val loss by default)
# -------------------------------------------------------------------
def _eval_macro_auc_on_loader(wrapper, loader: DataLoader) -> float:
    wrapper.model.eval()
    probs_all, y_all = [], []
    with torch.no_grad():
        for Xf_b, pid_b, Xr_b, y_b in loader:
            logits = wrapper.predict_logits(Xf_b, pid_b, Xr_b if Xr_b.size(1) > 0 else None)
            probs_all.append(torch.sigmoid(logits).cpu().numpy())
            y_all.append(y_b.cpu().numpy())
    y_prob = np.vstack(probs_all); y_true = np.vstack(y_all)
    m = evaluate_multitask(y_true, y_prob, thr=0.5)
    return float(m.get("macro_AUC", np.nan))

def _fit_once(
    wrapper, optimizer,
    train_loader: DataLoader,
    val_loader: Optional[DataLoader],          # used for early stopping (default)
    monitor_loader: Optional[DataLoader],      # printed each epoch; often the test set
    max_epochs: int = 100,
    patience: int = 10,
    early_stop_metric: str = "loss",           # "loss" | "macro_AUC"
    early_stop_on: str = "val",                # "val" | "train" | "monitor"
    verbose: bool = True,
):
    # If early_stop_on='val' but val_loader is None, fall back to 'train'
    if early_stop_on == "val" and val_loader is None:
        early_stop_on = "train"

    best_val = np.inf if early_stop_metric == "loss" else -np.inf
    best_state, no_improve = None, 0

    def _avg_loss(loader) -> Optional[float]:
        if loader is None:
            return None
        wrapper.model.eval()
        total, n = 0.0, 0
        with torch.no_grad():
            for Xf_b, pid_b, Xr_b, y_b in loader:
                y_logits, adv_logits, xhat, _, _ = wrapper.forward(Xf_b, pid_b, Xr_b if Xr_b.size(1) > 0 else None)
                l, _ = wrapper.compute_losses(y_b, y_logits, adv_logits, xhat, Xf_b, pid_b)
                bs = Xf_b.size(0)
                total += float(l.detach().cpu()) * bs
                n += bs
        return total / max(1, n)

    for epoch in range(1, max_epochs + 1):
        # --- Train pass ---
        wrapper.model.train()
        total_tr, n_tr = 0.0, 0
        for Xf_b, pid_b, Xr_b, y_b in train_loader:
            y_logits, adv_logits, xhat, _, _ = wrapper.forward(
                Xf_b, pid_b, Xr_b if Xr_b.size(1) > 0 else None
            )
            loss, _ = wrapper.compute_losses(y_b, y_logits, adv_logits, xhat, Xf_b, pid_b)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            bs = Xf_b.size(0)
            total_tr += float(loss.detach().cpu()) * bs
            n_tr += bs
        train_loss = total_tr / max(1, n_tr)

        # --- Monitor print ---
        monitor_loss = _avg_loss(monitor_loader)
        if verbose:
            if monitor_loss is not None:
                print(f"Epoch {epoch:03d} | train {train_loss:.6f} | val_loss(monitor) {monitor_loss:.6f}")
            else:
                print(f"Epoch {epoch:03d} | train {train_loss:.6f}")

        # --- Early stopping ---
        if early_stop_metric == "loss":
            if early_stop_on == "val":
                current = _avg_loss(val_loader)
                is_better = (current is not None) and (current < best_val - 1e-6)
                metric_for_best = current
            elif early_stop_on == "train":
                current = train_loss
                is_better = current < best_val - 1e-6
                metric_for_best = current
            else:  # "monitor"
                current = monitor_loss
                is_better = (current is not None) and (current < best_val - 1e-6)
                metric_for_best = current
        else:  # early_stop_metric == "macro_AUC"
            if early_stop_on == "val":
                current = _eval_macro_auc_on_loader(wrapper, val_loader) if val_loader is not None else np.nan
            elif early_stop_on == "train":
                current = _eval_macro_auc_on_loader(wrapper, train_loader)
            else:
                current = _eval_macro_auc_on_loader(wrapper, monitor_loader) if monitor_loader is not None else np.nan
            is_better = (not np.isnan(current)) and (current > best_val + 1e-6)
            metric_for_best = current

        if is_better:
            best_val = metric_for_best
            best_state = {k: v.detach().cpu().clone() for k, v in wrapper.model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                if verbose:
                    tag = f"{early_stop_metric}@{early_stop_on}"
                    print(f"Early stopping at epoch {epoch:03d} (best {tag} {best_val:.6f})")
                break

    if best_state is not None:
        wrapper.model.load_state_dict(best_state)

# -------------------------------------------------------------------
# PCA + loaders + per-split PID remap
# -------------------------------------------------------------------
def _prepare_split_and_loaders(
    X_only_fixed: np.ndarray,
    X_fixed_and_random: Optional[np.ndarray],
    y: np.ndarray,
    pid_idx_full: np.ndarray,
    indices_train: np.ndarray,
    indices_val: Optional[np.ndarray],
    indices_test: np.ndarray,
    batch_size: int,
    device: torch.device,
    random_state: int = 42,
    pca_var_ratio: float = 0.95,
):
    # Fit PCA pipelines on TRAIN only
    of_pipe = _fit_pca_pipeline(X_only_fixed[indices_train], var_ratio=pca_var_ratio, random_state=random_state)
    fr_pipe = None
    if X_fixed_and_random is not None and X_fixed_and_random.shape[1] > 0:
        fr_pipe = _fit_pca_pipeline(X_fixed_and_random[indices_train], var_ratio=pca_var_ratio, random_state=random_state)

    # Transform helper
    def transform_block(idxs):
        of = _transform_pca_pipeline(of_pipe, X_only_fixed[idxs])
        fr = _transform_pca_pipeline(fr_pipe, None if X_fixed_and_random is None else X_fixed_and_random[idxs])
        Xf = _concat_safe(of, fr)  # fixed = OF ⊕ FR
        Xr = fr                    # random = FR
        return Xf, Xr

    Xf_tr, Xr_tr = transform_block(indices_train)
    Xf_te, Xr_te = transform_block(indices_test)
    if indices_val is not None:
        Xf_va, Xr_va = transform_block(indices_val)
    else:
        Xf_va, Xr_va = None, None

    # Per-split PID remap: train-seen -> 0..n_seen-1, unseen -> -1
    seen = np.unique(pid_idx_full[indices_train])
    pid_to_seen = {p: i for i, p in enumerate(seen)}

    def map_pids(idxs):
        vals = pid_idx_full[idxs]
        mapped = np.array([pid_to_seen.get(p, -1) for p in vals], dtype=np.int64)
        return mapped

    pid_tr = map_pids(indices_train)
    pid_te = map_pids(indices_test)
    pid_va = map_pids(indices_val) if indices_val is not None else None

    # Build loaders
    tr_loader = _make_loader(Xf_tr, pid_tr, y[indices_train], Xr=Xr_tr, batch_size=batch_size, shuffle=True,  device=device)
    va_loader = _make_loader(Xf_va, pid_va, y[indices_val], Xr=Xr_va, batch_size=batch_size, shuffle=False, device=device) if Xf_va is not None else None
    te_loader = _make_loader(Xf_te, pid_te, y[indices_test],  Xr=Xr_te, batch_size=batch_size, shuffle=False, device=device)

    # Post-PCA dims
    d_fixed  = Xf_tr.shape[1]
    d_random = 0 if Xr_tr is None else Xr_tr.shape[1]

    preprocessors = {
        "only_fixed": of_pipe,
        "fixed_and_random": fr_pipe,
        "d_fixed": d_fixed,
        "d_random": d_random,
        "n_seen": int(len(seen)),
    }
    loaders = {"train": tr_loader, "val": va_loader, "test": te_loader}
    return preprocessors, loaders

# -------------------------------------------------------------------
# One split fit/eval (used by all modes)
# -------------------------------------------------------------------
def _fit_eval_once(
    build_model_fn, wrapper_cls,
    arch_params: Dict[str, Any],
    train_params: Dict[str, Any],
    X_of: np.ndarray,
    X_fr: Optional[np.ndarray],
    y: np.ndarray,
    pid_idx_full: np.ndarray,
    tr_idx: np.ndarray,
    va_idx: Optional[np.ndarray],
    te_idx: np.ndarray,
    device: torch.device,
    monitor_source: str = "test",               # "test" (default) or "val"
    threshold_selection_source: str = "train",  # "train" (default) | "val" | "test"
    verbose: bool = True,
):
    preprocessors, loaders = _prepare_split_and_loaders(
        X_of, X_fr, y, pid_idx_full,
        tr_idx, va_idx, te_idx,
        batch_size=train_params.get("batch_size", 256),
        device=device,
        random_state=train_params.get("random_state", 42),
        pca_var_ratio=train_params.get("pca_var_ratio", 0.95),
    )

    d_fixed  = preprocessors["d_fixed"]
    d_random = preprocessors["d_random"]
    n_seen   = preprocessors["n_seen"]

    # Build model with post-PCA dims and train-seen participant count
    y_dim = y.shape[1] if y.ndim == 2 else 1
    model = build_model_fn(
        d_fixed=d_fixed,
        d_random=d_random,
        y_dim=y_dim,
        n_participants=n_seen,
        **arch_params
    ).to(device)

    wrapper = wrapper_cls(model, loss_weights=train_params.get("loss_weights", None), device=device)
    opt = torch.optim.Adam(wrapper.model.parameters(),
                           lr=train_params.get("lr", 1e-3),
                           weight_decay=train_params.get("weight_decay", 0.0))

    monitor_loader = loaders["test"] if monitor_source == "test" else loaders["val"]

    _fit_once(
        wrapper, opt,
        loaders["train"], loaders["val"], monitor_loader,
        max_epochs=train_params.get("max_epochs", 100),
        patience=train_params.get("patience", 10),
        early_stop_metric=train_params.get("early_stop_metric", "loss"),
        early_stop_on=train_params.get("early_stop_on", "val"),
        verbose=verbose
    )

    # Predict on TEST loader
    wrapper.model.eval()
    probs_all, y_all = [], []
    with torch.no_grad():
        for Xf_b, pid_b, Xr_b, y_b in loaders["test"]:
            logits = wrapper.predict_logits(Xf_b, pid_b, Xr_b if Xr_b.size(1) > 0 else None)
            probs_all.append(torch.sigmoid(logits).cpu().numpy())
            y_all.append(y_b.cpu().numpy())
    y_prob_te = np.vstack(probs_all); y_true_te = np.vstack(y_all)

    # Metrics @ 0.5
    metrics_050 = evaluate_multitask(y_true_te, y_prob_te, thr=0.5)

    # Threshold selection (per-task exact F1-opt)
    if threshold_selection_source == "train":
        thr_vec, thr_summary = wrapper.find_best_thresholds_f1_from_loader(loaders["train"])
    elif threshold_selection_source == "val" and loaders["val"] is not None:
        thr_vec, thr_summary = wrapper.find_best_thresholds_f1_from_loader(loaders["val"])
    else:  # "test" or no val available fallback
        thr_vec, thr_summary = wrapper.find_best_thresholds_f1_from_loader(loaders["test"])

    metrics_opt = _evaluate_with_thresholds(y_true_te, y_prob_te, thr_vec)

    calib = compute_calibration_curves(y_true_te, y_prob_te,
    n_bins=train_params.get("calibration_bins", 10),
    strategy=train_params.get("calibration_strategy", "quantile"),)

    # add scalar summaries into metrics
    macro_ECE = calib["macro_ECE"]
    macro_MCE = calib["macro_MCE"]

    return {
        "metrics@0.5": metrics_050,
        "metrics@optthr": metrics_opt,
        "opt_thresholds": thr_vec,
        "opt_thresholds_summary": thr_summary,
        "preprocessors": preprocessors,
        "wrapper": wrapper,
        "model": wrapper.model,
        "macro_ECE": macro_ECE,
        "macro_MCE": macro_MCE
    }



# -------------------------------------------------------------------
# Main entry: single / cv_only / nested_cv, scenarios cases/time/both
# -------------------------------------------------------------------
def run_training_and_eval_armed(
    X_only_fixed: np.ndarray,
    X_fixed_and_random: Optional[np.ndarray],
    y: np.ndarray,
    pid_idx: np.ndarray,
    time_index: np.ndarray,
    build_model_fn,               # callable(d_fixed, d_random, y_dim, n_participants, **arch)
    wrapper_cls,                  # ARMEDWrapper
    *,
    mode: str = "single",         # "single" | "cv_only" | "nested_cv"
    scenario: str = "cases",      # "cases" | "time" | "both"
    outer_folds: int = 5,
    inner_folds: int = 3,
    param_grid: Optional[Dict[str, List]] = None,
    arch_defaults: Optional[Dict[str, Any]] = None,
    train_defaults: Optional[Dict[str, Any]] = None,
    device: Optional[torch.device] = None,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Unified runner:
      - scaling + PCA(0.95) per block fit on TRAIN inside each split
      - time scenario: test uses participants with >=3 measurements
      - inner model selection by macro-F1 at per-task F1-opt thresholds (chosen on TRAIN)
      - prints train + monitor (default monitor=test) losses per epoch
    """
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    X_of = np.asarray(X_only_fixed, dtype=np.float32)
    X_fr = None if X_fixed_and_random is None else np.asarray(X_fixed_and_random, dtype=np.float32)
    y    = np.asarray(y, dtype=np.float32)
    pid_idx = np.asarray(pid_idx, dtype=np.int64)
    time_ix = np.asarray(time_index)

    arch_defaults = arch_defaults or {}
    train_defaults = train_defaults or {}

    rnd = int(train_defaults.get("random_state", 42))
    val_frac = float(train_defaults.get("val_fraction", 0.10))
    monitor_source = train_defaults.get("monitor_source", "test")                     # printed each epoch
    thr_source = train_defaults.get("threshold_selection_source", "train")            # choose thresholds on TRAIN by default

    def _make_train_val_split(idx_array: np.ndarray, seed: int) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """Return (train_indices, val_indices). If too small, no val."""
        idx_array = np.asarray(idx_array)
        if len(idx_array) <= 10 or val_frac <= 0.0:
            return idx_array, None
        rng = np.random.default_rng(seed)
        perm = rng.permutation(len(idx_array))
        cut = max(1, int(val_frac * len(idx_array)))
        va_sel, tr_sel = perm[:cut], perm[cut:]
        return idx_array[tr_sel], idx_array[va_sel]

    # Run both scenarios if requested
    if scenario == "both":
        out_cases = run_training_and_eval_armed(
            X_of, X_fr, y, pid_idx, time_ix,
            build_model_fn, wrapper_cls,
            mode=mode, scenario="cases",
            outer_folds=outer_folds, inner_folds=inner_folds,
            param_grid=param_grid, arch_defaults=arch_defaults, train_defaults=train_defaults,
            device=device, verbose=verbose
        )
        out_time = run_training_and_eval_armed(
            X_of, X_fr, y, pid_idx, time_ix,
            build_model_fn, wrapper_cls,
            mode=mode, scenario="time",
            outer_folds=outer_folds, inner_folds=inner_folds,
            param_grid=param_grid, arch_defaults=arch_defaults, train_defaults=train_defaults,
            device=device, verbose=verbose
        )
        return {"cases": out_cases, "time": out_time}

    # -------------------- MODE: SINGLE --------------------
    if mode == "single":
        if scenario == "cases":
            tr_idx_all, te_idx = _split_cases(pid_idx, test_fraction=0.2, seed=rnd)
        elif scenario == "time":
            tr_idx_all, te_idx_raw = _split_time_basic(time_ix, test_fraction=0.2)
            te_idx = _filter_time_test_min_measurements(pid_idx, te_idx_raw, min_meas=3)
            if len(te_idx) == 0:
                raise RuntimeError("Time split produced empty test after >=3 measurements filter.")
            te_idx = te_idx
        else:
            raise ValueError("scenario must be 'cases' or 'time'")

        if scenario == "cases":
            te_idx_use = te_idx
        else:
            te_idx_use = te_idx

        tr_idx, va_idx = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd)

        res = _fit_eval_once(
            build_model_fn, wrapper_cls,
            arch_defaults, train_defaults,
            X_of, X_fr, y,
            pid_idx,     # ALWAYS pass full pid array
            tr_idx, va_idx, te_idx_use,
            device=device,
            monitor_source=monitor_source,
            threshold_selection_source=thr_source,
            verbose=verbose
        )

        if verbose:
            print("\nSingle-fit test metrics @0.5:")
            for k, v in res["metrics@0.5"].items():
                print(f"{k:>18}: {v:.4f}")
            print("\nSingle-fit test metrics @F1-opt per task:")
            for k, v in res["metrics@optthr"].items():
                print(f"{k:>18}: {v:.4f}")

        return res

    # -------------------- MODE: CV-ONLY --------------------
    if mode == "cv_only":
        fold_metrics_050, fold_metrics_opt = [], []

        if scenario == "cases":
            outer = GroupKFold(n_splits=outer_folds)
            outer_iter = outer.split(X_of, y[:, 0] if y.ndim > 1 else y, groups=pid_idx)
        elif scenario == "time":
            tss = TimeSeriesSplit(n_splits=outer_folds)
            order = np.argsort(time_ix)
            X_order = X_of[order]; y_order = y[order]
            outer_iter = ((order[tr], order[te]) for tr, te in tss.split(X_order, y_order[:, 0] if y.ndim > 1 else y_order))
        else:
            raise ValueError("scenario must be 'cases' or 'time'")

        for fold_id, (tr_idx_all, te_idx_raw) in enumerate(outer_iter, start=1):
            if scenario == "time":
                te_idx = _filter_time_test_min_measurements(pid_idx, te_idx_raw, min_meas=3)
                if len(te_idx) == 0:
                    if verbose: print(f"Skipping time fold {fold_id} (empty test after filter).")
                    continue
                te_idx_use = te_idx
            else:
                te_idx_use = te_idx_raw

            tr_idx, va_idx = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd + fold_id)

            res = _fit_eval_once(
                build_model_fn, wrapper_cls,
                arch_defaults, train_defaults,
                X_of, X_fr, y,
                pid_idx,
                tr_idx, va_idx, te_idx_use,
                device=device,
                monitor_source=monitor_source,
                threshold_selection_source=thr_source,   # TRAIN by default
                verbose=False
            )
            fold_metrics_050.append(res["metrics@0.5"])
            fold_metrics_opt.append(res["metrics@optthr"])

            if verbose:
                macro_keys = [k for k in res["metrics@optthr"].keys() if k.startswith("macro_")]
                print(f"Fold {fold_id}: " + ", ".join(f"{k}={res['metrics@optthr'][k]:.4f}" for k in macro_keys))
        
        cv_summary_050 = _summarize_cv_folds(fold_metrics_050)
        cv_summary_opt = _summarize_cv_folds(fold_metrics_opt)

        if verbose:
            print("\nCV averages (±95% CI) for 0.50 treshold:")
            for key in sorted(cv_summary_050.keys()):
                if key.endswith("_mean"):
                    base = key[:-5]
                    mean = cv_summary_050[key]
                    low  = cv_summary_050.get(f"{base}_95ci_low", np.nan)
                    high = cv_summary_050.get(f"{base}_95ci_high", np.nan)
                    print(f"{base:>20}: {mean:.4f}  (95% CI {low:.4f}, {high:.4f})")

            print("\nCV averages (±95% CI) for optimal threshold:")
            for key in sorted(cv_summary_opt.keys()):
                if key.endswith("_mean"):
                    base = key[:-5]
                    mean = cv_summary_opt[key]
                    low  = cv_summary_opt.get(f"{base}_95ci_low", np.nan)
                    high = cv_summary_opt.get(f"{base}_95ci_high", np.nan)
                    print(f"{base:>20}: {mean:.4f}  (95% CI {low:.4f}, {high:.4f})")


        return {
            "cv_folds_metrics@0.5": fold_metrics_050,
            "cv_folds_metrics@optthr": fold_metrics_opt,
            "cv_summary@0.5": cv_summary_050,
            "cv_summary@optthr": cv_summary_opt,
        }


    # -------------------- MODE: NESTED CV --------------------
    if mode == "nested_cv":
        if not param_grid:
            param_grid = {
                "fixed_rep_dim": [32, 64, 128],
                "random_rep_dim": [32],
                "combine_mode": ["add"],
                "grl_lambda": [1.0],
                "lr": [1e-3, 3e-4],
                "weight_decay": [0.0, 1e-4],
                "batch_size": [256],
                "max_epochs": [100],
                "patience": [10],
            }

        results_folds = []
        best_score_global, best_params_global = -np.inf, None

        # Outer iterator
        if scenario == "cases":
            outer = GroupKFold(n_splits=outer_folds)
            outer_iter = outer.split(X_of, y[:, 0] if y.ndim > 1 else y, groups=pid_idx)
        elif scenario == "time":
            tss = TimeSeriesSplit(n_splits=outer_folds)
            order = np.argsort(time_ix)
            X_order = X_of[order]; y_order = y[order]
            outer_iter = ((order[tr], order[te]) for tr, te in tss.split(X_order, y_order[:, 0] if y.ndim > 1 else y_order))
        else:
            raise ValueError("scenario must be 'cases' or 'time'")

        for fold_id, (tr_idx_all, te_idx_raw) in enumerate(outer_iter, start=1):
            if verbose:
                print(f"\nOuter fold {fold_id}/{outer_folds}")

            if scenario == "time":
                te_idx = _filter_time_test_min_measurements(pid_idx, te_idx_raw, min_meas=3)
                if len(te_idx) == 0:
                    if verbose: print(f"Skipping outer time fold {fold_id} (empty test after filter).")
                    continue
                te_idx_use = te_idx
            else:
                te_idx_use = te_idx_raw

            # ----- INNER CV: model selection by macro F1 at per-task F1-opt thresholds (chosen on TRAIN) -----
            def inner_iter():
                if scenario == "cases":
                    inner = GroupKFold(n_splits=inner_folds)
                    return inner.split(X_of[tr_idx_all], (y[tr_idx_all, 0] if y.ndim > 1 else y[tr_idx_all]), groups=pid_idx[tr_idx_all])
                else:
                    tr_order = np.argsort(time_ix[tr_idx_all])
                    X_tr_order = X_of[tr_idx_all][tr_order]
                    y_tr_order = y[tr_idx_all][tr_order]
                    inner_tss = TimeSeriesSplit(n_splits=inner_folds)
                    return ((tr_idx_all[tr_order][itr], tr_idx_all[tr_order][iva])
                            for itr, iva in inner_tss.split(X_tr_order, y_tr_order[:, 0] if y.ndim > 1 else y_tr_order))

            best_inner_score, best_inner_params = -np.inf, None

            for params in ParameterGrid(param_grid):
                arch_params = dict(arch_defaults)
                train_params = dict(train_defaults)
                for k, v in params.items():
                    if k in ("fixed_rep_dim", "random_rep_dim", "combine_mode", "grl_lambda"):
                        arch_params[k] = v
                    else:
                        train_params[k] = v

                inner_scores = []
                for in_tr, in_va in inner_iter():
                    # For time scenario, apply ≥3 rule on inner held-out set
                    if scenario == "time":
                        in_va_f = _filter_time_test_min_measurements(pid_idx, in_va, min_meas=3)
                        if len(in_va_f) == 0:
                            continue
                        in_va = in_va_f

                    # carve small early-stop val from in_tr
                    tr_idx_inner, va_idx_inner = _make_train_val_split(np.asarray(in_tr), seed=rnd + fold_id)

                    # Fit once: thresholds from TRAIN; evaluate on inner "test" (=in_va)
                    res_inner = _fit_eval_once(
                        build_model_fn, wrapper_cls,
                        arch_params, train_params,
                        X_of, X_fr, y,
                        pid_idx,
                        tr_idx_inner, va_idx_inner, in_va,
                        device=device,
                        monitor_source="val",
                        threshold_selection_source="train",   # << thresholds from TRAIN
                        verbose=False
                    )
                    score = res_inner["metrics@optthr"].get("macro_F1_optthr", np.nan)
                    inner_scores.append(score)

                if len(inner_scores) == 0:
                    avg_score = -np.inf
                else:
                    avg_score = float(np.nanmean(inner_scores))

                if avg_score > best_inner_score:
                    best_inner_score = avg_score
                    best_inner_params = (arch_params, train_params)

            # ----- Outer evaluation with best inner params -----
            if best_inner_params is None:
                if verbose: print("No viable inner config; skipping outer fold.")
                continue

            arch_params, train_params = best_inner_params
            tr_idx_outer, va_idx_outer = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd + fold_id * 17)

            res_outer = _fit_eval_once(
                build_model_fn, wrapper_cls,
                arch_params, train_params,
                X_of, X_fr, y,
                pid_idx,
                tr_idx_outer, va_idx_outer, te_idx_use,
                device=device,
                monitor_source=monitor_source,
                threshold_selection_source="train",   # << thresholds from TRAIN
                verbose=False
            )
            results_folds.append(res_outer)

            score_outer = res_outer["metrics@optthr"].get("macro_F1_optthr", -np.inf)
            if score_outer > best_score_global:
                best_score_global = score_outer
                best_params_global = (arch_params, train_params)

            if verbose:
                macro_keys = [k for k in res_outer["metrics@optthr"].keys() if k.startswith("macro_")]
                print("Outer fold macro (optthr): " + ", ".join(f"{k}={res_outer['metrics@optthr'][k]:.4f}" for k in macro_keys))

        # summarize folds
        def summarize(results_list: List[Dict[str, Any]], which: str) -> Dict[str, float]:
            keys = list(results_list[0][which].keys())
            out = {}
            for k in keys:
                arr = np.array([res[which][k] for res in results_list], dtype=float)
                m = float(np.nanmean(arr)); s = float(np.nanstd(arr, ddof=1)); n = len(arr)
                se = s / np.sqrt(n) if n > 1 else np.nan
                if n > 1:
                    tcrit = float(student_t.ppf(0.975, df=n-1))
                    ci = (m - tcrit * se, m + tcrit * se)
                else:
                    ci = (np.nan, np.nan)
                out[k + "_mean"] = m
                out[k + "_95ci_low"] = ci[0]
                out[k + "_95ci_high"] = ci[1]
            return out

        cv_summary_050    = summarize(results_folds, "metrics@0.5")
        cv_summary_optthr = summarize(results_folds, "metrics@optthr")

        # Optional final refit on a fresh 80/20 with best params
        if verbose:
            print("\nBest params (by outer macro_F1 at opt thresholds):")
            arch_p, train_p = best_params_global
            print("[ARCH]:");   [print(f"  {k}: {v}") for k, v in arch_p.items()]
            print("[TRAIN]:");  [print(f"  {k}: {v}") for k, v in train_p.items()]

        # Final refit (fresh 80/20) for reporting
        if scenario == "cases":
            tr_idx_all, te_idx = _split_cases(pid_idx, test_fraction=0.2, seed=rnd)
            te_idx_use = te_idx
        else:
            tr_idx_all, te_idx_raw = _split_time_basic(time_ix, test_fraction=0.2)
            te_idx = _filter_time_test_min_measurements(pid_idx, te_idx_raw, min_meas=2)
            if len(te_idx) == 0:
                raise RuntimeError("Final refit: time split produced empty test after filter.")
            te_idx_use = te_idx

        tr_idx_final, va_idx_final = _make_train_val_split(np.asarray(tr_idx_all), seed=rnd + 999)

        final_res = _fit_eval_once(
            build_model_fn, wrapper_cls,
            best_params_global[0], best_params_global[1],
            X_of, X_fr, y,
            pid_idx,
            tr_idx_final, va_idx_final, te_idx_use,
            device=device,
            monitor_source=monitor_source,
            threshold_selection_source="train",
            verbose=verbose
        )

        return {
            "outer_folds": results_folds,
            "cv_summary@0.5": cv_summary_050,
            "cv_summary@optthr": cv_summary_optthr,
            "best_params": {"arch": best_params_global[0], "train": best_params_global[1]},
            "final_refit": final_res,
        }

    raise ValueError("mode must be one of {'single','cv_only','nested_cv'}")


## Model test

### Define variables and parameters

In [56]:
y_raw      = GHQ_cat_y.to_numpy(np.float32)
y_np       = y_raw if y_raw.ndim == 2 else y_raw.reshape(-1, 1)

pid_raw    = GHQ_cat_participant_id.to_numpy().ravel()
pid_uniqs, pid_encoded = np.unique(pid_raw, return_inverse=True)
pid_np     = pid_encoded.astype(np.int64)       
n_ids      = int(len(pid_uniqs))                

time_ix_np = GHQ_cat_time.to_numpy().ravel()



def build_model_fn(d_fixed, d_random, y_dim, n_participants, **arch):
    return ARMEDTabular(
        d_fixed=d_fixed,
        d_random=d_random,
        y_dim=y_dim,
        n_participants=n_participants,
        **arch
    )

wrapper_cls = ARMEDWrapper

arch_defaults = dict(
    fixed_rep_dim=256,
    random_rep_dim=256,
    combine_mode="add",
    grl_lambda=1.0,
)

train_defaults = dict(
    lr=1e-4,
    weight_decay=1e-4,
    batch_size=256,
    max_epochs=100,
    patience=20,
    loss_weights=ARMEDLossWeights(lambda_adv=1.0, lambda_recon=0.0),

    # NEW knobs:
    pca_var_ratio=0.95,                   # 95% variance PCA
    threshold_selection_source="train",   # pick per-task F1 thresholds on TRAIN
    monitor_source="test",                # print “val_loss” on the test loader
    early_stop_metric="loss",             # default (BCE + λ_adv CE + λ_recon MSE)
    early_stop_on="val",                  # stop on validation loss
)


### Simple split test

In [17]:
res_simple = run_training_and_eval_armed(
    X_only_fixed=GHQ_cat_only_fixed.values,
    X_fixed_and_random=GHQ_cat_fixed_and_random.values,
    y=y_np,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=wrapper_cls,
    mode="single",
    scenario="both",          # <- runs cases and time in one go
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
)


Epoch 001 | train 6.030224 | val_loss(monitor) 0.692274
Epoch 002 | train 6.012431 | val_loss(monitor) 0.688858
Epoch 003 | train 5.997552 | val_loss(monitor) 0.686105
Epoch 004 | train 5.984345 | val_loss(monitor) 0.683502
Epoch 005 | train 5.972251 | val_loss(monitor) 0.680849
Epoch 006 | train 5.960769 | val_loss(monitor) 0.678124
Epoch 007 | train 5.950499 | val_loss(monitor) 0.675621
Epoch 008 | train 5.938741 | val_loss(monitor) 0.673116
Epoch 009 | train 5.928022 | val_loss(monitor) 0.670584
Epoch 010 | train 5.917136 | val_loss(monitor) 0.667917
Epoch 011 | train 5.905789 | val_loss(monitor) 0.665497
Epoch 012 | train 5.894815 | val_loss(monitor) 0.663129
Epoch 013 | train 5.884111 | val_loss(monitor) 0.660909
Epoch 014 | train 5.872952 | val_loss(monitor) 0.659218
Epoch 015 | train 5.862207 | val_loss(monitor) 0.657941
Epoch 016 | train 5.851113 | val_loss(monitor) 0.657120
Epoch 017 | train 5.839987 | val_loss(monitor) 0.656684
Epoch 018 | train 5.829613 | val_loss(monitor) 0

### Nested CV with parameter search

In [13]:
param_grid = {
    # arch params
    "fixed_rep_dim": [64, 128, 256],
    "random_rep_dim": [64, 128],
    "combine_mode": ["add", "film"],
    "grl_lambda": [0.5, 1.0],
    # train params
    "lr": [1e-3, 3e-4, 1e-4],
    "weight_decay": [0.0, 1e-4],
    "batch_size": [128, 256],
    "max_epochs": [100],
    "patience": [20],
}

res_nested = run_training_and_eval_armed(
    X_only_fixed=GHQ_cat_only_fixed.values,
    X_fixed_and_random=GHQ_cat_fixed_and_random.values,
    y=GHQ_cat_y.values,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=wrapper_cls,
    mode="nested_cv",
    scenario="both",
    outer_folds=5,
    inner_folds=3,
    param_grid=param_grid,
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
)



Outer fold 1/5
Outer fold macro (optthr): macro_AUC_optthr=0.7566, macro_AUPRC_optthr=0.7738, macro_Brier_optthr=0.2010, macro_ACC_optthr=0.7246, macro_F1_optthr=0.7532, macro_Precision_optthr=0.7909, macro_Recall_optthr=0.7190, macro_Sensitivity_optthr=0.7190, macro_Specificity_optthr=0.7326

Outer fold 2/5
Outer fold macro (optthr): macro_AUC_optthr=0.7379, macro_AUPRC_optthr=0.7717, macro_Brier_optthr=0.2088, macro_ACC_optthr=0.6748, macro_F1_optthr=0.6884, macro_Precision_optthr=0.7115, macro_Recall_optthr=0.6667, macro_Sensitivity_optthr=0.6667, macro_Specificity_optthr=0.6842

Outer fold 3/5
Outer fold macro (optthr): macro_AUC_optthr=0.8183, macro_AUPRC_optthr=0.8112, macro_Brier_optthr=0.1830, macro_ACC_optthr=0.7184, macro_F1_optthr=0.7500, macro_Precision_optthr=0.6797, macro_Recall_optthr=0.8365, macro_Sensitivity_optthr=0.8365, macro_Specificity_optthr=0.5980

Outer fold 4/5
Outer fold macro (optthr): macro_AUC_optthr=0.8578, macro_AUPRC_optthr=0.8803, macro_Brier_optthr=0

### CV without parameter search

In [57]:
res_cv_only = run_training_and_eval_armed(
    X_only_fixed=GHQ_cat_only_fixed.values,
    X_fixed_and_random=GHQ_cat_fixed_and_random.values,
    y=y_np,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=wrapper_cls,
    mode="cv_only",
    scenario="both",
    outer_folds=3,
    arch_defaults=arch_defaults,      # used as-is; no grid
    train_defaults=train_defaults,    # used as-is; no grid
    verbose=True,
)


Fold 1: macro_AUC_optthr=0.6701, macro_AUPRC_optthr=0.6703, macro_Brier_optthr=0.2709, macro_ACC_optthr=0.5872, macro_F1_optthr=0.5535, macro_Precision_optthr=0.6567, macro_Recall_optthr=0.4783, macro_Sensitivity_optthr=0.4783, macro_Specificity_optthr=0.7125
Fold 2: macro_AUC_optthr=0.7695, macro_AUPRC_optthr=0.7929, macro_Brier_optthr=0.2187, macro_ACC_optthr=0.6744, macro_F1_optthr=0.6686, macro_Precision_optthr=0.7687, macro_Recall_optthr=0.5916, macro_Sensitivity_optthr=0.5916, macro_Specificity_optthr=0.7778
Fold 3: macro_AUC_optthr=0.7899, macro_AUPRC_optthr=0.7908, macro_Brier_optthr=0.2020, macro_ACC_optthr=0.6851, macro_F1_optthr=0.7545, macro_Precision_optthr=0.6409, macro_Recall_optthr=0.9171, macro_Sensitivity_optthr=0.9171, macro_Specificity_optthr=0.4259

CV averages (±95% CI) for 0.50 treshold:
           macro_ACC: 0.6480  (95% CI 0.5205, 0.7754)
           macro_AUC: 0.7432  (95% CI 0.5840, 0.9023)
         macro_AUPRC: 0.7513  (95% CI 0.5770, 0.9257)
         macro_B

## New outcome

In [59]:
Body_cat_df = pd.read_csv(os.path.join(data_dir, "Body_data_for_categorical_forecast.csv"))
columns_Body_cat_df = pd.read_csv(os.path.join(data_dir, "columns_Body_data_for_categorical_forecast.csv"))

Body_cat_df.head()

Unnamed: 0,id,age,group,day,beep,Body_check,Restr,Comp,BE,EE,...,EDDS15,EDDS16,EDDS17,EDDS18,education2,education3,education4,critical_event_next,day_beep_diff,day_beep_diff_copy
0,1,29,1,0,4,3,3,1,2,1,...,0,2,12,5,0,1,0,1,2,2
1,1,29,1,1,1,3,4,1,2,1,...,0,2,12,5,0,1,0,0,1,1
2,1,29,1,1,2,1,5,1,1,1,...,0,2,12,5,0,1,0,1,1,1
3,1,29,1,1,3,4,3,1,2,1,...,0,2,12,5,0,1,0,0,1,1
4,1,29,1,1,4,4,2,1,2,1,...,0,2,12,5,0,1,0,0,2,2


In [60]:
# Select the outcome column(s) marked with 1 in the "outcomes" column of columns_Body_cat_df
Body_cat_df_outcome_cols = columns_Body_cat_df.loc[columns_Body_cat_df['outcomes'] == 1, 'column_name'].tolist()
Body_cat_y = Body_cat_df[Body_cat_df_outcome_cols]
Body_cat_y.head()

# Same for outcomes lags column(s)
Body_cat_outcomes_lags_cols = columns_Body_cat_df.loc[columns_Body_cat_df['outcomes_lags'] == 1, 'column_name'].tolist()
Body_cat_outcomes_lags = Body_cat_df[Body_cat_outcomes_lags_cols]

# Same for participant column(s)
Body_cat_participant_cols = columns_Body_cat_df.loc[columns_Body_cat_df['participant_id'] == 1, 'column_name'].tolist()
Body_cat_participant_id = Body_cat_df[Body_cat_participant_cols]

# Same for time column(s)
Body_cat_time_cols = columns_Body_cat_df.loc[columns_Body_cat_df['time'] == 1, 'column_name'].tolist()
Body_cat_time = Body_cat_df[Body_cat_time_cols]

# Same for forecast horizons column(s)
Body_cat_forecast_horizons_cols = columns_Body_cat_df.loc[columns_Body_cat_df['forecast_horizons'] == 1, 'column_name'].tolist()
Body_cat_forecast_horizons = Body_cat_df[Body_cat_forecast_horizons_cols]

# Same for fixed effects column(s)
Body_cat_only_fixed_cols = columns_Body_cat_df.loc[columns_Body_cat_df['only_fixed'] == 1, 'column_name'].tolist()
Body_cat_only_fixed = Body_cat_df[Body_cat_only_fixed_cols]

# Same for random effects column(s)
Body_cat_fixed_and_random_cols = columns_Body_cat_df.loc[columns_Body_cat_df['fixed_and_random'] == 1, 'column_name'].tolist()
Body_cat_fixed_and_random = Body_cat_df[Body_cat_fixed_and_random_cols]

In [62]:
y_raw      = Body_cat_y.to_numpy(np.float32)
y_np       = y_raw if y_raw.ndim == 2 else y_raw.reshape(-1, 1)

pid_raw    = Body_cat_participant_id.to_numpy().ravel()
pid_uniqs, pid_encoded = np.unique(pid_raw, return_inverse=True)
pid_np     = pid_encoded.astype(np.int64)       
n_ids      = int(len(pid_uniqs))                

time_ix_np = Body_cat_time.to_numpy().ravel()



def build_model_fn(d_fixed, d_random, y_dim, n_participants, **arch):
    return ARMEDTabular(
        d_fixed=d_fixed,
        d_random=d_random,
        y_dim=y_dim,
        n_participants=n_participants,
        **arch
    )

wrapper_cls = ARMEDWrapper

arch_defaults = dict(
    fixed_rep_dim=256,
    random_rep_dim=256,
    combine_mode="add",
    grl_lambda=1.0,
)

train_defaults = dict(
    lr=1e-4,
    weight_decay=1e-4,
    batch_size=256,
    max_epochs=100,
    patience=20,
    loss_weights=ARMEDLossWeights(lambda_adv=1.0, lambda_recon=0.0),

    # NEW knobs:
    pca_var_ratio=0.95,                   # 95% variance PCA
    threshold_selection_source="train",   # pick per-task F1 thresholds on TRAIN
    monitor_source="test",                # print “val_loss” on the test loader
    early_stop_metric="loss",             # default (BCE + λ_adv CE + λ_recon MSE)
    early_stop_on="val",                  # stop on validation loss
)


In [48]:
res_simple = run_training_and_eval_armed(
    X_only_fixed=Body_cat_only_fixed.values,
    X_fixed_and_random=Body_cat_fixed_and_random.values,
    y=y_np,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=wrapper_cls,
    mode="single",
    scenario="both",          # <- runs cases and time in one go
    arch_defaults=arch_defaults,
    train_defaults=train_defaults,
    verbose=True,
)


Epoch 001 | train 4.028074 | val_loss(monitor) 0.699170
Epoch 002 | train 4.006329 | val_loss(monitor) 0.703826
Epoch 003 | train 3.987052 | val_loss(monitor) 0.709048
Epoch 004 | train 3.969835 | val_loss(monitor) 0.714651
Epoch 005 | train 3.954430 | val_loss(monitor) 0.720502
Epoch 006 | train 3.941038 | val_loss(monitor) 0.726300
Epoch 007 | train 3.928426 | val_loss(monitor) 0.731854
Epoch 008 | train 3.917093 | val_loss(monitor) 0.737196
Epoch 009 | train 3.906043 | val_loss(monitor) 0.742041
Epoch 010 | train 3.896091 | val_loss(monitor) 0.746426
Epoch 011 | train 3.886171 | val_loss(monitor) 0.750105
Epoch 012 | train 3.877179 | val_loss(monitor) 0.753221
Epoch 013 | train 3.868231 | val_loss(monitor) 0.755465
Epoch 014 | train 3.859213 | val_loss(monitor) 0.757318
Epoch 015 | train 3.851116 | val_loss(monitor) 0.758610
Epoch 016 | train 3.842240 | val_loss(monitor) 0.759756
Epoch 017 | train 3.834473 | val_loss(monitor) 0.761044
Epoch 018 | train 3.826720 | val_loss(monitor) 0

In [63]:
res_cv_only = run_training_and_eval_armed(
    X_only_fixed=Body_cat_only_fixed.values,
    X_fixed_and_random=Body_cat_fixed_and_random.values,
    y=y_np,
    pid_idx=pid_np,
    time_index=time_ix_np,
    build_model_fn=build_model_fn,
    wrapper_cls=wrapper_cls,
    mode="cv_only",
    scenario="both",
    outer_folds=3,
    arch_defaults=arch_defaults,      # used as-is; no grid
    train_defaults=train_defaults,    # used as-is; no grid
    verbose=True,
)


Fold 1: macro_AUC_optthr=0.8036, macro_AUPRC_optthr=0.8602, macro_Brier_optthr=0.1651, macro_ACC_optthr=0.7529, macro_F1_optthr=0.8158, macro_Precision_optthr=0.7949, macro_Recall_optthr=0.8378, macro_Sensitivity_optthr=0.8378, macro_Specificity_optthr=0.5932
Fold 2: macro_AUC_optthr=0.6982, macro_AUPRC_optthr=0.6902, macro_Brier_optthr=0.3161, macro_ACC_optthr=0.5176, macro_F1_optthr=0.6822, macro_Precision_optthr=0.5176, macro_Recall_optthr=1.0000, macro_Sensitivity_optthr=1.0000, macro_Specificity_optthr=0.0000
Fold 3: macro_AUC_optthr=0.6884, macro_AUPRC_optthr=0.8076, macro_Brier_optthr=0.2141, macro_ACC_optthr=0.5976, macro_F1_optthr=0.6495, macro_Precision_optthr=0.7079, macro_Recall_optthr=0.6000, macro_Sensitivity_optthr=0.6000, macro_Specificity_optthr=0.5938

CV averages (±95% CI) for 0.50 treshold:
           macro_ACC: 0.6326  (95% CI 0.3235, 0.9417)
           macro_AUC: 0.7301  (95% CI 0.5713, 0.8888)
         macro_AUPRC: 0.7860  (95% CI 0.5698, 1.0022)
         macro_B