In [1]:
# --- Step 1: Environment Setup ---
!pip install vitaldb tensorflow boto3 pandas numpy matplotlib scikit-learn joblib

# --- Step 2: Imports ---
import os
import time
import vitaldb
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error
import matplotlib.pyplot as plt
import boto3
from botocore import UNSIGNED
from botocore.client import Config
import joblib # For saving the scaler

# --- Step 3: GPU Configuration ---
print("--- Verifying and Configuring GPUs ---")
gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
    try:
        for gpu in gpu_devices:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"{len(gpu_devices)} GPU(s) detected and configured.")
    except RuntimeError as e:
        print(e)
else:
    print("No GPU detected.")
print("-----------------------------------\n")

Collecting vitaldb
  Downloading vitaldb-1.5.6-py3-none-any.whl.metadata (314 bytes)
Collecting wfdb (from vitaldb)
  Downloading wfdb-4.3.0-py3-none-any.whl.metadata (3.8 kB)
Downloading vitaldb-1.5.6-py3-none-any.whl (59 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.9/59.9 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wfdb-4.3.0-py3-none-any.whl (163 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: wfdb, vitaldb
Successfully installed vitaldb-1.5.6 wfdb-4.3.0


2025-07-11 04:22:42.769030: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752207763.124224      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752207763.231868      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


--- Verifying and Configuring GPUs ---
2 GPU(s) detected and configured.
-----------------------------------



In [2]:
# --- Configuration Parameters for the Current Run ---

# <<<< SET THE SCALE FOR THIS RUN >>>>
# SMOKE TEST: 10 patients
# BENCHMARK: 100 patients
# FULL RUN: 1000 patients
NUM_PATIENTS_TO_PROCESS = 100

# <<<< SET EPOCHS FOR THIS RUN >>>>
# SMOKE TEST: 1 epoch
# BENCHMARK / FULL RUN: 25 epochs
EPOCHS = 25

# --- Static Parameters ---
VITAL_FILES_LOCAL_DIR = f'vital_files_{NUM_PATIENTS_TO_PROCESS}_patients'
MODEL_SAVE_PATH = f'vitaldb_bis_transformer_model_{NUM_PATIENTS_TO_PROCESS}.h5'
SCALER_SAVE_PATH = f'standard_scaler_{NUM_PATIENTS_TO_PROCESS}.joblib'

TRAIN_TEST_SPLIT_RATIO = 0.8
RANDOM_SEED = 42
SEQUENCE_LENGTH = 300
BATCH_SIZE = 256
LEARNING_RATE = 1e-4
CLIPNORM = 1.0

# Vitals to use as features
INPUT_VITALS = ['Solar8000/HR', 'Solar8000/ART_MBP', 'Solar8000/PLETH_SPO2', 'Solar8000/ETCO2']
TARGET_VITAL = 'BIS/BIS'

print(f"--- CONFIGURATION FOR THIS RUN ---")
print(f"Patient Count: {NUM_PATIENTS_TO_PROCESS}")
print(f"Epochs: {EPOCHS}")
print(f"Scaler Path: {SCALER_SAVE_PATH}")
print(f"Model Path: {MODEL_SAVE_PATH}")
print(f"------------------------------------")

--- CONFIGURATION FOR THIS RUN ---
Patient Count: 100
Epochs: 25
Scaler Path: standard_scaler_100.joblib
Model Path: vitaldb_bis_transformer_model_100.h5
------------------------------------


In [3]:
# --- Transformer Model Definition ---
class PositionalEncoding(layers.Layer):
    def __init__(self, sequence_length, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.position_embedding = layers.Embedding(input_dim=sequence_length, output_dim=output_dim)
        self.sequence_length = sequence_length
    def call(self, inputs):
        positions = tf.range(start=0, limit=self.sequence_length, delta=1)
        return inputs + self.position_embedding(positions)

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential([layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim)])
        self.layernorm1, self.layernorm2 = layers.LayerNormalization(epsilon=1e-6), layers.LayerNormalization(epsilon=1e-6)
        self.dropout1, self.dropout2 = layers.Dropout(rate), layers.Dropout(rate)
    def call(self, inputs, training=False):
        attn_output = self.dropout1(self.att(inputs, inputs), training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.dropout2(self.ffn(out1), training=training)
        return self.layernorm2(out1 + ffn_output)

# <<<< MODEL PARAMETERS INCREASED HERE >>>>
def build_transformer_model(sequence_length, num_features, embed_dim=128, num_heads=8, ff_dim=256, num_blocks=2):
    print(f"\n--- Building model with increased capacity: embed_dim={embed_dim}, ff_dim={ff_dim}, num_heads={num_heads} ---")
    inputs = keras.Input(shape=(sequence_length, num_features))
    x = layers.Dense(embed_dim, activation='relu')(inputs)
    x = PositionalEncoding(sequence_length, embed_dim)(x)
    for _ in range(num_blocks):
        x = TransformerBlock(embed_dim, num_heads, ff_dim)(x)
    x = layers.GlobalAveragePooling1D(data_format="channels_first")(x)
    x = layers.Dropout(0.1)(x)
    x = layers.Dense(32, activation="relu")(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Dense(1)(x)
    return keras.Model(inputs=inputs, outputs=outputs)

# --- Memory-Safe, Scaled Data Generator ---
class MemorySafeVitalDBGenerator(keras.utils.Sequence):
    def __init__(self, name, case_ids, vital_dir, batch_size, sequence_length, feature_cols, target_col, scaler):
        self.name = name
        self.case_ids = case_ids
        self.vital_dir = vital_dir
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.feature_cols = feature_cols
        self.target_col = target_col
        self.scaler = scaler
        self.patient_data = []
        self.index_map = []
        self._prepare_data()

    def _prepare_data(self):
        print(f"\n--- Preparing {self.name} Data Generator ({len(self.case_ids)} patients) ---")
        total_sequences = 0
        for patient_idx, case_id in enumerate(self.case_ids):
            file_path = os.path.join(self.vital_dir, f'{case_id:04d}.vital')
            if not os.path.exists(file_path): continue
            try:
                vf = vitaldb.VitalFile(file_path)
                df = vf.to_pandas(self.feature_cols + [self.target_col], interval=1)
                df.ffill(inplace=True)
                df.dropna(inplace=True)
                if len(df) > self.sequence_length:
                    df[self.feature_cols] = self.scaler.transform(df[self.feature_cols])
                    arr = df[self.feature_cols + [self.target_col]].values
                    self.patient_data.append(arr)
                    num_sequences_in_patient = len(arr) - self.sequence_length
                    for seq_idx in range(num_sequences_in_patient):
                        self.index_map.append((len(self.patient_data) - 1, seq_idx))
                    total_sequences += num_sequences_in_patient
            except Exception as e:
                print(f"  ERROR: Could not process case {case_id}: {e}")
        print(f"--- {self.name} Generator ready. Total sequences: {total_sequences} ---")

    def __len__(self):
        return int(np.ceil(len(self.index_map) / self.batch_size))

    def __getitem__(self, idx):
        batch_index_map = self.index_map[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = np.zeros((len(batch_index_map), self.sequence_length, len(self.feature_cols)))
        batch_y = np.zeros(len(batch_index_map))
        for i, (patient_idx, seq_start_idx) in enumerate(batch_index_map):
            seq_end_idx = seq_start_idx + self.sequence_length
            data_slice = self.patient_data[patient_idx][seq_start_idx:seq_end_idx]
            batch_x[i] = data_slice[:, :-1]
            batch_y[i] = data_slice[-1, -1]
        return batch_x, batch_y

print("Model and Generator classes defined.")

Model and Generator classes defined.


In [4]:
# --- Download Data ---
os.makedirs(VITAL_FILES_LOCAL_DIR, exist_ok=True)
s3_client = boto3.client('s3', config=Config(signature_version=UNSIGNED))
S3_BUCKET_NAME, S3_BASE_KEY = 'physionet-open', 'vitaldb/1.0.0/vital_files/'
print(f"Downloading {NUM_PATIENTS_TO_PROCESS} files...")
for i in range(1, NUM_PATIENTS_TO_PROCESS + 1):
    file_name = f'{i:04d}.vital'
    local_file_path = os.path.join(VITAL_FILES_LOCAL_DIR, file_name)
    if not os.path.exists(local_file_path):
        try: s3_client.download_file(S3_BUCKET_NAME, os.path.join(S3_BASE_KEY, file_name), local_file_path)
        except Exception as e: print(f"  ERROR downloading {file_name}: {e}")
print("--- Download complete. ---\n")

# --- Prepare Patient ID Split ---
all_case_ids = list(range(1, NUM_PATIENTS_TO_PROCESS + 1))
train_case_ids, val_case_ids = train_test_split(all_case_ids, test_size=(1 - TRAIN_TEST_SPLIT_RATIO), random_state=RANDOM_SEED)
print(f"Patient Split: {len(train_case_ids)} for Training, {len(val_case_ids)} for Validation.")

# --- Fit or Load the StandardScaler ---
if os.path.exists(SCALER_SAVE_PATH):
    print(f"\n--- Loading existing scaler from {SCALER_SAVE_PATH} ---")
    scaler = joblib.load(SCALER_SAVE_PATH)
else:
    print("\n--- Fitting new StandardScaler on Training Data ---")
    scaler = StandardScaler()
    for i, case_id in enumerate(train_case_ids):
        print(f"  ...processing patient {i+1}/{len(train_case_ids)} for scaler...")
        file_path = os.path.join(VITAL_FILES_LOCAL_DIR, f'{case_id:04d}.vital')
        if not os.path.exists(file_path): continue
        vf = vitaldb.VitalFile(file_path)
        df = vf.to_pandas(INPUT_VITALS, interval=1)
        df.ffill(inplace=True)
        df.dropna(inplace=True)
        if not df.empty:
            scaler.partial_fit(df[INPUT_VITALS])
    print(f"--- Scaler fitting complete. Saving to {SCALER_SAVE_PATH} ---")
    joblib.dump(scaler, SCALER_SAVE_PATH)

print("--- Data preparation and scaling complete. ---")

Downloading 100 files...
--- Download complete. ---

Patient Split: 80 for Training, 20 for Validation.

--- Fitting new StandardScaler on Training Data ---
  ...processing patient 1/80 for scaler...
  ...processing patient 2/80 for scaler...
  ...processing patient 3/80 for scaler...
  ...processing patient 4/80 for scaler...
  ...processing patient 5/80 for scaler...
  ...processing patient 6/80 for scaler...
  ...processing patient 7/80 for scaler...
  ...processing patient 8/80 for scaler...
  ...processing patient 9/80 for scaler...
  ...processing patient 10/80 for scaler...
  ...processing patient 11/80 for scaler...
  ...processing patient 12/80 for scaler...
  ...processing patient 13/80 for scaler...
  ...processing patient 14/80 for scaler...
  ...processing patient 15/80 for scaler...
  ...processing patient 16/80 for scaler...
  ...processing patient 17/80 for scaler...
  ...processing patient 18/80 for scaler...
  ...processing patient 19/80 for scaler...
  ...processing 

In [5]:
# --- Create Data Generators ---
# This step is now much faster as the heavy lifting (scaling) is done.
train_generator = MemorySafeVitalDBGenerator("Training", train_case_ids, VITAL_FILES_LOCAL_DIR, BATCH_SIZE, SEQUENCE_LENGTH, INPUT_VITALS, TARGET_VITAL, scaler=scaler)
val_generator = MemorySafeVitalDBGenerator("Validation", val_case_ids, VITAL_FILES_LOCAL_DIR, BATCH_SIZE, SEQUENCE_LENGTH, INPUT_VITALS, TARGET_VITAL, scaler=scaler)

# --- Check if generators are valid ---
if len(train_generator.index_map) == 0:
    print("FATAL: Training generator is empty. Cannot proceed.")
else:
    # Handle potentially empty validation generator
    validation_data_for_fit = val_generator if len(val_generator.index_map) > 0 else None
    if validation_data_for_fit is None:
        print("\nWARNING: Validation generator is empty. Proceeding with training but without validation.")

    # --- Build and Compile Model ---
    transformer_model = build_transformer_model(SEQUENCE_LENGTH, len(INPUT_VITALS))
    optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE, clipnorm=CLIPNORM)
    transformer_model.compile(optimizer=optimizer, loss='mean_squared_error', metrics=['mean_absolute_error'])
    transformer_model.summary()

    # --- Train for 1 Epoch ---
    print(f"\n--- Starting Training for {EPOCHS} epoch(s) ---")
    history = transformer_model.fit(
        train_generator, 
        epochs=EPOCHS, 
        validation_data=validation_data_for_fit, # Use the potentially None variable
        verbose=1
    )
    print("--- Training complete. ---")

    # --- Final Verification ---
    final_loss = history.history['loss'][-1]
    if np.isnan(final_loss):
        print("\n\033[91mERROR: Smoke test FAILED. Loss is NaN.\033[0m")
    else:
        print(f"\n\033[92mSUCCESS: Smoke test PASSED. Final training loss is {final_loss:.4f}.\033[0m")
        print("The model is numerically stable. It is now safe to increase patient count and epochs.")


--- Preparing Training Data Generator (80 patients) ---
--- Training Generator ready. Total sequences: 695274 ---

--- Preparing Validation Data Generator (20 patients) ---
--- Validation Generator ready. Total sequences: 159529 ---

--- Building model with increased capacity: embed_dim=128, ff_dim=256, num_heads=8 ---


I0000 00:00:1752208322.972449      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1752208322.973066      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5



--- Starting Training for 25 epoch(s) ---
Epoch 1/25


  self._warn_if_super_not_called()
I0000 00:00:1752208335.835287     495 service.cc:148] XLA service 0x7e8f2000d4c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1752208335.836712     495 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1752208335.836731     495 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1752208337.336461     495 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1752208349.712315     495 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m2716/2716[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1487s[0m 539ms/step - loss: 1270.5769 - mean_absolute_error: 30.5969 - val_loss: 210.4020 - val_mean_absolute_error: 9.0794
Epoch 2/25
[1m2716/2716[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1474s[0m 543ms/step - loss: 236.4269 - mean_absolute_error: 10.1908 - val_loss: 195.2508 - val_mean_absolute_error: 8.8207
Epoch 3/25
[1m2716/2716[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1474s[0m 543ms/step - loss: 246.4339 - mean_absolute_error: 10.2140 - val_loss: 193.0101 - val_mean_absolute_error: 8.7294
Epoch 4/25
[1m2716/2716[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1477s[0m 544ms/step - loss: 227.3939 - mean_absolute_error: 9.8586 - val_loss: 207.1676 - val_mean_absolute_error: 9.5778
Epoch 5/25
[1m2716/2716[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1477s[0m 544ms/step - loss: 234.2221 - mean_absolute_error: 9.8423 - val_loss: 197.4386 - val_mean_absolute_error: 9.0717
Epoch 6/25
[1m2716/2716[0m [3