# üè• EHR-M-GAN Training on Google Colab

**Version:** 2.0 (Production-Ready)

**Objectif:** G√©n√©rer des donn√©es m√©dicales synth√©tiques de haute qualit√©

**Dataset:** eICU-CRD Demo (1,650 patients)

---

## ‚öôÔ∏è Configuration

**GPU recommand√©:** T4 ou A100

**Runtime:** GPU (obligatoire)

**Temps estim√©:** 6-8 heures total

## üì¶ √âtape 1 : Setup Initial

In [None]:
# Monter Google Drive
from google.colab import drive
drive.mount('/content/drive')

# V√©rifier acc√®s
!ls "/content/drive/MyDrive/" | head -10

In [None]:
# V√©rifier GPU disponible
!nvidia-smi

In [None]:
# Cloner le repository
import os

# Supprimer si existe d√©j√†
!rm -rf /content/ehrMGAN

# Cloner
!git clone https://github.com/jli0117/ehrMGAN.git /content/ehrMGAN

# Se placer dans le dossier
%cd /content/ehrMGAN

# V√©rifier structure
!ls -la

## üîß √âtape 2 : Installation TensorFlow 1.15

In [None]:
# CRITIQUE : Downgrade vers TensorFlow 1.15
!pip uninstall -y tensorflow tensorflow-gpu -q
!pip install tensorflow-gpu==1.15.5 -q

# V√©rifier version
import tensorflow as tf
print(f"‚úÖ TensorFlow version: {tf.__version__}")
assert tf.__version__.startswith('1.15'), "‚ùå TensorFlow 1.15 requis!"

## üìö √âtape 3 : Installer D√©pendances

In [None]:
# Installer toutes les d√©pendances (versions test√©es)
!pip install --upgrade pip setuptools wheel -q

# Core dependencies
!pip install numpy==1.19.5 -q
!pip install pandas==1.1.5 -q
!pip install scipy==1.5.4 -q
!pip install scikit-learn==0.24.2 -q
!pip install matplotlib==3.3.4 -q
!pip install seaborn==0.11.2 -q
!pip install h5py==2.10.0 -q
!pip install tqdm==4.64.1 -q
!pip install pyyaml==5.4.1 -q

# PyTorch (pour contrastive loss)
!pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 -q

# V√©rifier imports
import numpy as np
import pandas as pd
import tensorflow as tf
import torch
print("‚úÖ Toutes les d√©pendances install√©es")

## üìÅ √âtape 4 : Pr√©parer les Donn√©es

In [None]:
# Configuration des chemins
import shutil

# Chemins Drive (√† adapter selon votre structure)
DRIVE_DATA = "/content/drive/MyDrive/ehrMGAN_data/eicu-crd-demo-2.0.1"
DRIVE_FIXES = "/content/drive/MyDrive/ehrMGAN_fixes"
DRIVE_CHECKPOINT = "/content/drive/MyDrive/ehrMGAN_checkpoints"

# Chemins locaux
LOCAL_DATA = "/content/ehrMGAN/preprocessing_physionet-main/eicu_preprocess/data"

# Cr√©er dossiers
os.makedirs(LOCAL_DATA, exist_ok=True)
os.makedirs(DRIVE_CHECKPOINT, exist_ok=True)

print("‚úÖ Dossiers cr√©√©s")

In [None]:
# Copier donn√©es eICU depuis Drive
required_files = [
    "patient.csv.gz",
    "vitalPeriodic.csv.gz",
    "infusionDrug.csv.gz",
    "respiratoryCare.csv.gz"
]

print("Copie des donn√©es eICU...")
for file in required_files:
    src = os.path.join(DRIVE_DATA, file)
    dst = os.path.join(LOCAL_DATA, file)
    if os.path.exists(src):
        shutil.copy2(src, dst)
        print(f"  ‚úÖ {file}")
    else:
        print(f"  ‚ùå MANQUANT : {file}")
        print(f"     T√©l√©chargez depuis PhysioNet et uploadez dans Drive")

# V√©rifier
!ls -lh {LOCAL_DATA}

In [None]:
# Copier fichiers de fix depuis Drive
print("Application des fixes...")

# visualise.py
if os.path.exists(f"{DRIVE_FIXES}/visualise.py"):
    !cp "{DRIVE_FIXES}/visualise.py" /content/ehrMGAN/evaluation_metrics/
    print("  ‚úÖ visualise.py")
else:
    print("  ‚ö†Ô∏è  visualise.py manquant (pourrait causer erreur)")

# utils/
if os.path.exists(f"{DRIVE_FIXES}/utils"):
    !cp -r "{DRIVE_FIXES}/utils/"*.py /content/ehrMGAN/preprocessing_physionet-main/eicu_preprocess/utils/
    print("  ‚úÖ utils/*.py")

# preprocessing_eicu_complete.py
if os.path.exists(f"{DRIVE_FIXES}/preprocessing_eicu_complete.py"):
    !cp "{DRIVE_FIXES}/preprocessing_eicu_complete.py" /content/ehrMGAN/preprocessing_physionet-main/eicu_preprocess/
    print("  ‚úÖ preprocessing_eicu_complete.py")

print("\n‚úÖ Fixes appliqu√©s")

## üîÑ √âtape 5 : Preprocessing

In [None]:
# Se placer dans le dossier preprocessing
%cd /content/ehrMGAN/preprocessing_physionet-main/eicu_preprocess

# Lancer preprocessing
!python preprocessing_eicu_complete.py \
  --data_path ./data \
  --output_path ../../data/real/eicu \
  --time_window 24 \
  --min_length 12 \
  --max_length 240 \
  --age_min 18 \
  --verbose

# V√©rifier outputs
print("\n" + "="*50)
print("Fichiers g√©n√©r√©s :")
!ls -lh ../../data/real/eicu/

In [None]:
# OPTIONNEL : Fix pickle protocol si erreur
import pickle

def convert_pickle_protocol(file_path):
    """Convertir pickle protocol 5 ‚Üí 4"""
    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        with open(file_path, 'wb') as f:
            pickle.dump(data, f, protocol=4)
        print(f"‚úÖ Converti : {file_path}")
    except Exception as e:
        print(f"‚ùå Erreur : {e}")

# Convertir tous les .pkl
pkl_files = [
    "../../data/real/eicu/vital_sign_24hrs.pkl",
    "../../data/real/eicu/med_interv_24hrs.pkl",
    "../../data/real/eicu/statics.pkl"
]

for pkl_file in pkl_files:
    if os.path.exists(pkl_file):
        convert_pickle_protocol(pkl_file)

## üéì √âtape 6 : Configuration Training

In [None]:
# Revenir au dossier principal
%cd /content/ehrMGAN

# Param√®tres optimis√©s pour Colab
BATCH_SIZE = 128          # R√©duit pour √©viter OOM
NUM_PRE_EPOCHS = 500      # Pretraining VAE
NUM_EPOCHS = 800          # Training adversarial
CHECKPOINT_FREQ = 50      # Sauvegardes fr√©quentes
Z_DIM = 25                # Dimension latente

# Cr√©er dossiers de sortie
!mkdir -p data/checkpoint
!mkdir -p data/fake
!mkdir -p logs/visualizations

print(f"""Configuration Training:
- Batch Size: {BATCH_SIZE}
- Pretraining Epochs: {NUM_PRE_EPOCHS}
- Training Epochs: {NUM_EPOCHS}
- Checkpoint Freq: {CHECKPOINT_FREQ}
- Latent Dim: {Z_DIM}
""")

In [None]:
# Activer anti-d√©connexion Colab
from IPython.display import display, Javascript

display(Javascript('''
    function KeepClicking(){
        console.log("Keeping session alive");
        document.querySelector("colab-toolbar-button#connect").click();
    }
    setInterval(KeepClicking, 60000);
'''))

print("‚úÖ Anti-d√©connexion activ√© (click toutes les 60s)")

## üöÄ √âtape 7 : Phase 1 - Pretraining VAE

In [None]:
# Lancer pretraining VAE
!python main_train.py \
  --dataset eicu \
  --data_path ./data/real/eicu \
  --batch_size {BATCH_SIZE} \
  --num_pre_epochs {NUM_PRE_EPOCHS} \
  --num_epochs 0 \
  --epoch_ckpt_freq {CHECKPOINT_FREQ} \
  --z_dim {Z_DIM} \
  --conditional False

print("\n" + "="*50)
print("‚úÖ Pretraining VAE termin√©")

In [None]:
# Sauvegarder checkpoint pretraining dans Drive
import shutil
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_name = f"pretraining_complete_{timestamp}"
checkpoint_dir = f"{DRIVE_CHECKPOINT}/{checkpoint_name}"

shutil.copytree("data/checkpoint", checkpoint_dir, dirs_exist_ok=True)
print(f"‚úÖ Checkpoint sauvegard√© : {checkpoint_dir}")

## üî• √âtape 8 : Phase 2 - Training Adversarial

In [None]:
# Lancer training GAN
!python main_train.py \
  --dataset eicu \
  --data_path ./data/real/eicu \
  --batch_size {BATCH_SIZE} \
  --num_pre_epochs 0 \
  --num_epochs {NUM_EPOCHS} \
  --epoch_ckpt_freq {CHECKPOINT_FREQ} \
  --z_dim {Z_DIM} \
  --conditional False \
  --resume_training True

print("\n" + "="*50)
print("‚úÖ Training GAN termin√©")

In [None]:
# Sauvegarder r√©sultats finaux dans Drive
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
final_export = f"{DRIVE_CHECKPOINT}/FINAL_EXPORT_{timestamp}"
os.makedirs(final_export, exist_ok=True)

# Copier tout
shutil.copytree("data/checkpoint", f"{final_export}/checkpoints", dirs_exist_ok=True)
shutil.copytree("data/fake", f"{final_export}/synthetic_data", dirs_exist_ok=True)
shutil.copytree("logs", f"{final_export}/logs", dirs_exist_ok=True)

print(f"‚úÖ Export final complet : {final_export}")

## üìä √âtape 9 : Validation R√©sultats

In [None]:
# Charger donn√©es synth√©tiques
import pickle

with open("data/fake/c_gen_data.pkl", "rb") as f:
    c_gen_data = pickle.load(f)

with open("data/fake/d_gen_data.pkl", "rb") as f:
    d_gen_data = pickle.load(f)

# Charger donn√©es r√©elles
with open("data/real/eicu/vital_sign_24hrs.pkl", "rb") as f:
    c_real_data = pickle.load(f)

with open("data/real/eicu/med_interv_24hrs.pkl", "rb") as f:
    d_real_data = pickle.load(f)

print(f"Synth√©tiques - Continues: {c_gen_data.shape}, Discr√®tes: {d_gen_data.shape}")
print(f"R√©elles - Continues: {c_real_data.shape}, Discr√®tes: {d_real_data.shape}")

In [None]:
# Visualiser comparaisons
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")

num_samples = 5
feature_names = ['Heart Rate', 'SpO2', 'SBP', 'DBP', 'Temp', 'Resp Rate', 'GCS']

fig, axes = plt.subplots(num_samples, 2, figsize=(14, 10))

for i in range(num_samples):
    # R√©el
    axes[i, 0].plot(c_real_data[i], alpha=0.7, linewidth=0.8)
    axes[i, 0].set_title(f"Patient {i+1} - R√©el")
    axes[i, 0].set_ylabel("Valeur normalis√©e")
    
    # Synth√©tique
    axes[i, 1].plot(c_gen_data[i], alpha=0.7, linewidth=0.8)
    axes[i, 1].set_title(f"Patient {i+1} - Synth√©tique")
    
    if i == num_samples - 1:
        axes[i, 0].set_xlabel("Heure")
        axes[i, 1].set_xlabel("Heure")

plt.tight_layout()
plt.savefig(f"{DRIVE_CHECKPOINT}/comparison_trajectories.png", dpi=150)
plt.show()

In [None]:
# Calculer MMD (Maximum Mean Discrepancy)
from evaluation_metrics.max_mean_discrepency import mmd_rbf
import numpy as np

mmd_scores = []
print("Calcul MMD par feature:\n")
print(f"{'Feature':<15} {'MMD':<10}")
print("-" * 25)

for feat_idx in range(c_real_data.shape[2]):
    real_feat = c_real_data[:, :, feat_idx].reshape(-1)
    gen_feat = c_gen_data[:, :, feat_idx].reshape(-1)
    
    mmd = mmd_rbf(real_feat, gen_feat)
    mmd_scores.append(mmd)
    print(f"{feature_names[feat_idx]:<15} {mmd:.6f}")

print("\n" + "="*25)
print(f"MMD moyen : {np.mean(mmd_scores):.6f}")
print("\nüéØ Cible : < 0.05 (excellent), < 0.10 (bon)")

In [None]:
# Discriminative Score
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Pr√©parer donn√©es
X_real = c_real_data.reshape(len(c_real_data), -1)
X_gen = c_gen_data.reshape(len(c_gen_data), -1)

X = np.vstack([X_real, X_gen])
y = np.hstack([np.ones(len(X_real)), np.zeros(len(X_gen))])

# Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Entra√Æner
clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
clf.fit(X_train, y_train)

# Pr√©dire
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)

print(f"\nDiscriminative Score: {accuracy:.4f}")
print("üéØ Cible : ~0.50 (id√©al = indistinguable)")

In [None]:
# Cr√©er rapport final
from datetime import datetime

report = f"""
{'='*60}
EHR-M-GAN TRAINING REPORT
{'='*60}

Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
Dataset: eICU-CRD Demo
GPU: {!nvidia-smi --query-gpu=name --format=csv,noheader | head -1}

{'='*60}
CONFIGURATION
{'='*60}
- Batch Size: {BATCH_SIZE}
- Pretraining Epochs: {NUM_PRE_EPOCHS}
- Training Epochs: {NUM_EPOCHS}
- Latent Dimension: {Z_DIM}
- Checkpoint Frequency: {CHECKPOINT_FREQ}

{'='*60}
DONN√âES
{'='*60}
- Patients trait√©s: {len(c_real_data)}
- Fen√™tre temporelle: 24 heures
- Features continues: {c_real_data.shape[2]} (vital signs)
- Features discr√®tes: {d_real_data.shape[2]} (interventions)

{'='*60}
M√âTRIQUES QUALIT√â
{'='*60}
- MMD moyen: {np.mean(mmd_scores):.6f} {'‚úÖ' if np.mean(mmd_scores) < 0.10 else '‚ö†Ô∏è'}
- Discriminative Score: {accuracy:.4f} {'‚úÖ' if 0.45 <= accuracy <= 0.55 else '‚ö†Ô∏è'}

MMD par feature:
"""

for fname, mmd in zip(feature_names, mmd_scores):
    report += f"  - {fname:<15}: {mmd:.6f}\n"

report += f"""
{'='*60}
FICHIERS G√âN√âR√âS
{'='*60}
- Checkpoints: {final_export}/checkpoints/
- Synthetic Data: {final_export}/synthetic_data/
- Logs: {final_export}/logs/
- Visualizations: {DRIVE_CHECKPOINT}/comparison_trajectories.png

{'='*60}
STATUT
{'='*60}
‚úÖ Training compl√©t√© avec succ√®s
‚úÖ Donn√©es synth√©tiques g√©n√©r√©es
‚úÖ Sauvegarde Drive effectu√©e

{'='*60}
"""

print(report)

# Sauvegarder rapport
with open(f"{final_export}/TRAINING_REPORT.txt", "w") as f:
    f.write(report)

print(f"\n‚úÖ Rapport sauvegard√© : {final_export}/TRAINING_REPORT.txt")

## ‚úÖ Checklist Finale

In [None]:
# V√©rifier que tout est sauvegard√©
import os

checklist = [
    (f"{final_export}/checkpoints", "Checkpoints mod√®le"),
    (f"{final_export}/synthetic_data", "Donn√©es synth√©tiques"),
    (f"{final_export}/logs", "Logs training"),
    (f"{final_export}/TRAINING_REPORT.txt", "Rapport final"),
    (f"{DRIVE_CHECKPOINT}/comparison_trajectories.png", "Visualisations")
]

print("V√©rification des fichiers:\n")
all_ok = True
for path, desc in checklist:
    exists = os.path.exists(path)
    status = "‚úÖ" if exists else "‚ùå"
    print(f"{status} {desc}: {path}")
    if not exists:
        all_ok = False

print("\n" + "="*60)
if all_ok:
    print("‚úÖ TOUT EST SAUVEGARD√â - Vous pouvez fermer le notebook")
else:
    print("‚ö†Ô∏è  ATTENTION - Fichiers manquants, v√©rifiez les erreurs ci-dessus")

---

## üéâ FIN DU TRAINING

### Prochaines √©tapes:

1. **T√©l√©charger les r√©sultats** depuis Drive
2. **Analyser les m√©triques** (MMD, Discriminative Score)
3. **Valider downstream** (mod√®les pr√©dictifs)
4. **Publier r√©sultats** (article, GitHub)

### Support:

- **Documentation**: [COLAB_SETUP_GUIDE.md](COLAB_SETUP_GUIDE.md)
- **GitHub**: https://github.com/jli0117/ehrMGAN
- **Article**: https://arxiv.org/abs/2112.12047

---

*Notebook cr√©√© par [Votre Nom] - Version 2.0 (F√©vrier 2026)*