# ehrMGAN: Electronic Health Records Multiple Generative Adversarial Networks
This notebook implements the complete ehrMGAN model for generating synthetic electronic health records data. The model combines VAEs and GANs to generate both continuous (vital signs) and discrete (medications/interventions) time series data.

## 1. Environment Setup
First, let's set up our environment with the necessary imports and check the configuration.

In [1]:
# ====== Import necessary libraries ======
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import pickle
import timeit
import argparse
import warnings
import sys
from tqdm import tqdm
import GPUtil
warnings.filterwarnings('ignore')

# Ensure TensorFlow 2.x behavior
print(f"TensorFlow version: {tf.__version__}")
if tf.__version__.startswith('2'):
    # Enable memory growth for GPU
    physical_devices = tf.config.list_physical_devices('GPU')
    if len(physical_devices) > 0:
        for device in physical_devices:
            tf.config.experimental.set_memory_growth(device, True)
            print(f"Memory growth enabled for {device}")
else:
    print("Warning: This code requires TensorFlow 2.x")

TensorFlow version: 2.5.0
Memory growth enabled for PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [2]:
# ====== System Information ======
print(f"TensorFlow version: {tf.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"System: {sys.platform}")
print(f"Python version: {sys.version.split()[0]}")

# Processor information
import platform
print(f"Processor: {platform.processor()}")

# Memory information
import psutil
print(f"Total memory: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
print(f"Available memory: {psutil.virtual_memory().available / (1024 ** 3):.2f} GB")

try:
    # GPU information
    gpus = GPUtil.getGPUs()
    for i, gpu in enumerate(gpus):
        print(f"GPU {i}: {gpu.name}, Memory: {gpu.memoryTotal} MB")
        print(f"   Memory used: {gpu.memoryUsed} MB, Load: {gpu.load*100:.1f}%")
except:
    print("No GPU detected or GPUtil not installed")

# Check TensorFlow GPU
print("\nTensorFlow GPU availability:")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

TensorFlow version: 2.5.0
NumPy version: 1.20.3
Pandas version: 1.5.2
System: win32
Python version: 3.8.8
Processor: AMD64 Family 25 Model 33 Stepping 0, AuthenticAMD
Total memory: 31.93 GB
Available memory: 16.46 GB
GPU 0: NVIDIA GeForce RTX 3060, Memory: 12288.0 MB
   Memory used: 726.0 MB, Load: 3.0%

TensorFlow GPU availability:
GPU available: True


### 2. Import Model Components
Now let's import the necessary model components from our files.

In [3]:
# ====== Load modules ======
# Import our custom modules
from m3gan_tf2 import m3gan
from networks_tf2 import C_VAE_NET, D_VAE_NET, C_GAN_NET, D_GAN_NET

### 3. Data Loading and Preprocessing
Let's load and preprocess the data.

In [4]:
# ====== Load data ======
# Dataset parameters
patinet_num = 16062
filename_postfix = '5_var'

# Load continuous data (vital signs)
continuous_x = np.loadtxt(f'data/real/mimic/vital_sign_24hrs_{filename_postfix}_mimiciv.txt')
continuous_x = continuous_x.reshape(patinet_num, 24, 5)
c_dim = continuous_x.shape[-1]

# Load discrete data (medications/interventions)
discrete_x = np.loadtxt(f'data/real/mimic/med_interv_24hrs_{filename_postfix}_mimiciv.txt')
discrete_x = discrete_x.reshape(patinet_num, 24, 1)
d_dim = discrete_x.shape[-1]

# Load static data (patient demographics)
statics_label = pd.read_csv(f'data/real/mimic/static_data_{filename_postfix}_mimiciv.csv')
statics_label = np.asarray(statics_label)[:, 0].reshape([-1, 1])

print(f"Continuous data shape: {continuous_x.shape}")
print(f"Discrete data shape: {discrete_x.shape}")
print(f"Static labels shape: {statics_label.shape}")

# Data parameters
time_steps = continuous_x.shape[1]
conditional = True   # Whether to use conditional GAN
num_labels = 1 if conditional else 0  # Number of conditional labels

Continuous data shape: (16062, 24, 5)
Discrete data shape: (16062, 24, 1)
Static labels shape: (16062, 1)


### 4. Model Configuration and Hyperparameter Tuning
Let's define our hyperparameters based on the data dimensions.

In [5]:
# ====== Model hyperparameters ======
# Batch size and epochs
batch_size = 64
num_pre_epochs = 50   # Epochs for pretraining
num_epochs = 200      # Epochs for adversarial training

# VAE parameters
c_z_size = 32         # Latent dimension for continuous VAE
d_z_size = 16         # Latent dimension for discrete VAE
c_noise_dim = 64      # Noise dimension for continuous generator
d_noise_dim = 32      # Noise dimension for discrete generator

# Network architecture
enc_size = 128        # Hidden units in encoder LSTM
dec_size = 128        # Hidden units in decoder LSTM
enc_layers = 1        # Number of encoder layers
dec_layers = 1        # Number of decoder layers
gen_num_units = 128   # Hidden units in generator
gen_num_layers = 1    # Number of generator layers
dis_num_units = 128   # Hidden units in discriminator
dis_num_layers = 1    # Number of discriminator layers
keep_prob = 0.8       # Keep probability for dropout
l2_scale = 0.001      # L2 regularization scale

# Training parameters
d_rounds = 1          # Discriminator training rounds per step
g_rounds = 1          # Generator training rounds per step
v_rounds = 1          # VAE training rounds per step

# Learning rates
v_lr_pre = 0.001      # VAE pretraining learning rate
v_lr = 0.0001         # VAE learning rate
g_lr = 0.0001         # Generator learning rate
d_lr = 0.0001         # Discriminator learning rate

# Weight parameters for losses
alpha_re = 1.0        # Reconstruction loss weight
alpha_kl = 0.1        # KL divergence loss weight
alpha_mt = 2.0        # Matching loss weight
alpha_ct = 0.0        # Contrastive loss weight
alpha_sm = 0.0        # Smoothness loss weight
c_beta_adv = 1.0      # Continuous adversarial loss weight
c_beta_fm = 1.0       # Continuous feature matching loss weight
d_beta_adv = 1.0      # Discrete adversarial loss weight
d_beta_fm = 1.0       # Discrete feature matching loss weight

# Adjust hyperparameters based on data dimensions
if c_dim > 10:
    c_z_size = 64
    c_noise_dim = 128
    enc_size = 256
    dec_size = 256

if d_dim > 10:
    d_z_size = 32
    d_noise_dim = 64

print("Hyperparameters set based on data dimensions:")
print(f"Continuous latent dim: {c_z_size}, Noise dim: {c_noise_dim}")
print(f"Discrete latent dim: {d_z_size}, Noise dim: {d_noise_dim}")
print(f"Encoder size: {enc_size}, Decoder size: {dec_size}")

Hyperparameters set based on data dimensions:
Continuous latent dim: 32, Noise dim: 64
Discrete latent dim: 16, Noise dim: 32
Encoder size: 128, Decoder size: 128


### 5. Build Model Components
Now let's build the model components.

In [None]:
# ====== Create model instances ======
# Create VAE instances
print("Creating model components...")
c_vae = C_VAE_NET(
    batch_size=batch_size, time_steps=time_steps, 
    dim=c_dim, z_dim=c_z_size,
    enc_size=enc_size, dec_size=dec_size, 
    enc_layers=enc_layers, dec_layers=dec_layers, 
    keep_prob=keep_prob, l2scale=l2_scale,
    conditional=conditional, num_labels=num_labels
)
print("Continuous VAE created.")
d_vae = D_VAE_NET(
    batch_size=batch_size, time_steps=time_steps, 
    dim=d_dim, z_dim=d_z_size,
    enc_size=enc_size, dec_size=dec_size, 
    enc_layers=enc_layers, dec_layers=dec_layers, 
    keep_prob=keep_prob, l2scale=l2_scale,
    conditional=conditional, num_labels=num_labels
)
print("Discrete VAE created.")
# Create GAN instances
c_gan = C_GAN_NET(
    batch_size=batch_size, noise_dim=c_noise_dim, 
    dim=c_dim, gen_dim=c_z_size, time_steps=time_steps,
    gen_num_units=gen_num_units, gen_num_layers=gen_num_layers,
    dis_num_units=dis_num_units, dis_num_layers=dis_num_layers,
    keep_prob=keep_prob, l2_scale=l2_scale,
    conditional=conditional, num_labels=num_labels
)
print("Continuous GAN created.")
d_gan = D_GAN_NET(
    batch_size=batch_size, noise_dim=d_noise_dim, 
    dim=d_dim, gen_dim=d_z_size, time_steps=time_steps,
    gen_num_units=gen_num_units, gen_num_layers=gen_num_layers,
    dis_num_units=dis_num_units, dis_num_layers=dis_num_layers,
    keep_prob=keep_prob, l2_scale=l2_scale,
    conditional=conditional, num_labels=num_labels
)
print("Discrete GAN created.")
# Define checkpoint directory
checkpoint_dir = "data/checkpoint/"
os.makedirs(checkpoint_dir, exist_ok=True)

Creating model components...
Continuous VAE created.
Discrete VAE created.
Continuous GAN created.
Discrete GAN created.


### 6. Create and Train the Model
Now let's set up the complete model and train it.

In [8]:
# ====== Create and build the complete model ======
# Create the main M3GAN model
print("Creating the complete M3GAN model...")
model = m3gan(
    batch_size=batch_size,
    time_steps=time_steps,
    num_pre_epochs=num_pre_epochs,
    num_epochs=num_epochs,
    checkpoint_dir=checkpoint_dir,
    epoch_ckpt_freq=100,  # Save checkpoint every 100 epochs
    epoch_loss_freq=10,   # Display loss every 10 epochs
    
    # Continuous parameters
    c_dim=c_dim,
    c_noise_dim=c_noise_dim,
    c_z_size=c_z_size,
    c_data_sample=continuous_x,
    c_vae=c_vae,
    c_gan=c_gan,
    
    # Discrete parameters
    d_dim=d_dim,
    d_noise_dim=d_noise_dim,
    d_z_size=d_z_size,
    d_data_sample=discrete_x,
    d_vae=d_vae,
    d_gan=d_gan,
    
    # Training parameters
    d_rounds=d_rounds,
    g_rounds=g_rounds,
    v_rounds=v_rounds,
    v_lr_pre=v_lr_pre,
    v_lr=v_lr,
    g_lr=g_lr,
    d_lr=d_lr,
    
    # Loss weights
    alpha_re=alpha_re,
    alpha_kl=alpha_kl,
    alpha_mt=alpha_mt,
    alpha_ct=alpha_ct,
    alpha_sm=alpha_sm,
    c_beta_adv=c_beta_adv,
    c_beta_fm=c_beta_fm,
    d_beta_adv=d_beta_adv,
    d_beta_fm=d_beta_fm,
    
    # Conditional parameters
    conditional=conditional,
    num_labels=num_labels,
    statics_label=statics_label
)

# Build the model
print("Building the model...")
model.build()

: 

In [None]:
# ====== Train the Model ======
# Train the model using the TF2.x-compatible train method
print("Starting training...")
start_time = timeit.default_timer()

# Check if we want to train the model
train_model = True
if train_model:
    model.train()
    print("Training completed!")
else:
    print("Skipping training phase.")

end_time = timeit.default_timer()
print(f"Training time: {(end_time - start_time)/60:.2f} minutes")

### 7.2 Addressing First Hour Low Variance Issue
Let's analyze and fix the issue with low standard deviation in the first hour.

In [None]:
# Generate data first
print("Generating synthetic data...")
d_gen_data, c_gen_data = model.generate_data(num_sample=1024)

# Apply renormalization if your data was normalized
print("Renormalizing generated data...")
c_gen_data_renorm = c_gen_data  # If no renormalization needed
# If renormalization needed:
# c_gen_data_renorm = renormlizer(c_gen_data, data_info) # TODO: Implement renormlizer



# Analyze the first hour variance issue
print("Analyzing variance across time steps...")
real_std_by_hour = np.std(continuous_x, axis=0)
gen_std_by_hour = np.std(c_gen_data, axis=0)

# Plot the standard deviation over time
plt.figure(figsize=(12, 6))
for i in range(c_dim):
    plt.subplot(1, c_dim, i+1)
    plt.plot(range(time_steps), real_std_by_hour[:, i], 'b-', label='Real')
    plt.plot(range(time_steps), gen_std_by_hour[:, i], 'r-', label='Generated')
    plt.title(f'Feature {i+1} Std Dev')
    plt.xlabel('Hour')
    plt.ylabel('Standard Deviation')
    if i == 0:
        plt.legend()
plt.tight_layout()
plt.show()

# Fix the low variance issue with a post-processing step
print("\nApplying variance correction to the first hour...")

# Improved variance matching function with consistency preservation
def variance_matching(data, target_std, axis=0, smoothing_factor=0.7):
    """
    Match variance of data to target_std while preserving mean and temporal consistency
    
    Args:
        data: The data to adjust
        target_std: The target standard deviation
        axis: Axis along which to compute statistics
        smoothing_factor: How much to smooth between original and adjusted values (0-1)
                         Higher means more of the original preserved
    """
    mean = np.mean(data, axis=axis, keepdims=True)
    std = np.std(data, axis=axis, keepdims=True) + 1e-10  # Avoid division by zero
    
    # Normalize the data
    normalized = (data - mean) / std
    
    # Scale to target std dev and shift back to original mean
    adjusted = normalized * target_std + mean
    
    # Apply smoothing to preserve some of the original structure
    result = smoothing_factor * data + (1 - smoothing_factor) * adjusted
    
    return result

# Apply the correction to the first hour with improvements
print("\nApplying enhanced variance correction to the first hour...")
c_gen_data_corrected = c_gen_data_renorm.copy()

for i in range(c_dim):
    # First hour data
    first_hour_data = c_gen_data_corrected[:, 0, i].reshape(-1, 1)
    
    # Target std dev with slight randomization for more natural results
    target_std = real_std_by_hour[0, i] * np.random.uniform(1.0, 1.15)
    
    # Apply variance matching with trajectory consistency preservation
    c_gen_data_corrected[:, 0, i] = variance_matching(
        first_hour_data, target_std, axis=0, smoothing_factor=0.3
    ).flatten()
    
    # For natural transitions, also slightly adjust the second hour
    if time_steps > 1:
        second_hour_data = c_gen_data_corrected[:, 1, i].reshape(-1, 1)
        target_std_2 = real_std_by_hour[1, i] * np.random.uniform(0.95, 1.05)
        c_gen_data_corrected[:, 1, i] = variance_matching(
            second_hour_data, target_std_2, axis=0, smoothing_factor=0.6
        ).flatten()

# Verify the correction
corrected_std_by_hour = np.std(c_gen_data_corrected, axis=0)

plt.figure(figsize=(15, 6))
for i in range(c_dim):
    plt.subplot(1, c_dim, i+1)
    plt.plot(range(time_steps), real_std_by_hour[:, i], 'b-', label='Real')
    plt.plot(range(time_steps), gen_std_by_hour[:, i], 'r--', label='Original Generated')
    plt.plot(range(time_steps), corrected_std_by_hour[:, i], 'g-', label='Corrected')
    plt.title(f'Feature {i+1} Std Dev')
    plt.xlabel('Hour')
    plt.ylabel('Standard Deviation')
    if i == 0:
        plt.legend()
plt.tight_layout()
plt.show()

# Save the corrected data
save_path_corrected = f"data/fake/gen_data_mimiciv_{filename_postfix}_corrected.npz"
np.savez(save_path_corrected, 
         c_gen_data=c_gen_data_corrected, 
         d_gen_data=d_gen_data)
print(f"Corrected generated data saved to {save_path_corrected}")

### 7.3 Visualization and Analysis

In [None]:
# Sample and visualize a few trajectories
num_samples_to_visualize = 5
sample_indices = np.random.choice(c_gen_data.shape[0], num_samples_to_visualize, replace=False)

plt.figure(figsize=(15, 10))
for i, idx in enumerate(sample_indices):
    # Plot continuous features
    for j in range(c_dim):
        plt.subplot(num_samples_to_visualize, c_dim, i*c_dim + j + 1)
        plt.plot(range(time_steps), continuous_x[idx, :, j], 'b-', label='Real')
        plt.plot(range(time_steps), c_gen_data_renorm[idx, :, j], 'r-', label='Generated')
        if i == 0 and j == 0:
            plt.legend()
        plt.title(f'Sample {i+1}, Feature {j+1}')
        plt.xlabel('Hour')
plt.tight_layout()
plt.show()

# For discrete data
plt.figure(figsize=(15, 5))
for i, idx in enumerate(sample_indices):
    plt.subplot(1, num_samples_to_visualize, i + 1)
    plt.step(range(time_steps), discrete_x[idx, :, 0], 'b-', where='post', label='Real')
    plt.step(range(time_steps), d_gen_data[idx, :, 0], 'r-', where='post', label='Generated')
    if i == 0:
        plt.legend()
    plt.title(f'Discrete Sample {i+1}')
    plt.xlabel('Hour')
plt.tight_layout()
plt.show()

# Compare distributions
plt.figure(figsize=(15, 6))
for i in range(c_dim):
    plt.subplot(1, c_dim, i + 1)
    sns.kdeplot(continuous_x[:, :, i].flatten(), label='Real', color='blue')
    sns.kdeplot(c_gen_data_renorm[:, :, i].flatten(), label='Generated', color='red')
    plt.title(f'Feature {i+1} Distribution')
    if i == 0:
        plt.legend()
plt.tight_layout()
plt.show()

# Discrete data distribution
plt.figure(figsize=(8, 5))
bins = np.linspace(0, 1, 20)
plt.hist(discrete_x.flatten(), bins=bins, alpha=0.5, label='Real', density=True)
plt.hist(d_gen_data.flatten(), bins=bins, alpha=0.5, label='Generated', density=True)
plt.legend()
plt.title('Discrete Data Distribution')
plt.xlabel('Value')
plt.ylabel('Density')
plt.show()

### 8. Advanced Analysis: Preserving Medical Correlations

In [None]:
# Analyze correlations between features
def plot_correlation_heatmap(data, title):
    # Reshape to (patients*times, features)
    data_flat = data.reshape(-1, data.shape[2])
    
    # Calculate correlation matrix
    corr = np.corrcoef(data_flat.T)
    
    # Plot
    plt.figure(figsize=(8, 6))
    sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
    plt.title(title)
    plt.show()
    return corr

# Plot correlation heatmaps
real_corr = plot_correlation_heatmap(continuous_x, 'Real Data Correlation')
gen_corr = plot_correlation_heatmap(c_gen_data_renorm, 'Generated Data Correlation')

# Calculate correlation difference
corr_diff = np.abs(real_corr - gen_corr)
print(f"Mean absolute correlation difference: {np.mean(corr_diff):.4f}")

# Visualize temporal patterns
def plot_mean_trajectory(data, title):
    mean_trajectory = np.mean(data, axis=0)
    
    plt.figure(figsize=(12, 5))
    for i in range(data.shape[2]):
        plt.subplot(1, data.shape[2], i+1)
        plt.plot(range(data.shape[1]), mean_trajectory[:, i])
        plt.title(f'Feature {i+1}')
        plt.xlabel('Hour')
    plt.suptitle(title)
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)
    plt.show()

plot_mean_trajectory(continuous_x, 'Real Data Mean Trajectory')
plot_mean_trajectory(c_gen_data_renorm, 'Generated Data Mean Trajectory')

### 9. Quantitative Evaluation

In [None]:
# Train a model to distinguish real from generated data
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score

def train_discriminator(real, gen):
    # Prepare data
    real_flat = real.reshape(real.shape[0], -1)
    gen_flat = gen.reshape(gen.shape[0], -1)
    
    # Combine and create labels
    X = np.vstack([real_flat, gen_flat])
    y = np.concatenate([np.zeros(len(real_flat)), np.ones(len(gen_flat))])
    
    # Split into train and test
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42)
    
    # Create a simple discriminator model
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(128, activation='relu', input_shape=(X_train.shape[1],)),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    
    # Compile model
    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['acc'])  # Use 'acc' for TF1.x
    
    # Train model
    history = model.fit(
        X_train, y_train,
        epochs=10,
        validation_data=(X_test, y_test),
        verbose=1,
        batch_size=128
    )
    
    # Evaluate model
    y_pred = model.predict(X_test)
    auc = roc_auc_score(y_test, y_pred)
    apr = average_precision_score(y_test, y_pred)
    
    return auc, apr, history

# Evaluate discriminative performance
print("Training discriminator for continuous data...")
c_auc, c_apr, c_history = train_discriminator(continuous_x, c_gen_data_renorm)
print(f"Continuous data - AUC: {c_auc:.4f}, APR: {c_apr:.4f}")

print("\nTraining discriminator for discrete data...")
d_auc, d_apr, d_history = train_discriminator(discrete_x, d_gen_data)
print(f"Discrete data - AUC: {d_auc:.4f}, APR: {d_apr:.4f}")

# In an ideal GAN, AUC should be close to 0.5 (indistinguishable)

# Plot training history
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(c_history.history['acc'], label='train')
plt.plot(c_history.history['val_acc'], label='test')
plt.title('Continuous Discriminator Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(d_history.history['acc'], label='train')
plt.plot(d_history.history['val_acc'], label='test')
plt.title('Discrete Discriminator Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

### 10. Summary and Recommendations

In [None]:
# Display final summary
print("M3GAN Model Training Summary")
print("-" * 40)
print(f"Dataset: MIMIC-IV with {patinet_num} patients")
print(f"Features: {c_dim} continuous, {d_dim} discrete")
print(f"Time steps: {time_steps} hours")
print(f"Pre-training epochs: {num_pre_epochs}")
print(f"Training epochs: {num_epochs}")
print("\nGenerated data quality:")
print(f"Continuous data AUC: {c_auc:.4f} (closer to 0.5 is better)")
print(f"Discrete data AUC: {d_auc:.4f} (closer to 0.5 is better)")
print(f"Mean correlation difference: {np.mean(corr_diff):.4f} (lower is better)")
print("\nTraining files:")
print(f"Checkpoint directory: {checkpoint_dir}")
print(f"Generated data: data/fake/gen_data_mimiciv_{filename_postfix}_corrected.npz")

# Recommendations
print("\nRecommendations:")
print("1. For low variance in first hour: Applied variance matching as a post-processing step")
print("2. For better stability: Consider using gradient penalty or spectral normalization")
print("3. For better feature correlations: Increase the alpha_mt parameter")
print("4. For more realistic trajectories: Increase the number of GAN training epochs")
print("5. For more diverse samples: Consider reducing batch size or adding noise during generation")

### 11. Close Session

In [None]:
# Close the TensorFlow session
sess.close()
print("TensorFlow session closed")