In [None]:
# --- Path 1: The Deep Time-Series Learner (v2.3 - Final Data Integrity Fix) ---

# Step 1: Install libraries and set up TensorFlow
!pip install lightgbm pandas numpy scikit-learn matplotlib seaborn google-colab imbalanced-learn --quiet
print("--- Libraries installed ---")

import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.utils import class_weight
import matplotlib.pyplot as plt
from google.colab import drive
import os
import glob
from tqdm.notebook import tqdm
from collections import Counter, defaultdict

# --- 2. Setup and Configuration ---
print(f"TensorFlow Version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

drive.mount('/content/drive')
DRIVE_PROJECT_ROOT = '/content/drive/MyDrive/VitalDB_Drift_Focused_Dataset'
BATCH_DIR_PATH = os.path.join(DRIVE_PROJECT_ROOT, 'preprocessed_batches')
batch_files = sorted(glob.glob(os.path.join(BATCH_DIR_PATH, 'batch_*.parquet')))

# --- Model & Data Configuration ---
FEATURES_TO_USE = [
    'BIS/BIS', 'BIS/EMG', 'BIS/SEF', 'BIS/SR',
    'Solar8000/HR', 'Solar8000/ART_MBP', 'Solar8000/ART_SBP', 'Solar8000/ART_DBP',
    'Solar8000/PLETH_SPO2', 'Solar8000/ETCO2', 'Solar8000/VENT_RR',
    'Reversion_Pressure', 'Tension_Index'
]
N_FEATURES = len(FEATURES_TO_USE)
SEQUENCE_LENGTH = 120
BATCH_SIZE = 256
DRIFT_THRESHOLD = 3.0

# --- 3. Stratified Split ---
print("\n--- Getting Patient IDs for Stratified Split ---")
patient_row_counts = Counter()
for file_path in tqdm(batch_files, desc="Counting Patient Rows"):
    df_temp = pd.read_parquet(file_path)
    patient_row_counts.update(df_temp.index.get_level_values('patient_id').value_counts().to_dict())
patient_counts_df = pd.DataFrame.from_dict(patient_row_counts, orient='index', columns=['row_count'])
available_patient_ids = patient_counts_df.index.values
patient_counts_df['duration_bin'] = pd.qcut(patient_counts_df['row_count'], q=5, labels=False, duplicates='drop')
train_ids, val_ids = train_test_split(
    available_patient_ids, test_size=0.2, random_state=42,
    stratify=patient_counts_df.loc[available_patient_ids, 'duration_bin']
)
print(f"Split into {len(train_ids)} training and {len(val_ids)} validation patients.")


# --- 4. The Corrected, Memory-Safe, and ROBUST Keras Sequence Generator ---
class SafeSequenceGenerator(keras.utils.Sequence):
    def __init__(self, patient_ids, batch_files, features, seq_length, batch_size, name="Generator"):
        self.patient_ids = set(patient_ids)
        self.batch_files = batch_files
        self.features = features
        self.seq_length = seq_length
        self.batch_size = batch_size
        self.name = name
        self.index_map = []
        self.class_counts = Counter()
        self._build_index_map()

    def _build_index_map(self):
        print(f"\n[{self.name}] Building index map and calculating class distribution...")
        for file_path in tqdm(self.batch_files, desc=f"[{self.name}] Pre-processing files"):
            df = pd.read_parquet(file_path).reset_index()
            df = df[df['patient_id'].isin(self.patient_ids)]
            if df.empty: continue

            # Defensive check for BIS/BIS before target engineering
            if 'BIS/BIS' not in df.columns: continue

            df.fillna(0, inplace=True)

            df['BIS_future_30s'] = df.groupby('patient_id')['BIS/BIS'].shift(-30)
            df['BIS_drift_30s'] = df['BIS_future_30s'] - df['BIS/BIS']
            df.dropna(subset=['BIS_drift_30s'], inplace=True)

            df['drift_class'] = 0
            df.loc[df['BIS_drift_30s'] > DRIFT_THRESHOLD, 'drift_class'] = 1
            df.loc[df['BIS_drift_30s'] < -DRIFT_THRESHOLD, 'drift_class'] = 2

            self.class_counts.update(df['drift_class'].values)

            for patient_id, patient_df in df.groupby('patient_id'):
                if len(patient_df) < self.seq_length: continue
                for i in range(len(patient_df) - self.seq_length + 1):
                    start_row_loc = patient_df.index[i]
                    self.index_map.append((file_path, start_row_loc))

        print(f"[{self.name}] Found {len(self.index_map)} total possible sequences.")

    def __len__(self):
        return int(np.floor(len(self.index_map) / self.batch_size))

    def __getitem__(self, index):
        batch_map = self.index_map[index * self.batch_size:(index + 1) * self.batch_size]

        file_to_indices = defaultdict(list)
        for file_path, start_loc in batch_map:
            file_to_indices[file_path].append(start_loc)

        X = np.zeros((self.batch_size, self.seq_length, len(self.features)))
        y = np.zeros(self.batch_size)

        current_pos = 0
        for file_path, locs in file_to_indices.items():
            df = pd.read_parquet(file_path).reset_index()

            # THE DEFINITIVE DATA INTEGRITY FIX:
            # Ensure all required feature columns exist, creating them with 0 if they don't.
            for col in self.features:
                if col not in df.columns:
                    df[col] = 0

            df.fillna(0, inplace=True)

            for start_loc in locs:
                end_loc = start_loc + self.seq_length
                seq_df = df.iloc[start_loc:end_loc]

                if len(seq_df) < self.seq_length: continue # Safety check for edge cases

                X[current_pos,] = seq_df[self.features].values

                last_point_original_index = seq_df.index[-1]
                if last_point_original_index + 30 < len(df):
                    future_bis = df.loc[last_point_original_index + 30, 'BIS/BIS']
                    current_bis = seq_df['BIS/BIS'].iloc[-1]
                    drift = future_bis - current_bis
                else:
                    drift = np.nan

                if np.isnan(drift): y[current_pos] = 0
                elif drift > DRIFT_THRESHOLD: y[current_pos] = 1
                elif drift < -DRIFT_THRESHOLD: y[current_pos] = 2
                else: y[current_pos] = 0

                current_pos += 1

        return X, to_categorical(y, num_classes=3)

# --- 5. Create Generators and Calculate Class Weights ---
train_generator = SafeSequenceGenerator(train_ids, batch_files, FEATURES_TO_USE, SEQUENCE_LENGTH, BATCH_SIZE, name="TrainGenerator")
val_generator = SafeSequenceGenerator(val_ids, batch_files, FEATURES_TO_USE, SEQUENCE_LENGTH, BATCH_SIZE, name="ValGenerator")

print("\nCalculating class weights from pre-computed distribution...")
total_samples = sum(train_generator.class_counts.values())
class_weights = {
    0: (1 / train_generator.class_counts[0]) * (total_samples / 3.0),
    1: (1 / train_generator.class_counts[1]) * (total_samples / 3.0),
    2: (1 / train_generator.class_counts[2]) * (total_samples / 3.0)
}
print(f"Calculated Class Weights: {class_weights}")


# --- 6. Build and Compile the LSTM Model ---
print("\n--- Building LSTM Model ---")
model = Sequential([
    keras.Input(shape=(SEQUENCE_LENGTH, N_FEATURES)),
    LSTM(64, return_sequences=True),
    BatchNormalization(),
    Dropout(0.3),
    LSTM(32),
    BatchNormalization(),
    Dropout(0.3),
    Dense(16, activation='relu'),
    Dense(3, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])
model.summary()


# --- 7. Train the Model ---
print("\n--- Starting Model Training ---")
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=5,
    class_weight=class_weights
)

# --- 8. Evaluate the Model ---
print("\n--- Evaluating Final Model ---")
pd.DataFrame(history.history).plot(figsize=(10, 6))
plt.grid(True)
plt.gca().set_ylim(0, 2)
plt.show()

print("Predicting on validation set...")
y_pred_proba = model.predict(val_generator)
y_pred = np.argmax(y_pred_proba, axis=1)

print("Gathering true labels from validation set...")
y_true = []
for i in tqdm(range(len(val_generator)), desc="Getting True Labels"):
    _, y_batch = val_generator[i]
    y_true.extend(np.argmax(y_batch, axis=1))

print("\n--- Classification Report ---")
print(classification_report(y_true, y_pred, target_names=['Stable', 'Drifting Up', 'Drifting Down']))

print("\n--- Confusion Matrix ---")
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Stable', 'Drifting Up', 'Drifting Down'])
fig, ax = plt.subplots(figsize=(8, 8))
disp.plot(cmap=plt.cm.Blues, ax=ax)
plt.show()

--- Libraries installed ---
TensorFlow Version: 2.18.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Mounted at /content/drive

--- Getting Patient IDs for Stratified Split ---


Counting Patient Rows:   0%|          | 0/31 [00:00<?, ?it/s]

Split into 2166 training and 542 validation patients.

[TrainGenerator] Building index map and calculating class distribution...


[TrainGenerator] Pre-processing files:   0%|          | 0/31 [00:00<?, ?it/s]

[TrainGenerator] Found 21330905 total possible sequences.

[ValGenerator] Building index map and calculating class distribution...


[ValGenerator] Pre-processing files:   0%|          | 0/31 [00:00<?, ?it/s]

[ValGenerator] Found 5334072 total possible sequences.

Calculating class weights from pre-computed distribution...
Calculated Class Weights: {0: 0.5664200934296348, 1: 1.5849460267313666, 2: 1.6567548635038714}

--- Building LSTM Model ---



--- Starting Model Training ---


  self._warn_if_super_not_called()


Epoch 1/5
[1m 1630/83323[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m26:00:32[0m 1s/step - accuracy: 0.4494 - auc: 0.6689 - loss: 0.9940