In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from chronos import BaseChronosPipeline
from sklearn.preprocessing import StandardScaler

import os
from glob import glob
from pathlib import Path
from tqdm.notebook import tqdm
import ipywidgets as widgets
from IPython.display import display, clear_output
from matplotlib import pyplot as plt

from modules import * 

# override all pandas display limits
pd.options.display.max_columns = None
pd.options.display.max_rows = None

In [2]:
# load parquet files
data_path = Path("../../data")
notebooks_path = Path(os.getcwd())
data_dir = {}

##unsafe
# for file_path in list((notebooks_path / data_path).glob("*.parquet")):
#     print(f"Reading {file_path}")
#     # retrieve the name of the file without the extension for all OS
#     data = pd.read_parquet(file_path)
#     # if "Time" in df.columns:
#     #     df["Time"] = pd.to_datetime(df["Time"])
#     data_dir[str(file_path).replace("\\", "/").split("/")[-1].split(".")[0].replace("-", "_")] = data

for file_path in (notebooks_path / data_path).glob("set*.parquet"):
    var_name = file_path.stem.replace("-", "_")
    globals()[var_name] = pd.read_parquet(file_path)
    print(f"Reading {file_path} -> {var_name}")


ID_vars = ["PatientID", "Time", "RecordID"] # TODO 
# stationary variables
stationary_vars = ["Age", "Gender", "Height"] #, "ICUType"]
# dynamic variables
dynamic_vars = set_a.columns.difference(stationary_vars + ID_vars + ['In-hospital_death']).tolist()

feature_cols = dynamic_vars + stationary_vars

Reading /home/pbaertschi/ICU-TimeSeries-Mortality-Prediction/notebooks/4_Foundation_Models/../../data/set-a-filled.parquet -> set_a_filled
Reading /home/pbaertschi/ICU-TimeSeries-Mortality-Prediction/notebooks/4_Foundation_Models/../../data/set-b.parquet -> set_b
Reading /home/pbaertschi/ICU-TimeSeries-Mortality-Prediction/notebooks/4_Foundation_Models/../../data/set-a.parquet -> set_a
Reading /home/pbaertschi/ICU-TimeSeries-Mortality-Prediction/notebooks/4_Foundation_Models/../../data/set-c-filled.parquet -> set_c_filled
Reading /home/pbaertschi/ICU-TimeSeries-Mortality-Prediction/notebooks/4_Foundation_Models/../../data/set-c.parquet -> set_c
Reading /home/pbaertschi/ICU-TimeSeries-Mortality-Prediction/notebooks/4_Foundation_Models/../../data/set-b-filled.parquet -> set_b_filled


In [3]:
def preprocess_parquet_for_lstm(key, scaler=None, fit_scaler=False):
    labelname = 'In-hospital_death'
    df = globals()[f"set_{key}_filled"].copy()

    # Sort and fill NaNs
    df = df.sort_values(["RecordID", "Time"])
    df[feature_cols] = df[feature_cols].fillna(0)

    # raise NotImplementedError("Encode cathegories")

    # --- Fit scaler on all feature data if requested ---
    if fit_scaler or scaler is None:
        scaler = StandardScaler()
        scaler.fit(df[feature_cols])

    # --- Apply scaling ---
    df[feature_cols] = scaler.transform(df[feature_cols])

    # Group by patient
    X = []
    y = []
    for pid, group in df.groupby("RecordID"):
        group = group.sort_values("Time")
        X.append(group[feature_cols].values)
        y.append(group[labelname].iloc[0])

    X_tensor = torch.tensor(np.stack(X)).float()  # (n_patients, seq_len, n_features)
    y_tensor = torch.tensor(y).float()            # (n_patients,)

    return X_tensor, y_tensor, scaler  # return scaler for reuse on val/test


X_train, y_train, fitted_scaler = preprocess_parquet_for_lstm("a", fit_scaler = True)
# len(preprocess_parquet_for_lstm("b", scaler = fitted_scaler))
X_val, y_val , _    = preprocess_parquet_for_lstm("b", scaler = fitted_scaler)
X_test, y_test , _  = preprocess_parquet_for_lstm("c", scaler = fitted_scaler)


# extract dimensions
num_patients, sequence_length, num_features  = X_train.size()

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the pretrained Chronos model
model_id = "amazon/chronos-t5-small"  # Example model; choose the one that fits your needs
pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",  # use "amazon/chronos-bolt-small" for the corresponding Chronos-Bolt model
    device_map=device,  # use "cpu" for CPU inference
    torch_dtype=torch.float16,
)
# Manually move tokenizer boundaries to match model device
pipeline.tokenizer.boundaries = pipeline.tokenizer.boundaries.to(pipeline.model.device)

# Fix Chronos tokenizer device mismatch when appending EOS token
def patched_append_eos_token(self, token_ids, attention_mask):
    device = token_ids.device
    batch_size = token_ids.shape[0]
    eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id, device=device)
    eos_mask = torch.full((batch_size, 1), fill_value=True, device=device)
    token_ids = torch.concat((token_ids, eos_tokens), dim=1)
    attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
    return token_ids, attention_mask

# Patch the method
pipeline.tokenizer._append_eos_token = patched_append_eos_token.__get__(pipeline.tokenizer)

In [5]:
# X_train_embeddings = get_chronos_embeddings(X_train, num_features, pipeline, return_stack = False)
# X_val_embeddings = get_chronos_embeddings(X_val, num_features, pipeline, return_stack = False)
# X_test_embeddings = get_chronos_embeddings(X_test, num_features, pipeline, return_stack = False)
# torch.save(X_train_embeddings, notebooks_path / data_path / Path('X_train_embeddings_Q431.pt'))
# torch.save(X_val_embeddings, notebooks_path / data_path / Path('X_val_embeddings_Q431.pt'))
# torch.save(X_test_embeddings, notebooks_path / data_path / Path('X_test_embeddings_Q431.pt'))

## Part 1: fixed pooling

In [48]:
# ==== 1. Define Model ====
class LinearProbe(nn.Module):
    def __init__(self, input_dim, n_classes = 1):
        super().__init__()
        self.layer = nn.Linear(input_dim, n_classes)

    def forward(self, x):
        x_flat = x.view(x.shape[0], -1) # flatten all dimensions but batch
        # print(x_flat.shape)
        return self.layer(x_flat).squeeze(-1)

X_train_embeddings = torch.load(notebooks_path / data_path / Path('X_train_embeddings_Q431.pt')).float()
X_val_embeddings = torch.load(notebooks_path / data_path / Path('X_val_embeddings_Q431.pt')).float()
X_test_embeddings = torch.load(notebooks_path / data_path / Path('X_test_embeddings_Q431.pt')).float()

x = X_train_embeddings[0].unsqueeze(0)

lp = LinearProbe(input_dim = x.view(1, -1).shape[-1])
lp(x).shape
# lpP = LinearProbe(input_dim=, num_segments =

torch.Size([1])

In [50]:
# ==== 2. Training & Evaluation Logic ====
def evaluate(model, loader, device):
    model.eval()
    all_logits, all_labels = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            probs = torch.sigmoid(logits)
            all_logits.append(probs.cpu())
            all_labels.append(yb.cpu())
    y_true = torch.cat(all_labels).numpy()
    y_score = torch.cat(all_logits).numpy()
    
    # print("y_true:", np.unique(y_true))
    # print("y_score shape:", y_score.shape, "dtype:", y_score.dtype)

    return {
        'AUROC': roc_auc_score(y_true, y_score),
        'AUPRC': average_precision_score(y_true, y_score)
    }
    
def train_model(X_train, y_train, X_val, y_val, X_test, y_test, model, 
                epochs=20, batch_size=32, lr=1e-3):
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Wrap tensors in datasets
    train_ds = TensorDataset(X_train, y_train)
    val_ds = TensorDataset(X_val, y_val)
    test_ds = TensorDataset(X_test, y_test)

    # Create data loaders
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size)
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    # Define model
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    # Training loop
    for epoch in tqdm(range(epochs)):
        model.train()
        for xb, yb in train_loader:
            # print(xb.shape)
            xb, yb = xb.to(device), yb.to(device).float()
            logits = model(xb)
            # print(logits.device)
            # return
            # print(logits.shape)
            # print(yb.shape)
            # return
            loss = loss_fn(logits, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_metrics = evaluate(model, val_loader, device)
        print(f"[Epoch {epoch+1}] Val AUROC: {val_metrics['AUROC']:.4f}, AUPRC: {val_metrics['AUPRC']:.4f}")
        # return

    print("✅ Finished training. Evaluating on test set...")
    test_metrics = evaluate(model, test_loader, device)
    print(f"Test AUROC: {test_metrics['AUROC']:.4f}, AUPRC: {test_metrics['AUPRC']:.4f}")

    return model, test_metrics

model = LinearProbe(input_dim = x.view(1, -1).shape[-1])
# print(x.view(1, -1).shape[-1])
model, test_metrics = train_model(X_train_embeddings, y_train, X_val_embeddings, y_val, X_test_embeddings, y_test, model, epochs=20, batch_size=32, lr=1e-3)

  0%|          | 0/20 [00:00<?, ?it/s]

[Epoch 1] Val AUROC: 0.7358, AUPRC: 0.3557
[Epoch 2] Val AUROC: 0.7580, AUPRC: 0.3758
[Epoch 3] Val AUROC: 0.7743, AUPRC: 0.4015
[Epoch 4] Val AUROC: 0.7791, AUPRC: 0.4086
[Epoch 5] Val AUROC: 0.7828, AUPRC: 0.4163
[Epoch 6] Val AUROC: 0.7850, AUPRC: 0.4212
[Epoch 7] Val AUROC: 0.7868, AUPRC: 0.4279
[Epoch 8] Val AUROC: 0.7876, AUPRC: 0.4274
[Epoch 9] Val AUROC: 0.7897, AUPRC: 0.4334
[Epoch 10] Val AUROC: 0.7910, AUPRC: 0.4349
[Epoch 11] Val AUROC: 0.7923, AUPRC: 0.4385
[Epoch 12] Val AUROC: 0.7936, AUPRC: 0.4427
[Epoch 13] Val AUROC: 0.7943, AUPRC: 0.4436
[Epoch 14] Val AUROC: 0.7949, AUPRC: 0.4458
[Epoch 15] Val AUROC: 0.7952, AUPRC: 0.4452
[Epoch 16] Val AUROC: 0.7958, AUPRC: 0.4475
[Epoch 17] Val AUROC: 0.7960, AUPRC: 0.4486
[Epoch 18] Val AUROC: 0.7970, AUPRC: 0.4500
[Epoch 19] Val AUROC: 0.7973, AUPRC: 0.4502
[Epoch 20] Val AUROC: 0.7974, AUPRC: 0.4502
✅ Finished training. Evaluating on test set...
Test AUROC: 0.7825, AUPRC: 0.4133


## Part 2: learnt pooling

In [None]:
# # first we need to get the stacked embeddings -> these are very memory intensive so we can only have one in memory at a time
# X_train_embeddings = get_chronos_embeddings(X_train, num_features, pipeline, return_stack = True)
# torch.save(X_train_embeddings, notebooks_path / data_path / Path('X_train_embeddings_Q431_stacked.pt'))
# del X_train_embeddings

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

In [7]:
# X_val_embeddings = get_chronos_embeddings(X_val, num_features, pipeline, return_stack = True)
# torch.save(X_val_embeddings, notebooks_path / data_path / Path('X_val_embeddings_Q431_stacked.pt'))
# del X_val_embeddings

  0%|          | 0/41 [00:00<?, ?it/s]

In [None]:
X_test_embeddings = get_chronos_embeddings(X_test, num_features, pipeline, return_stack = True)
torch.save(X_test_embeddings, notebooks_path / data_path / Path('X_test_embeddings_Q431_stacked.pt'), _use_new_zipfile_serialization=False)
del X_test_embeddings

  0%|          | 0/41 [00:00<?, ?it/s]

In [53]:
class LinearProbe_LearnablePooling(nn.Module): 
    def __init__(self, input_dim, num_segments,  n_features, n_classes = 1):
        super().__init__()
        # pooling weights 
        self.segment_weights = nn.Parameter(torch.randn(1, num_segments, 1))

        # linear probe
        self.layer = nn.Linear(input_dim, n_classes)

    def forward(self, x):
        # weight
        # Shape: (1, num_segments, 1) -> (batch_size, num_segments, 1)
        weights = F.softmax(self.segment_weights, dim=1)  # softmax over segments
        weighted = x * weights  # broadcasting over batch and feature dim
        pooled = weighted.sum(dim=1)  # sum over segments
        return self.layer(pooled).squeeze(-1)
        