In [1]:
# Imports
import os

import pandas as pd
from sklearn.decomposition import PCA

import metrics
from plot_utils import 

## Configuration

In [None]:
# Dictionary to map names to metrics. Can configure to use these how you would like. See below training configuration.
simple_metric_dict = {
    "mse": metrics.mse,
    "nll_embedding": metrics.compute_nll_contextual,
    "cosine_sim": metrics.cosine_similarity,
    "cosine_distance": metrics.cosine_distance,
    "similarity_entropy": metrics.similarity_entropy,
}

In [4]:
# Data configuration
# Time relative to word onset to center neural data around. In ms.
lag: int = 0
# The width of neural data to gather around each word onset in seconds.
window_width: float = 0.5
# The name of the embeddings to use. Currently supports gpt-2xl and arbitrary.
embedding_type: str = "gpt-2xl"
# Layer of model to gather embeddings from. Required if using gpt2-xl.
embedding_layer: Optional[int] = None
# Root of data folder.
data_root: str = "data"
# Number of embeddings to reduce the embeddings to using pca. If None, don't run PCA.
embedding_pca_dim: Optional[int] = None
# CSV file with columns subject (subject integer id) and elec (string name of electrode).
# If not set then defaults to configured subject_ids and channel_reg_ex. See our significant electrode file.
electrode_file_path: Optional[str] = None
# The subject id's to include in your analysis. For the podcast data they must all be in the range [1, 9]
subject_ids: list[int] = []
# A regular expression to pick which channels you are interested in.
# (i.e. "LG[AB]*" will select channels that start with "LGA" or "LGB")
channel_reg_ex: Optional[str] = None
# Column name in dataframe to use for grouping by words.
word_column: Optional[str] = "lemmatized_word"
# Used in preprocessor specific to the PITOM model. See preprocess_neural_data() below for overwriting with your own.
num_average_samples: int = 32 

# Training Configuration
# The batch size to train our decoder with.
batch_size: int = 32
# The maximum number of epochs to train over each fold with.
epochs: int = 100
# The learning rate to use when training. TODO: currently staic lr, could use a scheduler in the future.
learning_rate: float = 0.001
# The amount of weight decay to use as regularization in our optimizer.
weight_decay: float = 0.0001
# If cosine similarity between our predicted embeddings and the actual embeddings do not improve after this many steps
# stop training for this fold early.
early_stopping_patience: int = 10
# Number of folds to train over per-lag.
n_folds: int = 5
# Path to write model checkpoints to.
model_dir: str = "models"
# Type of fold generation to use. One of "sequential_folds" or "zero_shot_folds". Sequential folds mean we
# use time segmented folds. zero_shot_folds requires that no word in the test set appears in the training set.
fold_type: str = "sequential_folds"
# Losses to use, by default use nll_embedding. Must be defined in simple_metric_dict. We find that nll_embedding
# tends to be the best for our decoding tests because it is a contrastive loss (roughly equivalent to negative log likelihood over a batch)
losses: list[str] = ["nll_embedding"]
# Weight to assign to each loss. Should be parallel array with losses.
loss_weights: list[float] = [1.0]
# Metrics to track during training. Must be defined in simple_metric_dict.
metrics: list[str] = ["mse", "cosine_sim"]
# Metric to use for early stopping over validation set. Must be either the loss or in metrics.
early_stopping_metric: str = "cosine_sim"
# Whether or not a smaller value is better for early_stopping_metric. Should be False for metrics you
# want to increase (i.e. cosine similarity) but True for ones you want to decrease (i.e. MSE).
smaller_is_better: bool = False
# Number of gradient accumulation steps.
grad_accumulation_steps: int = 1
# TODO: Generalize parameters to metrics based on config. So we don't need to have these last few.
# Minimum number of occurences of a word in training set to be used for ROC-AUC calculation.
min_train_freq_auc: int = 5
# Minimum number of occurences of a word in test set to be used for ROC-AUC calculation.
min_test_freq_auc: int = -1
# Sets the k we use in top-k metrics.
top_k_thresholds: list[int] = [1, 5, 10]

# Model configuration (specific to your model). Passed in the get_model function below.
model_params = {
  "conv_filters": 128
  "reg": 0.35
  "reg_head": 0
  "dropout": 0.2
  "num_models": 10
  "embedding_dim": 50
}

# Other configuration
# Base directory to output results to.
output_dir: str = "results"
# Base directory to write models to.
model_dir: str = "models"

## Model Setup

You can replace this with your own DIVER model code here, or just import from your own files.

This model is roughly the one used for decoding in the 2022 paper: https://www.nature.com/articles/s41593-022-01026-4#Sec31. If you have your model setup to return a predicted word embedding, the EmbeddingPrediction class below may be useful for translating that into a prediction over all of the words as specified in the Decoding analysis section in the paper above.

In [2]:
class PitomModel(nn.Module):
    def __init__(
        self,
        input_channels,
        output_dim,
        conv_filters=128,
        reg=0.35,
        reg_head=0,
        dropout=0.2
    ):
        """
        PyTorch implementation of the PITOM decoding model.
        
        Args:
            input_channels: Numbr of electrodes in data (int)
            output_dim: Dimension of output vector (int)
            conv_filters: Number of convolutional filters (default: 128)
            reg: L2 regularization factor for convolutional layers (default: 0.35)
            reg_head: L2 regularization factor for dense head (default: 0)
            dropout: Dropout rate (default: 0.2)
        """
        super(PitomModel, self).__init__()
        
        self.conv_filters = conv_filters
        self.reg = reg
        self.reg_head = reg_head
        self.dropout = dropout
        self.output_dim = output_dim
        
        # Define the CNN architecture
        self.desc = [(conv_filters, 3), ('max', 2), (conv_filters, 2)]
        
        # Build the layers
        self.layers = nn.ModuleList()
        
        for i, (filters, kernel_size) in enumerate(self.desc):
            if filters == 'max':
                self.layers.append(
                    nn.MaxPool1d(kernel_size=kernel_size, stride=kernel_size, padding=kernel_size//2)
                )
            else:
                # Conv block
                conv = nn.Conv1d(
                    in_channels=input_channels if i == 0 else conv_filters,
                    out_channels=filters,
                    kernel_size=kernel_size,
                    stride=1,
                    padding=0,  # 'valid' in Keras
                    bias=False
                )
                
                # Apply weight decay equivalent to L2 regularization
                self.layers.append(conv)
                self.layers.append(nn.ReLU())
                self.layers.append(nn.BatchNorm1d(filters))
                self.layers.append(nn.Dropout(dropout))
                
                input_channels = filters
        
        # Final locally connected layer (using Conv1d with groups as approximation).
        # Not exactly the same as original paper but pytorch does not have locally connected layers.
        self.final_conv = nn.Conv1d(
            in_channels=conv_filters,
            out_channels=conv_filters,
            kernel_size=2,
            stride=1,
            padding=0,  # 'valid' in Keras
            bias=True
        )
        
        self.final_bn = nn.BatchNorm1d(conv_filters)
        self.final_act = nn.ReLU()
        
        # Output layer
        self.dense = nn.Linear(conv_filters, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)
        self.tanh = nn.Tanh()
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        
        x = self.final_conv(x)
        x = self.final_bn(x)
        x = self.final_act(x)
        
        x = F.adaptive_max_pool1d(x, 1).squeeze(-1)
        
        x = self.dense(x)
        x = self.layer_norm(x)
        x = self.tanh(x)
            
        return x


class EnsemblePitomModel(nn.Module):
    def __init__(
        self,
        num_models: int,
        input_channels,
        output_dim: int,
        conv_filters=128,
        reg=0.35,
        reg_head=0,
        dropout=0.2
    ):
        """
        PyTorch implementation of the PITOM decoding model.
        
        Args:
            num_models: The number of models to include in the ensemble. The outputs will be averaged at the end.
            input_channels: Numbr of electrodes in data (int)
            output_dim: Dimensionality of output (int)
            conv_filters: Number of convolutional filters (default: 128)
            reg: L2 regularization factor for convolutional layers (default: 0.35)
            reg_head: L2 regularization factor for dense head (default: 0)
            dropout: Dropout rate (default: 0.2)
        """
        super(EnsemblePitomModel, self).__init__()

        self.models = nn.ModuleList()
        for _ in range(num_models):
            self.models.append(PitomModel(
                input_channels,
                output_dim,
                conv_filters=conv_filters,
                reg=reg,
                reg_head=reg_head,
                dropout=dropout
            ))

    def forward(self, x):
        # Run all models and average together all embeddings.
        embeddings = torch.stack([model(x) for model in self.models])
        return embeddings


# Make sure to update this get_model function here which takes in your configured model parameters above.
def get_model(model_params):
    return EnsemblePitomModel(**model_params)

## Data Setup

Make sure you've run ./setup.sh to download the dataset to your machine. If you want it to create a venv with gpu dependencies you can run ./setup.sh --gpu.

In [None]:
# Overwrite this to preprocess the data as needed for your model. Will be passed straight into your decoding model in this form.
# see below for where it is called.
def preprocess_neural_data(data, preprocessor_params):
    return data.reshape(
        data.shape[0], data.shape[1], -1, num_average_samples
    ).mean(-1)

### Utils

In [None]:
def get_gpt_2xl_embeddings(df_contextual):
    """
    Loads GPT-2 XL contextual embeddings and aligns them to word-level units.

    This function:
    1. Loads sub-token-level GPT-2 XL embeddings from a specified HDF5 file.
    2. Groups the embeddings according to word indices provided in the contextual DataFrame.
    3. Averages sub-token embeddings to produce a single embedding vector per word.

    Args:
        df_contextual (pd.DataFrame): DataFrame containing token-level data, including `word_idx` for grouping.

    Returns:
        np.ndarray: A 2D array of shape (num_words, embedding_dim), where each row is a word-level embedding.
    """
    embedding_path = os.path.join(
        data_root, "stimuli/gpt2-xl/features.hdf5"
    )

    with h5py.File(embedding_path, "r") as f:
        contextual_embeddings = f[f"layer-{embedding_layer}"][...]

    # Group embeddings for each word (some are sub-tokenized).
    aligned_embeddings = []
    for _, group in df_contextual.groupby("word_idx"):  # group by word index
        indices = group.index.to_numpy()
        average_emb = contextual_embeddings[indices].mean(0)  # average features
        aligned_embeddings.append(average_emb)
    aligned_embeddings = np.stack(aligned_embeddings)

    return aligned_embeddings

def get_arbitrary_embeddings(df_word):
    """
    Generates arbitrary (random) embeddings for each unique word in the input DataFrame.

    Parameters:
    -----------
    df_word : pandas.DataFrame
        A DataFrame containing a column named 'word', representing a list of words.
    Returns:
    --------
    pd.DataFrame:
        df_word with arbitrary embeddings in embedding column

    Notes:
    ------
    - Embeddings are randomly sampled from a uniform distribution in the range [-1.0, 1.0]
      with an initial dimensionality of 50, then truncated or padded to match
      `embedding_pca_dim`.
    - Useful as a placeholder or for testing models where real word embeddings are not required.
    """
    words = df_word.word.tolist()
    unique_words = list(set(words))
    word_to_idx = {}
    for i, word in enumerate(words):
        if word not in word_to_idx:
            word_to_idx[word] = []
        word_to_idx[word].append(i)

    arbitrary_embeddings_per_word = np.random.uniform(
        low=-1.0, high=1.0, size=(len(unique_words), embedding_pca_dim)
    )
    arbitrary_embeddings = np.zeros((len(words), embedding_pca_dim))
    for i, word in enumerate(unique_words):
        for idx in word_to_idx[word]:
            arbitrary_embeddings[idx] = arbitrary_embeddings_per_word[i]

    df_word["target"] = list(arbitrary_embeddings)

    return df_word

def word_embedding_decoding_task():
    """
    Loads and processes word-level data and retrieves corresponding embeddings based on specified parameters.

    This function performs the following steps:
    1. Loads a transcript file containing token-level information.
    2. Retrieves aligned embeddings for each token or word, depending on the specified embedding type.
    3. Groups sub-token entries into full words using word indices.
    4. Optionally applies PCA to reduce the dimensionality of the embeddings.

    Returns:
        Tuple[pd.DataFrame, np.ndarray]: A DataFrame containing word-level information (word, start time, end time),
        and a NumPy array of corresponding word-level embeddings under the header target.
    """
    import nltk
    from nltk.stem import WordNetLemmatizer as wl

    try:
        nltk.data.find("corpora/wordnet")
        print("WordNet already downloaded")
    except LookupError:
        print("Downloading WordNet...")
        nltk.download("wordnet")
        nltk.download("wordnet")

    transcript_path = os.path.join(
        data_root, "stimuli/gpt2-xl/transcript.tsv"
    )

    # Load transcript
    df_contextual = pd.read_csv(transcript_path, sep="\t", index_col=0)

    aligned_embeddings = get_gpt_2xl_embeddings(
        df_contextual
    )

    # Group sub-tokens together into words.
    df_word = df_contextual.groupby("word_idx").agg(
        dict(word="first", start="first", end="last")
    )
    df_word["norm_word"] = df_word.word.str.lower().str.replace(
        r"^[^\w\s]+|[^\w\s]+$", "", regex=True
    )
    df_word["lemmatized_word"] = df_word.norm_word.apply(lambda x: wl().lemmatize(x))

    if embedding_type == "gpt-2xl":
        df_word["target"] = list(aligned_embeddings)
    elif embedding_type == "arbitrary":
        df_word = get_arbitrary_embeddings(df_word)

    if embedding_pca_dim:
        pca = PCA(n_components=embedding_pca_dim, svd_solver="auto")
        df_word.target = list(pca.fit_transform(df_word.target.tolist()))

    return df_word


def get_data(
    lag,
    raws: list[mne.io.Raw],
    df_word: pd.DataFrame,
    window_width: float,
    word_column: Optional[str] = None,
):
    """Gather data for every word in df_word from raw.

    Args:
        lag: the lag relative to each word onset to gather data around
        raws: list of mne.Raw object holding electrode data
        df_word: dataframe containing columns start, end, word, and target
        window_width: the width of the window which is gathered around each word onset + lag
        word_column: If provided, will return the column of words specified here.
    """
    datas = []
    for raw in raws:
        # Calculate time bounds for filtering
        tmin = lag / 1000 - window_width / 2
        tmax = lag / 1000 + window_width / 2 - 2e-3
        data_duration = raw.times[-1]  # End time of the data

        # Filter out events where the time window falls outside data bounds
        valid_mask = (df_word.start + tmin >= 0) & (
            df_word.start + tmax <= data_duration
        )
        df_word_valid = df_word[valid_mask].reset_index(drop=True)

        if len(df_word_valid) == 0:
            # No valid events for this raw, skip
            continue

        events = np.zeros((len(df_word_valid), 3), dtype=int)
        events[:, 0] = (df_word_valid.start * raw.info["sfreq"]).astype(int)

        epochs = mne.Epochs(
            raw,
            events,
            tmin=tmin,
            tmax=tmax,
            baseline=None,
            proj=False,
            event_id=None,
            preload=True,
            on_missing="ignore",
            event_repeated="merge",
            verbose="ERROR",
        )

        data = epochs.get_data(copy=False)
        selected_targets = df_word_valid.target[epochs.selection]

        # TODO: Clean this up so we don't need to pass around this potentially None variable.
        if word_column:
            selected_words = df_word_valid[word_column].to_numpy()[epochs.selection]
        else:
            selected_words = None

        # Make sure the number of samples match
        assert data.shape[0] == selected_targets.shape[0], "Sample counts don't match"
        if selected_words is not None:
            assert data.shape[0] == selected_words.shape[0], "Words don't match"

        datas.append(data)

    if len(datas) == 0:
        raise ValueError("No valid events found within data time bounds")

    datas = np.concatenate(datas, axis=1)
    # Your preprocessor is called here.
    datas = preprocess_neural_data(datas)

    return datas, selected_targets, selected_words

def load_raws(per_subject_electrodes):
    """
    Loads raw iEEG data for multiple subjects based on specified parameters.

    This function:
    1. Iterates over subject IDs provided in the configuration.
    2. Constructs BIDS-compliant file paths to locate preprocessed high-gamma iEEG data.
    3. Loads each subject's data using MNE's `read_raw_fif` function.
    4. Optionally filters channels using a regular expression (e.g., for selecting specific electrode groups).
    5. Collects and returns a list of raw MNE objects.

    Args:
        per_subject_electrodes: dictionary mapping subject ID's to a list of string names for the electrodes to use.

    Returns:
        List[mne.io.Raw]: A list of raw iEEG recordings for the specified subjects.
    """
    raws = []
    for sub_id in subject_ids:
        file_path = BIDSPath(
            root=os.path.join(data_root, "derivatives/ecogprep"),
            subject=f"{sub_id:02}",
            task="podcast",
            datatype="ieeg",
            description="highgamma",
            suffix="ieeg",
            extension=".fif",
        )

        raw = mne.io.read_raw_fif(file_path, verbose=False)
        if per_subject_electrodes:
            subject_electrode_names = per_subject_electrodes[sub_id]
            picks = mne.pick_channels(raw.ch_names, subject_electrode_names)
            raw = raw.pick(picks)
        elif channel_reg_ex:
            picks = mne.pick_channels_regexp(raw.ch_names, channel_reg_ex)
            raw = raw.pick(picks)
        raws.append(raw)

    return raws


def read_electrode_file(file_path: str):
    """
    Parse an electrode mapping CSV file to create a subject-to-electrodes mapping.

    This function reads a CSV file containing electrode information organized by subject
    and returns a dictionary mapping each subject ID to their list of electrode names.
    Each subject can have multiple electrodes, and the electrode order is preserved
    as it appears in the CSV file.

    Args:
        file_path (str): Path to the CSV file containing electrode data.
                        The CSV must have columns 'subject' (int) and 'elec' (str).

    Returns:
        dict: A dictionary where keys are subject IDs (int) and values are lists
              of electrode names (str) for that subject. For example:
              {1: ['A1', 'A2', 'B1'], 2: ['C1', 'C2']}

    Raises:
        FileNotFoundError: If the specified file does not exist.
        KeyError: If required columns 'subject' or 'elec' are missing from the CSV.

    Example:
        >>> # CSV file contains:
        >>> # subject,elec
        >>> # 1,A1
        >>> # 1,A2
        >>> # 2,C1
        >>> result = read_electrode_file('electrodes.csv')
        >>> print(result)
        {1: ['A1', 'A2'], 2: ['C1']}
    """
    file_data = pd.read_csv(file_path)
    subjects, electrodes = file_data.subject, file_data.elec

    sub_elec_mapping = {}
    for subject, electrode in zip(subjects, electrodes):
        subject = int(subject)
        if subject not in sub_elec_mapping.keys():
            sub_elec_mapping[subject] = []

        sub_elec_mapping[subject].append(electrode)

    return sub_elec_mapping

### Get data

In [None]:
# Load all data.
if electrode_file_path:
    subject_electrode_mapping = read_electrode_file(electrode_file_path)
else:
    subject_electrode_mapping = None

raws = data_utils.load_raws(subject_electrode_mapping)
df_word = word_embedding_decoding_task()

X, Y, selected_words = data_utils.get_data(
            lag,
            raws,
            df_word,
            window_width,
            word_column=word_column,
        )

## Training

In [None]:
def setup_metrics_and_loss():
    """
    Set up metrics and loss functions from training parameters.

    Returns:
        dict: Dictionary mapping metric names to callable functions
    """
    # Combine loss and metrics into single list
    metric_names = losses + metrics

    # Resolve all functions from registry
    all_fns = {name: simple_metric_dict[name] for name in metric_names}

    return all_fns


def compute_loss(out, groundtruth, all_fns):
    loss = 0.0
    for i, loss_name in enumerate(losses):
        loss += loss_weights[i] * all_fns[loss_name](out, groundtruth)
    return loss


def validate_early_stopping_config():
    """
    Validate that early stopping configuration is valid.

    Raises:
        ValueError: If early stopping metric is not in available metrics
    """
    available_metrics = [loss_name] + metrics

    if early_stopping_metric not in available_metrics:
        raise ValueError(
            f"Early stopping metric '{early_stopping_metric}' "
            f"must be either the loss function or in the metrics list. "
            f"Available: {available_metrics}"
        )


def setup_early_stopping_state():
    """
    Set up initial state for early stopping.

    Returns:
        tuple: (best_val, patience) initial values
    """
    if smaller_is_better:
        best_val = float("inf")
    else:
        best_val = -float("inf")

    patience = 0

    return best_val, patience

In [None]:
# 1. Prepare device & output dir
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(model_dir, exist_ok=True)

# 2. Convert to tensors if needed
if isinstance(X, np.ndarray):
    X = torch.tensor(X, dtype=torch.float32)
if isinstance(Y, np.ndarray):
    Y = torch.tensor(Y, dtype=torch.float32)

# 3. Get fold indices
if training_params.fold_type == "sequential_folds":
    fold_indices = get_sequential_folds(X, num_folds=n_folds)
elif training_params.fold_type == "zero_shot_folds":
    fold_indices = get_zero_shot_folds(
        selected_words, num_folds=]n_folds
    )
else:
    raise ValueError(f"Unknown fold_type: {fold_type}")

# 4. Build a single dict of all metric functions (including loss)
all_fns = setup_metrics_and_loss()
metric_names = all_fns.keys()

# 5. Initialize CV containers
phases = ("train", "val", "test")
cv_results = {f"{phase}_{name}": [] for phase in phases for name in metric_names}
cv_results["num_epochs"] = []

# Hardcode embedding task metrics for now since they need to be handled a bit differently.
# Clean this up later. Hardcoding for now since generalizing this like other metrics would
# get complicated.
# Test type is split between "word" and "occ" where word is averaged over
# each time a word occurs and occ is per-each occurence of the word so is
# more difficult and depends on contextual embeddings.
embedding_metrics = [
    "test_word_avg_auc_roc",
    "test_word_train_weighted_auc_roc",
    "test_word_test_weighted_auc_roc",
    "test_word_perplexity",
    "test_occ_perplexity",
]

# Top-K metrics.
for k_val in top_k_thresholds:
    for test_type in ["word", "occ"]:
        embedding_metrics.append(f"test_{test_type}_top_{k_val}")

for metric in embedding_metrics:
    cv_results[metric] = []

models, histories = [], []

def run_epoch(model, loader, optimizer=None):
    """
    If optimizer is provided: does a training pass.
    Otherwise: does an eval pass.
    Returns a dict { metric_name: average_value }.
    """
    is_train = optimizer is not None
    if is_train:
        model.train()
    else:
        model.eval()

    sums = {name: 0.0 for name in metric_names}
    sums["loss"] = 0.0

    grad_steps = grad_accumulation_steps
    if is_train:
        optimizer.zero_grad()

    for i, (Xb, yb) in enumerate(loader):
        Xb, yb = Xb.to(device), yb.to(device)
        bsz = Xb.size(0)

        if is_train:
            out = model(Xb)
            loss = compute_loss(out, yb, training_params, all_fns)
            # Normalize loss to account for gradient accumulation
            loss = loss / grad_steps
            loss.backward()

            if should_update_gradient_accumulation(i, len(loader), grad_steps):
                optimizer.step()
                optimizer.zero_grad()
        else:
            with torch.no_grad():
                out = model(Xb)
                loss = compute_loss(out, yb, training_params, all_fns)

        # accumulate each metric
        for name, fn in all_fns.items():
            val = fn(out, yb)
            # get a scalar float
            if torch.is_tensor(val):
                val = val.detach().mean().item()
            sums[name] += val

        # add loss to sums
        if torch.is_tensor(loss):
            loss = loss.detach().mean().item()
        sums["loss"] += loss
    return {name: sums[name] / len(loader) for name in sums}

# 6. Cross‐val loop
for fold, (tr_idx, va_idx, te_idx) in enumerate(fold_indices, start=1):
    model_path = os.path.join(model_dir, f"best_model_fold{fold}.pt")

    # DataLoaders
    datasets = {
        "train": TensorDataset(X[tr_idx], Y[tr_idx]),
        "val": TensorDataset(X[va_idx], Y[va_idx]),
        "test": TensorDataset(X[te_idx], Y[te_idx]),
    }
    loaders = {
        phase: DataLoader(
            ds, batch_size=batch_size, shuffle=(phase == "train")
        )
        for phase, ds in datasets.items()
    }

    # Model, optimizer, early‐stop setup
    model = get_model().to(device)
    optimizer = optim.Adam(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
    )

    best_val, patience = setup_early_stopping_state(training_params)
    best_epoch = 0

    # per‐fold history (only train & val, for plotting)
    history = {
        f"{phase}_{name}": [] for phase in ("train", "val") for name in metric_names
    }
    history["train_loss"] = []
    history["val_loss"] = []
    history["num_epochs"] = None

    loop = tqdm(range(epochs), desc=f"Lag {lag}, Fold {fold}")
    for epoch in loop:
        train_mets = run_epoch(model, loaders["train"], optimizer)
        val_mets = run_epoch(model, loaders["val"])

        # record + TensorBoard
        for name, val in train_mets.items():
            history[f"train_{name}"].append(val)
        for name, val in val_mets.items():
            history[f"val_{name}"].append(val)

        # early stopping on requested metric
        cur = val_mets[early_stopping_metric]
        if should_update_best(cur, best_val, smaller_is_better):
            best_val = cur
            best_epoch = epoch
            torch.save(model.state_dict(), model_path)
            patience = 0
        else:
            patience += 1
            if patience >= early_stopping_patience:
                break

        loop.set_postfix(
            {
                early_stopping_metric: f"{best_val:.4f}",
                **{f"train_{name}": val for name, val in train_mets.items()},
                **{f"val_{name}": val for name, val in val_mets.items()},
            }
        )

    history["num_epochs"] = best_epoch + 1

    # load best and eval on test set
    model.load_state_dict(torch.load(model_path))
    test_mets = run_epoch(model, loaders["test"])

    # record into cv_results
    for name in metric_names:
        cv_results[f"train_{name}"].append(history[f"train_{name}"][best_epoch])
        cv_results[f"val_{name}"].append(history[f"val_{name}"][best_epoch])
        cv_results[f"test_{name}"].append(test_mets[name])
    cv_results["num_epochs"].append(history["num_epochs"])

    # word‐level ROC and top-k. Only useful for word embedding task.
    # Hardcoded for now since this would be a bit complicated
    # to generalize at the moment.
    results = compute_word_embedding_task_metrics(
        X[te_idx],
        Y[te_idx],
        model,
        device,
        selected_words,
        te_idx,
        tr_idx,
        top_k_thresholds,
        min_train_freq_auc,
        min_test_freq_auc,
    )
    for key, val in results.items():
        cv_results[key].append(val)

    models.append(model)
    histories.append(history)

    if plot_results:
        plot_training_history(history, fold=fold)

# 7. Print CV summary
print("\n" + "=" * 60)
print("CROSS-VALIDATION RESULTS")
print("=" * 60)
for phase in phases:
    for name in metric_names:
        vals = cv_results[f"{phase}_{name}"]
        print(f"Mean {phase} {name}: {np.mean(vals):.4f} ± {np.std(vals):.4f}")
for metric_name in embedding_metrics:
    vals = cv_results[metric_name]
    print(f"Mean {metric_name}: {np.mean(vals):.4f} ± {np.std(vals):.4f}")

if plot_results:
    plot_cv_results(cv_results)

# Key outputs: models, histories, cv_results