# Evaluation Metrics in ehrMGAN

## Variables Settings

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pickle
from scipy import stats
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score
import tensorflow as tf

# Set plotting style
plt.style.use('seaborn-whitegrid')
sns.set_style("whitegrid")

In [None]:
GAN_24_FLAG = False
DATA_SOURCE = "mimic_iv" # "eICU" "mimic_iv" "MIMIC_EXTRACT_mimiciv"
DATA_GEN = "GAN" # "TABLE" "GAN"
bins = 20                          # number of bins for discretization
obs_hrs = 3                        # number of observation hours (e.g., current + previous)
subset_size = 10000                # number of distinct stay_id to use (sample a subset)
test_cases_num = 1000              # number of test cases to generate
id_cols = ['stay_id', 'before_weaning_hr']
feature_cols = ['heart_rate', 'resp_rate', 'spo2', 'fio2', 'respiratory_rate_set', 'gender_M', 'age'] # TODO: add gender, age, peep, (tidal_volume_set, tidal_volume_observed, sbp, dbp, mbp)
selected_cols = id_cols + feature_cols
state_cols = ['heart_rate', 'resp_rate', 'spo2']  # State variables
action_cols = ['fio2', 'respiratory_rate_set']  # Action variables
state_units = ['bpm', 'bpm', '%']  # Units for state variables
action_units = ['%', 'bpm']  # Units for action variables
baseline_cols = ['gender_M', 'age']  # Baseline variables
traj_hrs = "24" if GAN_24_FLAG else "all"
prefix_path = f"../data/{DATA_SOURCE}"
output_dir = f"{prefix_path}/init_{obs_hrs}_hrs_{bins}_bins_traj_{traj_hrs}_hrs"

## Read in Data

### Real data
- only select out the first 24 hr of each stay_id

In [None]:
# ================================
# Data reading
# ================================
# Define data paths (update if needed)
eICU_prefix_path = "../data/eICU"
mimic_iii_prefix_path = "../data/mimic_iii"
mimic_iv_prefix_path = "../data/mimic_iv"

eICU_file = os.path.join(eICU_prefix_path, "baseline_charttime_ground_truth.csv")
mimic_iv_file = os.path.join(mimic_iv_prefix_path, "baseline_charttime_ground_truth.csv")
mimic_iii_file = os.path.join(mimic_iii_prefix_path, "baseline_charttime_ground_truth.csv")

# Read dataframes
eICU_df = pd.read_csv(eICU_file)
mimic_iv_df = pd.read_csv(mimic_iv_file)
mimic_iii_df = pd.read_csv(mimic_iii_file)

# ================================
# Add "hours_in" column (if missing)
# ================================
def add_hours_in(baseline_df):
    # if hours_in not present, create it based on before_weaning_hr
    if 'hours_in' not in list(baseline_df.columns):
        vital_sign_df = baseline_df[selected_cols]
        vital_sign_df_with_hours = vital_sign_df.copy()
        # For each stay_id, use the maximum before_weaning_hr as baseline
        max_before_weaning = vital_sign_df.groupby('stay_id')['before_weaning_hr'].transform('max')
        vital_sign_df_with_hours['hours_in'] = max_before_weaning - vital_sign_df['before_weaning_hr']
        vital_sign_df_with_hours = vital_sign_df_with_hours.sort_values(['stay_id', 'hours_in'])
        vital_sign_df_with_hours.drop(columns=['before_weaning_hr'], inplace=True)
        cols = list(vital_sign_df_with_hours.columns)
        # Move hours_in to position 1 (after stay_id)
        if 'hours_in' in cols:
            cols.remove('hours_in')
        cols.insert(1, 'hours_in')
        baseline_df = vital_sign_df_with_hours[cols]
    baseline_df['gender_M'] = baseline_df['gender_M'].astype(int)
    return baseline_df

# Apply to all datasets
eICU_df = add_hours_in(eICU_df)
mimic_iv_df = add_hours_in(mimic_iv_df)
mimic_iii_df = add_hours_in(mimic_iii_df)

#### MIMIC IV Real data

In [None]:
baseline_charttime_ground_truth_first_24_hr_df = mimic_iv_df[mimic_iv_df['hours_in'] < 24]
# mimic_iv_first_24_hr_df = baseline_charttime_ground_truth_first_24_hr_df.copy()
# mimic_iv_first_24_hr_df
mimic_iv_df

In [None]:
mimic_iii_df

In [None]:
eICU_df

#### MIMIC IV Naive Agent with Ventilation Patient Environment

In [None]:
naive_agent_generated_df = pd.read_csv(
    os.path.join(prefix_path, "naive_agent_generated_traj_cont_larger_24_hr_10000.csv"))
# naive_agent_generated_df[naive_agent_generated_df.isna().any(axis=1)].groupby('stay_id').size().reset_index(name='counts')
# naive_agent_generated_df['respiratory_rate_set'].unique()

In [None]:
# read data/mimic_iv/naive_agent_generated_traj_cont_larger_24_hr_10000.csv
naive_agent_generated_df = pd.read_csv(
    os.path.join(prefix_path, "naive_agent_generated_traj_cont_larger_24_hr_10000.csv"))
# Convert feature columns to numeric type
for col in feature_cols:
    if col in naive_agent_generated_df.columns:
        naive_agent_generated_df[col] = pd.to_numeric(naive_agent_generated_df[col], errors='coerce')
naive_agent_generated_df
# naive_agent_generated_first_24_df = naive_agent_generated_df[naive_agent_generated_df['hours_in'] < 24]
# naive_agent_generated_first_24_df

In [None]:
# naive_agent_generated_df.dropna(inplace=True)
naive_agent_generated_df.isna().sum()

#### Health Gym

In [None]:
# read data/health_gym/Health_gym_FakeSepsis.csv
health_gym_raw_df = pd.read_csv(
    os.path.join("../data/health_gym/Health_gym_FakeSepsis.csv"))
health_gym_raw_df

In [None]:
health_gym_raw_df.columns

In [None]:
import pandas as pd

def map_health_gym_to_mimic(health_gym_df):
    """
    Map health_gym_df columns to mimic_iv_df's column names.

    Args:
        health_gym_df (pd.DataFrame): The input DataFrame with health_gym_df columns.

    Returns:
        pd.DataFrame: A new DataFrame with columns mapped to mimic_iv_df's column names.
    """
    health_gym_df['respiratory_rate_set'] = health_gym_df['RR']  # Use RR for respiratory_rate_set
    # Define the mapping between health_gym_df and mimic_iv_df columns
    column_mapping = {
        'PatientID': 'stay_id',              # Map PatientID to stay_id
        'Timepoints': 'hours_in',
        'HR': 'heart_rate',               # Map HR to heart_rate
        'RR': 'resp_rate',               # Map RR to resp_rate
        'SpO2': 'spo2',                  # Map SpO2 to spo2
        'FiO2': 'fio2',                  # Map FiO2 to fio2
        'respiratory_rate_set': 'respiratory_rate_set',    # Use RR for respiratory_rate_set
        'Gender': 'gender_M',            # Map Gender to gender_M
        'Age': 'age'                     # Map Age to age
    }

    # Select and rename the columns
    mimic_iv_df = health_gym_df[list(column_mapping.keys())].rename(columns=column_mapping)
    mimic_iv_df['fio2'] = mimic_iv_df['fio2'] * 100  # Scale FiO2 to percentage
    mimic_iv_df['spo2'] = (mimic_iv_df['spo2'] + 1) * 2 + 80  # TODO: since it is category, guess the mapping method: Scale FiO2 to percentage

    return mimic_iv_df

In [None]:
# Map health_gym_df to mimic_iv_df
health_gym_df = map_health_gym_to_mimic(health_gym_raw_df)

# Display the resulting DataFrame
health_gym_df

In [None]:
health_gym_df.describe()

#### SDV PAR

In [None]:
# read data/SDV_PAR/synthetic_data_1_epochs.csv
sdv_par_df = pd.read_csv(
    os.path.join("../data/SDV_PAR/synthetic_data_4_epochs.csv"))
sdv_par_df

In [None]:
sdv_par_df.describe()

### GAN generate data and TABLE generate data

In [None]:
GAN_gen_data = np.load(f'{output_dir}/fake/gen_data_499.npz')['c_gen_data']
TABLE_gen_data = np.load(f'{output_dir}/fake/table_gen_data.npz')['c_gen_data']

In [None]:
GAN_gen_data

In [None]:
TABLE_gen_data

## Renormalizing Generated Data Back to Original Range

In [None]:
def renormalize_data(gan_data, table_data, original_df):
    """
    Renormalize GAN and TABLE generated data back to the original value ranges.
    
    Args:
        gan_data: GAN generated data (normalized 0-1)
        table_data: TABLE generated data (discretized 0-19, then normalized to 0-1)
        original_df: Original DataFrame containing the real data with original ranges
    
    Returns:
        Tuple of (renormalized_gan_data, renormalized_table_data)
    """
    import numpy as np
    import pandas as pd
    
    # Get original feature columns (excluding stay_id and hours_in)
    feature_cols = ['heart_rate', 'resp_rate', 'spo2', 'fio2', 'respiratory_rate_set']
    
    # Calculate min and max for each feature in the original data
    feature_stats = {}
    for i, col in enumerate(feature_cols):
        feature_stats[i] = {
            'min': original_df[col].min(),
            'max': original_df[col].max(),
            'range': original_df[col].max() - original_df[col].min()
        }
    
    # Create copies to avoid modifying the original arrays
    gan_renormalized = np.zeros_like(gan_data)
    table_renormalized = np.zeros_like(table_data)
    
    # Process each feature
    for i in range(len(feature_cols)):
        # For GAN data: scale from [0,1] to [min, max]
        min_val = feature_stats[i]['min']
        max_val = feature_stats[i]['max']
        gan_renormalized[:, :, i] = gan_data[:, :, i] * (max_val - min_val) + min_val
        
        # For TABLE data: 
        # 1) Multiply by 19 to get back to bin indices (0-19)
        # 2) Convert to original range
        bins = 20
        bin_indices = np.round(table_data[:, :, i] * 19).astype(int)
        
        # Calculate bin edges for this feature
        bin_edges = np.linspace(min_val, max_val, bins + 1)
        
        # Map bin indices to bin centers
        for idx in range(bins):
            # Find all values with this bin index
            mask = (bin_indices == idx)
            # Replace with bin center value
            bin_center = (bin_edges[idx] + bin_edges[idx + 1]) / 2
            table_renormalized[:, :, i][mask] = bin_center
    
    print(f"Data renormalized to original ranges:")
    for i, col in enumerate(feature_cols):
        print(f"  {col}: Original range [{feature_stats[i]['min']:.1f}, {feature_stats[i]['max']:.1f}]")
        print(f"    GAN range: [{gan_renormalized[:,:,i].min():.1f}, {gan_renormalized[:,:,i].max():.1f}]")
        print(f"    TABLE range: [{table_renormalized[:,:,i].min():.1f}, {table_renormalized[:,:,i].max():.1f}]")
    
    return gan_renormalized, table_renormalized

# Get feature statistics from original data
feature_cols = ['heart_rate', 'resp_rate', 'spo2', 'fio2', 'respiratory_rate_set']
feature_stats = baseline_charttime_ground_truth_first_24_hr_df[feature_cols].describe()
print("Original data statistics:")
print(feature_stats)

# Renormalize the data
gan_renormalized, table_renormalized = renormalize_data(
    gan_data=GAN_gen_data,
    table_data=TABLE_gen_data,
    original_df=baseline_charttime_ground_truth_first_24_hr_df
)

# Sample visualization comparing original vs renormalized distributions
import matplotlib.pyplot as plt
import seaborn as sns

# Plot distributions for all features
plt.figure(figsize=(15, 10))
for i, col in enumerate(feature_cols):
    plt.subplot(2, 3, i+1)
    
    # Flatten arrays for distribution plotting
    original_values = baseline_charttime_ground_truth_first_24_hr_df[col].values
    gan_values = gan_renormalized[:,:,i].flatten()
    table_values = table_renormalized[:,:,i].flatten()
    
    # Plot distributions
    sns.kdeplot(original_values, label='Original', linewidth=2)
    sns.kdeplot(gan_values, label='GAN', linewidth=2)
    sns.kdeplot(table_values, label='TABLE', linewidth=2)
    
    plt.title(f'{col} Distribution')
    plt.legend()

plt.tight_layout()
plt.show()

# Save the renormalized data if needed
np.savez(f'{output_dir}/fake/renormalized_gan_data.npz', data=gan_renormalized)
np.savez(f'{output_dir}/fake/renormalized_table_data.npz', data=table_renormalized)

print(f"Renormalized data saved to {output_dir}")

In [None]:
gan_renormalized

In [None]:
import pandas as pd
import numpy as np

def gan_to_dataframe(gan_renormalized, start_stay_id=30000000):
    """
    Convert gan_renormalized array into a DataFrame with the desired format.

    Args:
        gan_renormalized (np.ndarray): The GAN-generated data array of shape (n_patients, n_timesteps, n_features).
        start_stay_id (int): The starting stay_id for generating fake IDs.

    Returns:
        pd.DataFrame: A DataFrame with columns ['stay_id', 'hours_in', 'heart_rate', 'resp_rate', 'spo2', 'fio2', 'respiratory_rate_set'].
    """
    # Initialize an empty list to store rows
    rows = []

    # Iterate over each patient
    for patient_idx, patient_data in enumerate(gan_renormalized):
        stay_id = start_stay_id + patient_idx  # Generate a fake stay_id
        for hour, timestep_data in enumerate(patient_data):
            # Extract the relevant features
            heart_rate, resp_rate, spo2, fio2, respiratory_rate_set = timestep_data[:5]
            # Append the row to the list
            rows.append({
                'stay_id': stay_id,
                'hours_in': hour,
                'heart_rate': round(heart_rate, 1),
                'resp_rate': round(resp_rate, 1),
                'spo2': round(spo2, 1),
                'fio2': round(fio2, 1),
                'respiratory_rate_set': round(respiratory_rate_set, 1)
            })

    # Convert the list of rows into a DataFrame
    df = pd.DataFrame(rows)
    return df

In [None]:
# Convert gan_renormalized to DataFrame
gan_df = gan_to_dataframe(gan_renormalized)

# Display the first few rows
gan_df

## Evaluation Metrics

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy.stats import ks_2samp

# Assuming dataframes are already loaded: mimic_iv_df, mimic_iii_df, etc.
dataframes = {
    'mimic_iv': mimic_iv_df,
    'mimic_iii': mimic_iii_df,
    'eICU': eICU_df,
    'physician_policy': naive_agent_generated_df,
    'health_gym': health_gym_df,
    'sdv_par': sdv_par_df,
    'gan': gan_df
}

# Define feature lists for flexibility
common_features = ['heart_rate', 'resp_rate', 'spo2', 'fio2', 'respiratory_rate_set']
additional_features = ['gender_M', 'age']
all_features = common_features + additional_features

# Identify categorical features
categorical_features = ['gender_M']
numerical_features = [f for f in all_features if f not in categorical_features]

# Set time horizon (20 hours for consistency with health_gym_df, adjustable)
max_hours = 20

# Create output directory
output_dir = '../data/evaluation_results'
os.makedirs(output_dir, exist_ok=True)

# Function to filter dataframes by max_hours
def filter_by_hours(df, max_hours):
    return df[df['hours_in'] < max_hours]

# Function to plot feature distributions
def plot_feature_distributions(dataframes, numerical_features, categorical_features, output_dir, max_hours):
    # Plot numerical features with KDE
    for feature in numerical_features:
        plt.figure(figsize=(10, 6))
        for name, df in dataframes.items():
            if feature in df.columns:
                filtered_df = filter_by_hours(df, max_hours)
                sns.kdeplot(filtered_df[feature].dropna(), label=name)
        plt.title(f'Distribution of {feature}')
        plt.legend()
        plt.savefig(os.path.join(output_dir, f'{feature}_distribution.png'))
        plt.close()
    
    # Plot categorical features with bar plots
    for feature in categorical_features:
        plt.figure(figsize=(10, 6))
        for name, df in dataframes.items():
            if feature in df.columns:
                filtered_df = filter_by_hours(df, max_hours)
                value_counts = filtered_df[feature].value_counts(normalize=True)
                plt.bar([f"{name}_{val}" for val in value_counts.index], value_counts.values, alpha=0.5, label=name)
        plt.title(f'Distribution of {feature}')
        plt.legend()
        plt.xticks(rotation=45)
        plt.savefig(os.path.join(output_dir, f'{feature}_distribution.png'))
        plt.close()

# Function to compute summary statistics
def compute_summary_stats(dataframes, features, output_dir, max_hours):
    for name, df in dataframes.items():
        filtered_df = filter_by_hours(df, max_hours)
        available_features = [f for f in features if f in df.columns]
        if available_features:
            summary = filtered_df[available_features].describe()
            summary.to_csv(os.path.join(output_dir, f'{name}_summary.csv'))

# Function to compute KS test against reference (mimic_iv)
def compute_ks_tests(dataframes, features, output_dir, max_hours):
    reference_df = filter_by_hours(dataframes['mimic_iv'], max_hours)
    ks_results = {}
    for feature in features:
        if feature in reference_df.columns:
            ks_results[feature] = {}
            ref_data = reference_df[feature].dropna()
            for name, df in dataframes.items():
                if name == 'mimic_iv' or feature not in df.columns:
                    continue
                filtered_df = filter_by_hours(df, max_hours)
                synth_data = filtered_df[feature].dropna()
                ks_stat, p_value = ks_2samp(ref_data, synth_data)
                ks_results[feature][name] = {'ks_stat': ks_stat, 'p_value': p_value}
    
    with open(os.path.join(output_dir, 'ks_results.txt'), 'w') as f:
        for feature, results in ks_results.items():
            f.write(f'Feature: {feature}\n')
            for name, res in results.items():
                f.write(f'  {name}: KS stat = {res["ks_stat"]:.4f}, p-value = {res["p_value"]:.4f}\n')

# Function from evaluation_metrics.ipynb: Feature-temporal correlation
def feature_temporal_correlation(real_df, synthetic_dfs, synthetic_names, feature_names, output_dir, max_hours):
    filtered_real = filter_by_hours(real_df, max_hours)
    synthetic_dfs_filtered = [filter_by_hours(df, max_hours) for df in synthetic_dfs]
    
    # Set the number of plots and figure size
    num_plots = len(synthetic_dfs) + 1
    fig, axes = plt.subplots(1, num_plots, figsize=(5 * num_plots, 5), gridspec_kw={'wspace': 0.4})
    
    for ax, df, title in zip(axes, [filtered_real] + synthetic_dfs_filtered, ['Real'] + synthetic_names):
        corr = df[feature_names].corr()
        sns.heatmap(corr, annot=True, cmap='coolwarm', ax=ax, square=True, cbar=False)
        ax.set_title(title)
    
    # Adjust layout and save the figure
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'temporal_correlation.png'))
    plt.close()

# Function from evaluation_metrics.ipynb: Feature cross-correlation difference
def feature_cross_correlation_difference(real_df, *synthetic_dfs, feature_names, output_dir, max_hours):
    filtered_real = filter_by_hours(real_df, max_hours)
    synthetic_dfs_filtered = [filter_by_hours(df, max_hours) for df in synthetic_dfs]
    
    real_corr = filtered_real[feature_names].corr()
    corr_diff_summary = []
    
    for i, (synth_df, name) in enumerate(zip(synthetic_dfs_filtered, [key for key in dataframes.keys() if key not in ['mimic_iv']])):
        synth_corr = synth_df[feature_names].corr()
        corr_diff = np.abs(real_corr - synth_corr)

        plt.figure(figsize=(8, 6))
        sns.heatmap(corr_diff, annot=True, cmap='coolwarm')
        plt.title(f'Correlation Difference: Real vs {name}')
        plt.savefig(os.path.join(output_dir, f'corr_diff_{name}.png'))
        plt.close()

        mean_diff = corr_diff.values.mean()
        max_diff = corr_diff.values.max()
        corr_diff_summary.append({'Method': f'{name} vs Real', 'Mean Abs Diff': mean_diff, 'Max Abs Diff': max_diff})
        print(f"{name} - Real mean absolute correlation difference: {mean_diff:.4f}")
    
    summary_df = pd.DataFrame(corr_diff_summary)
    summary_df.to_csv(os.path.join(output_dir, 'corr_diff_summary.csv'))
    print("\nCorrelation Difference Summary:\n", summary_df)
    return summary_df



In [None]:
# Filter all dataframes
filtered_dataframes = {name: filter_by_hours(df, max_hours) for name, df in dataframes.items()}
print(f"Filtered dataframes for max_hours={max_hours}:\n{filtered_dataframes}")

# 1. Visualize feature distributions
plot_feature_distributions(dataframes, numerical_features, categorical_features, output_dir, max_hours)
print("Feature distributions plotted and saved.")

# 2. Compute summary statistics
compute_summary_stats(dataframes, all_features, output_dir, max_hours)
print("Summary statistics computed and saved.")

# 3. Compute KS tests
compute_ks_tests(dataframes, common_features, output_dir, max_hours)
print("KS tests computed and saved.")

# 4. Run evaluation metric functions from evaluation_metrics.ipynb
real_df = dataframes['mimic_iv']
synthetic_dfs = [dataframes['physician_policy'], dataframes['health_gym'], dataframes['gan'], dataframes['sdv_par']]
synthetic_names = ['Naive Agent', 'Health Gym', 'GAN', 'SDV PAR']
# synthetic_dfs = [dataframes['gan'], dataframes['sdv_par'], dataframes['physician_policy'], dataframes['health_gym']]
feature_names = common_features  # Use only common features for consistency

print("\nGenerating feature-temporal correlation heatmaps...")
feature_temporal_correlation(real_df, synthetic_dfs, synthetic_names, feature_names=feature_names, output_dir=output_dir, max_hours=max_hours)

print("\nGenerating correlation difference heatmaps...")
feature_cross_correlation_difference(real_df, *synthetic_dfs, feature_names=feature_names, output_dir=output_dir, max_hours=max_hours)


In [None]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='This figure includes Axes that are not compatible with tight_layout')

In [None]:
real_df = dataframes['mimic_iv']
# synthetic_dfs = [dataframes['naive_agent'], dataframes['health_gym'], dataframes['gan'], dataframes['sdv_par']]
# synthetic_names = ['Naive Agent', 'Health Gym', 'ehrMGAN', 'SDV PAR']
synthetic_dfs = [dataframes['mimic_iii'], dataframes['eICU']]
synthetic_names = ['MIMIC III', 'eICU']

feature_names = common_features  # Use only common features for consistency

print("\nGenerating feature-temporal correlation heatmaps...")
feature_temporal_correlation(real_df, synthetic_dfs, synthetic_names, feature_names=feature_names, output_dir=output_dir, max_hours=max_hours)

In [None]:
# show the unique values of respiratory_rate_set in all the dataframes
for name, df in dataframes.items():
    if 'respiratory_rate_set' in df.columns:
        print(f"{name}: {len(df['respiratory_rate_set'].unique())}")
        # print(f"{name}: {df['respiratory_rate_set'].min()}")
        # print(f"{name}: {df['respiratory_rate_set'].max()}")

### Evaluation Metric v2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, mean_squared_error, roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split
from math import sqrt
import os

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

# Set up matplotlib
plt.style.use('seaborn-whitegrid')
sns.set_style("whitegrid")

# Configuration
output_dir = "../data"  # Adjust as needed
evaluation_dir = f"{output_dir}/evaluations_v2/"
os.makedirs(evaluation_dir, exist_ok=True)

max_hours = 20  # Set to 20 for consistency with health_gym_raw_df; adjustable
feature_cols = ['heart_rate', 'resp_rate', 'spo2', 'fio2', 'respiratory_rate_set']  # Common time-varying features
feature_names = ['Heart Rate', 'Resp. Rate', 'SpO2', 'FiO2', 'RR Set']  # For plotting

# Dataframes dictionary
dataframes = {
    'mimic_iv': mimic_iv_df,
    'mimic_iii': mimic_iii_df,
    'eICU': eICU_df,
    'physician_policy': naive_agent_generated_df,
    'health_gym': health_gym_df,
    'sdv_par': sdv_par_df,
    'gan': gan_df
}

# Helper function to convert dataframe to 3D array
def df_to_array(df, feature_cols=feature_cols, max_hours=max_hours):
    patient_ids = df['stay_id'].unique()
    n_patients = len(patient_ids)
    n_features = len(feature_cols)
    
    patient_id_to_idx = {id: i for i, id in enumerate(patient_ids)}
    array_3d = np.zeros((n_patients, max_hours, n_features))
    array_3d.fill(np.nan)
    
    df_filtered = df[df['hours_in'] < max_hours]
    for _, row in df_filtered.iterrows():
        patient_idx = patient_id_to_idx[row['stay_id']]
        hour = int(row['hours_in'])
        features = row[feature_cols].values
        array_3d[patient_idx, hour, :] = features
    
    # Forward-fill missing values
    for p in range(n_patients):
        for f in range(n_features):
            last_valid = None
            for h in range(max_hours):
                if np.isnan(array_3d[p, h, f]) and last_valid is not None:
                    array_3d[p, h, f] = last_valid
                elif not np.isnan(array_3d[p, h, f]):
                    last_valid = array_3d[p, h, f]
    
    return array_3d

# Convert all dataframes to 3D arrays
arrays = {name: df_to_array(df) for name, df in dataframes.items()}
real_array = arrays['mimic_iv']  # Reference real dataset

# 1. Visualize feature distributions
def plot_feature_distributions(dataframes, feature_cols, max_hours=max_hours):
    filtered_dfs = {name: df[df['hours_in'] < max_hours] for name, df in dataframes.items()}
    all_features = feature_cols + ['gender_M', 'age']
    
    for feature in all_features:
        plt.figure(figsize=(12, 6))
        if feature in ['gender_M']:  # Categorical
            counts = {}
            for name, df in filtered_dfs.items():
                if feature in df.columns:
                    counts[name] = df[feature].value_counts(normalize=True)
            if counts:
                count_df = pd.DataFrame(counts).T
                count_df.plot(kind='bar', ax=plt.gca())
                plt.title(f'Distribution of {feature}')
                plt.xlabel('Dataset')
                plt.ylabel('Proportion')
                plt.legend(['Female', 'Male'])
        else:  # Numerical
            for name, df in filtered_dfs.items():
                if feature in df.columns:
                    sns.kdeplot(df[feature], label=name, warn_singular=False)
            plt.title(f'Distribution of {feature}')
            plt.xlabel(feature)
            plt.ylabel('Density')
            plt.legend()
        plt.tight_layout()
        plt.savefig(f"{evaluation_dir}/{feature}_distribution.png", dpi=300)
        plt.close()

plot_feature_distributions(dataframes, feature_cols)

# 2.1 Max Mean Discrepancy (MMD)
def calculate_mmd(real_data, synthetic_data, bandwidth_multipliers=[0.2, 0.5, 0.9, 1.3]):
    from sklearn.metrics.pairwise import rbf_kernel

    # Flatten data to 2D
    real_flat = real_data.reshape(real_data.shape[0], -1)
    synth_flat = synthetic_data.reshape(synthetic_data.shape[0], -1)

    # Remove rows with NaN values
    real_flat = real_flat[~np.isnan(real_flat).any(axis=1)]
    synth_flat = synth_flat[~np.isnan(synth_flat).any(axis=1)]

    # Sample if too large (for computational efficiency)
    max_samples = 1000
    if len(real_flat) > max_samples:
        indices = np.random.choice(len(real_flat), max_samples, replace=False)
        real_flat = real_flat[indices]
    if len(synth_flat) > max_samples:
        indices = np.random.choice(len(synth_flat), max_samples, replace=False)
        synth_flat = synth_flat[indices]

    # Calculate median bandwidth using median heuristic
    X = np.vstack([real_flat, synth_flat])
    median_dist = np.median(np.sqrt(np.sum((X[:, None, :] - X[None, :, :]) ** 2, axis=-1)))

    # Handle case where median_dist is 0 or NaN
    if np.isnan(median_dist) or median_dist == 0:
        raise ValueError("Median distance is NaN or zero. Check your data for invalid values.")

    mmd_values = []
    for bandwidth_multiplier in bandwidth_multipliers:
        bandwidth = bandwidth_multiplier * median_dist

        # Calculate kernel matrices
        K_XX = rbf_kernel(real_flat, real_flat, gamma=1.0 / (2 * bandwidth**2))
        K_YY = rbf_kernel(synth_flat, synth_flat, gamma=1.0 / (2 * bandwidth**2))
        K_XY = rbf_kernel(real_flat, synth_flat, gamma=1.0 / (2 * bandwidth**2))

        # Calculate MMD
        mmd = np.mean(K_XX) - 2 * np.mean(K_XY) + np.mean(K_YY)
        # mmd_values.append(max(0, mmd))  # MMD should be non-negative
        mmd_values.append(max(0, np.sqrt(mmd)))  # MMD should be non-negative

    # Return average MMD across bandwidths
    return np.mean(mmd_values)

import numpy as np
from scipy.spatial.distance import cdist
# 2.1 Max Mean Discrepancy (MMD) [update] aligned with transition level evaluation
""" def calculate_mmd(real_data, synthetic_data, bandwidth_multipliers=[0.2, 0.5, 0.9, 1.3]):
    
    # Calculate Maximum Mean Discrepancy (MMD) between real_data and synthetic_data
    # using an unbiased estimator and Gaussian kernel, similar to MMD_evaluation.
   
    # Flatten data to 2D
    real_flat = real_data.reshape(real_data.shape[0], -1)
    synth_flat = synthetic_data.reshape(synthetic_data.shape[0], -1)

    # Remove rows with NaN values
    real_flat = real_flat[~np.isnan(real_flat).any(axis=1)]
    synth_flat = synth_flat[~np.isnan(synth_flat).any(axis=1)]

    # Sample if too large (for computational efficiency)
    max_samples = 1000
    if len(real_flat) > max_samples:
        indices = np.random.choice(len(real_flat), max_samples, replace=False)
        real_flat = real_flat[indices]
    if len(synth_flat) > max_samples:
        indices = np.random.choice(len(synth_flat), max_samples, replace=False)
        synth_flat = synth_flat[indices]

    # Compute squared distances
    xx = cdist(real_flat, real_flat, 'euclidean') ** 2
    yy = cdist(synth_flat, synth_flat, 'euclidean') ** 2
    xy = cdist(real_flat, synth_flat, 'euclidean') ** 2

    # Median heuristic for bandwidth
    all_states = np.vstack([real_flat, synth_flat])
    all_distances = cdist(all_states, all_states, 'euclidean')
    sigma = np.median(all_distances[all_distances > 0])
    if np.isnan(sigma) or sigma == 0:
        raise ValueError("Median distance is NaN or zero. Check your data for invalid values.")

    mmd_values = []
    for bandwidth_multiplier in bandwidth_multipliers:
        bandwidth = bandwidth_multiplier * sigma
        # Gaussian kernel
        k_xx = np.exp(-xx / (2 * bandwidth ** 2))
        k_yy = np.exp(-yy / (2 * bandwidth ** 2))
        k_xy = np.exp(-xy / (2 * bandwidth ** 2))

        m, n = len(real_flat), len(synth_flat)
        # Unbiased MMD^2
        if m > 1:
            term1 = (np.sum(k_xx) - np.trace(k_xx)) / (m * (m - 1))
        else:
            term1 = 0
        if n > 1:
            term2 = (np.sum(k_yy) - np.trace(k_yy)) / (n * (n - 1))
        else:
            term2 = 0
        term3 = np.sum(k_xy) / (m * n)
        mmd = np.sqrt(max(0, term1 + term2 - 2 * term3)) # with sqrt for consistency with MMD_evaluation
        # mmd = max(0, term1 + term2 - 2 * term3) # without sqrt
        mmd_values.append(mmd)

    return float(np.mean(mmd_values)) """

mmd_results = {}
for name, array in arrays.items():
    if name != 'mimic_iv':
        mmd_results[name] = calculate_mmd(real_array, array)
print("MMD results:")
for name, mmd in mmd_results.items():
    print(f"{name}: {mmd:.6f}")

# 2.2 Pearson Correlation
def compute_correlation_matrix(data_array):
    flat_data = data_array.reshape(-1, data_array.shape[-1])
    return np.corrcoef(flat_data.T)

real_corr = compute_correlation_matrix(real_array)
corr_diffs = {}
for name, array in arrays.items():
    if name != 'mimic_iv':
        synth_corr = compute_correlation_matrix(array)
        diff = synth_corr - real_corr
        corr_diffs[name] = np.linalg.norm(diff)

print("\nCorrelation difference (Frobenius norm):")
for name, diff in corr_diffs.items():
    print(f"{name}: {diff:.4f}")

# Visualize correlation for selected datasets
# selected_datasets = {'mimic_iv': real_array, 'naive_agent_df': arrays['naive_agent'], 'health_gym_df': arrays['health_gym'], 'gan': arrays['gan'], 'sdv_par': arrays['sdv_par']}
selected_datasets = {'mimic_iv': real_array, 'naive_agent_df': arrays['physician_policy'], 'health_gym_df': arrays['health_gym'], 'gan': arrays['gan'], 'sdv_par': arrays['sdv_par']}
fig, axes = plt.subplots(1, len(selected_datasets), figsize=(6*len(selected_datasets), 5))
for ax, (name, array) in zip(axes, selected_datasets.items()):
    corr = compute_correlation_matrix(array)
    sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, ax=ax, xticklabels=feature_names, yticklabels=feature_names)
    ax.set_title(f'Correlation - {name}')
plt.tight_layout()
plt.savefig(f"{evaluation_dir}/correlation_selected.png", dpi=300)
plt.close()

# 2.3 Dimension-wise Probability
dim_stats = {}
real_means = np.nanmean(real_array, axis=(0,1))
for name, array in arrays.items():
    if name != 'mimic_iv':
        synth_means = np.nanmean(array, axis=(0,1))
        cc = np.corrcoef(real_means, synth_means)[0,1]
        rmse = sqrt(mean_squared_error(real_means, synth_means))
        dim_stats[name] = {'cc': cc, 'rmse': rmse}

print("\nDimension-wise statistics:")
for name, stats in dim_stats.items():
    print(f"{name}: CC={stats['cc']:.4f}, RMSE={stats['rmse']:.4f}")

# Plot means
means = {name: np.nanmean(array, axis=(0,1)) for name, array in arrays.items()}
x = np.arange(len(feature_cols))
width = 0.1
fig, ax = plt.subplots(figsize=(12, 6))
for i, (name, mean) in enumerate(means.items()):
    offset = width * (i - len(means)/2)
    ax.bar(x + offset, mean, width, label=name)
ax.set_xticks(x)
ax.set_xticklabels(feature_names, rotation=45)
ax.set_ylabel('Mean Value')
ax.set_title('Mean Feature Values Comparison')
ax.legend()
plt.tight_layout()
plt.savefig(f"{evaluation_dir}/feature_means.png", dpi=300)
plt.close()



In [None]:
# 2.4 Discriminative Score
def train_discriminator(real_data, synthetic_data):
    """
    Train a discriminator to distinguish between real and synthetic data.
    Returns the AUC and Average Precision scores.
    """
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import roc_auc_score, average_precision_score
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense, Flatten

    # Combine real and synthetic data
    real_labels = np.ones(len(real_data))
    synthetic_labels = np.zeros(len(synthetic_data))
    X = np.vstack([real_data, synthetic_data])
    y = np.hstack([real_labels, synthetic_labels])

    # Flatten the input data (combine time steps and features)
    X = X.reshape(X.shape[0], -1)

    # Split into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Build a simple neural network discriminator
    model = Sequential([
        Dense(64, activation='relu', input_dim=X_train.shape[1]),
        Dense(32, activation='relu'),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    # Train the model
    model.fit(X_train, y_train, epochs=10, batch_size=32, validation_data=(X_test, y_test), verbose=0)

    # Predict probabilities
    y_pred = model.predict(X_test, verbose=0).flatten()

    # Remove NaN values from y_test and y_pred
    valid_indices = ~np.isnan(y_pred)
    y_test = y_test[valid_indices]
    y_pred = y_pred[valid_indices]

    # Check if arrays are empty
    if len(y_test) == 0 or len(y_pred) == 0:
        print("Warning: No valid samples after filtering NaN values.")
        return float('nan'), float('nan')

    # Calculate AUC and Average Precision
    auc = roc_auc_score(y_test, y_pred)
    apr = average_precision_score(y_test, y_pred)
    return auc, apr

disc_scores = {}
for name, array in arrays.items():
    if name not in ['mimic_iv']:
    # if name not in ['mimic_iv', 'mimic_iii', 'eICU']:
        print(f"\nTraining discriminator for {name}...")
        auc, apr = train_discriminator(real_array, array)
        disc_scores[name] = {'auc': auc, 'apr': apr}
print("\nDiscriminative scores:")
for name, scores in disc_scores.items():
    print(f"{name}: AUC={scores['auc']:.4f}, APR={scores['apr']:.4f}")

# 2.5 Visualization Functions
def visualize_trajectories(data_dict, feature_names=feature_names, n_patients=5):
    n_features = len(feature_names)
    n_datasets = len(data_dict)
    fig, axes = plt.subplots(n_features, n_datasets, figsize=(6*n_datasets, 4*n_features))
    axes = np.atleast_2d(axes)
    
    for col, (name, data) in enumerate(data_dict.items()):
        patients = np.random.choice(data.shape[0], n_patients, replace=False)
        for row in range(n_features):
            ax = axes[row, col]
            for p in patients:
                ax.plot(data[p, :, row], alpha=0.7)
            ax.set_title(f'{feature_names[row]} - {name}')
            ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/trajectories.png", dpi=300)
    plt.close()

visualize_trajectories(selected_datasets)

def visualize_dim_reduction(data_dict):
    combined = []
    labels = []
    for i, (name, data) in enumerate(data_dict.items()):
        flat = data.reshape(data.shape[0], -1)
        max_samples = 1000
        if len(flat) > max_samples:
            indices = np.random.choice(len(flat), max_samples, replace=False)
            flat = flat[indices]
        combined.append(flat)
        labels.append(np.full(len(flat), i))
    
    combined = np.vstack(combined)
    labels = np.concatenate(labels)
    tsne = TSNE(n_components=2, random_state=42)
    embedded = tsne.fit_transform(combined)
    
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embedded[:, 0], embedded[:, 1], c=labels, cmap='viridis', alpha=0.7, s=5)
    plt.legend(handles=scatter.legend_elements()[0], labels=list(data_dict.keys()))
    plt.title('t-SNE Visualization')
    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/tsne.png", dpi=300)
    plt.close()

visualize_dim_reduction(selected_datasets)

# 2.6 Feature Temporal Correlation
def feature_temporal_correlation(data_dict, feature_names=feature_names):
    fig, axes = plt.subplots(1, len(data_dict), figsize=(7*len(data_dict), 7))
    axes = np.atleast_1d(axes)
    
    timepoints = np.linspace(0, max_hours-1, 6, dtype=int)
    for ax, (name, data) in zip(axes, data_dict.items()):
        selected_data = []
        combined_names = []
        for i, feature in enumerate(feature_names):
            for tp in timepoints:
                selected_data.append(data[:, tp, i])
                combined_names.append(f"{feature}_{tp:02d}")
        
        selected_data = np.array(selected_data).T
        df = pd.DataFrame(data=selected_data, columns=combined_names)
        corr_matrix = df.corr()
        mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
        sns.heatmap(corr_matrix, ax=ax, mask=mask, cmap="coolwarm", vmin=-1, vmax=1, center=0)
        ax.set_xticklabels(ax.get_xticklabels(), rotation=75)
        ax.set_title(f"{name}")
    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/temporal_correlation.png", dpi=300)
    plt.close()

feature_temporal_correlation(selected_datasets)

# Summary
metrics_summary = pd.DataFrame({
    'Dataset': list(mmd_results.keys()),
    'MMD': list(mmd_results.values()),
    'Corr_Diff': [corr_diffs[name] for name in mmd_results.keys()],
    'Dim_CC': [dim_stats[name]['cc'] for name in mmd_results.keys()],
    'Dim_RMSE': [dim_stats[name]['rmse'] for name in mmd_results.keys()],
    'Disc_AUC': [disc_scores.get(name, {'auc': np.nan})['auc'] for name in mmd_results.keys()],
    'Disc_APR': [disc_scores.get(name, {'apr': np.nan})['apr'] for name in mmd_results.keys()]
})
print("\nSummary of Evaluation Metrics:")
print(metrics_summary)
metrics_summary.to_csv(f"{evaluation_dir}/metrics_summary.csv", index=False)

In [None]:
def visualize_mean_variance(data_dict, feature_names=feature_names, max_hours=20):
    """
    Visualize the mean and variance of features over the first max_hours timestamps for each dataset.

    Args:
        data_dict (dict): Dictionary of datasets (3D arrays) with dataset names as keys.
        feature_names (list): List of feature names corresponding to the last dimension of the arrays.
        max_hours (int): Number of timestamps to consider (e.g., 20).
    """
    n_features = len(feature_names)
    n_datasets = len(data_dict)
    fig, axes = plt.subplots(n_features, 2, figsize=(12, 4 * n_features))  # 2 columns: mean and variance

    for row, feature_idx in enumerate(range(n_features)):
        for col, stat in enumerate(['Mean', 'Variance']):
            ax = axes[row, col] if n_features > 1 else axes[col]
            for name, data in data_dict.items():
                # Compute mean or variance over the first max_hours timestamps
                if stat == 'Mean':
                    values = np.nanmean(data[:, :max_hours, feature_idx], axis=0)
                elif stat == 'Variance':
                    values = np.nanvar(data[:, :max_hours, feature_idx], axis=0)

                # Plot the computed values
                ax.plot(range(max_hours), values, label=name, alpha=0.8)

            ax.set_title(f'{feature_names[feature_idx]} - {stat}')
            ax.set_xlabel('Time (hours)')
            ax.set_ylabel(stat)
            ax.grid(True, alpha=0.3)
            ax.legend()

    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/mean_variance.png", dpi=300)
    plt.close()

# Visualize mean and variance for the first 20 hours
visualize_mean_variance(selected_datasets, feature_names=feature_names, max_hours=20)

### Old evaluation metric

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, mean_squared_error
from sklearn.model_selection import train_test_split
from math import sqrt
import os

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

# Set up matplotlib
plt.style.use('seaborn-whitegrid')
sns.set_style("whitegrid")

# Create output directory for evaluation results
evaluation_dir = f"{output_dir}/evaluations/"
os.makedirs(evaluation_dir, exist_ok=True)

# Helper function to convert dataframe to 3D array for comparison
def df_to_array(df, feature_cols=None):
    """Convert DataFrame to 3D numpy array (n_patients, n_timesteps, n_features)"""
    if feature_cols is None:
        feature_cols = ['heart_rate', 'resp_rate', 'spo2', 'fio2', 'respiratory_rate_set']
    
    # Get unique patient IDs
    patient_ids = df['stay_id'].unique()
    n_patients = len(patient_ids)
    n_features = len(feature_cols)
    
    # Create mapping from stay_id to array index
    patient_id_to_idx = {id: i for i, id in enumerate(patient_ids)}
    
    # Initialize output array
    array_3d = np.zeros((n_patients, 24, n_features))
    array_3d.fill(np.nan)
    
    # Fill the array with values
    for _, row in df.iterrows():
        patient_idx = patient_id_to_idx[row['stay_id']]
        hour = int(row['hours_in'])
        
        # Skip if hour is outside range
        if hour < 0 or hour >= 24:
            continue
            
        # Extract features
        features = row[feature_cols].values
        
        # Fill the array at the correct position
        array_3d[patient_idx, hour, :] = features
    
    # Forward-fill missing values
    for p in range(n_patients):
        for f in range(n_features):
            last_valid = None
            for h in range(24):
                if np.isnan(array_3d[p, h, f]) and last_valid is not None:
                    array_3d[p, h, f] = last_valid
                elif not np.isnan(array_3d[p, h, f]):
                    last_valid = array_3d[p, h, f]
    
    return array_3d

# Convert the dataframe to 3D array for comparison
real_array = df_to_array(baseline_charttime_ground_truth_first_24_hr_df) # TODO: not sure use all or just first 24 hours
print(f"Real data array shape: {real_array.shape}")

# 1. MAX MEAN DISCREPANCY (MMD)
def calculate_mmd(real_data, synthetic_data, bandwidth_multipliers=[0.2, 0.5, 0.9, 1.3]):
    """
    Calculate Maximum Mean Discrepancy using sklearn implementation
    
    Args:
        real_data: Real data array (n_samples, n_features)
        synthetic_data: Synthetic data array (n_samples, n_features)
        bandwidth_multipliers: List of multipliers for the median heuristic
        
    Returns:
        MMD value
    """
    from sklearn.metrics.pairwise import rbf_kernel
    
    # Flatten data to 2D
    real_flat = real_data.reshape(real_data.shape[0], -1)
    synth_flat = synthetic_data.reshape(synthetic_data.shape[0], -1)
    
    # Sample if too large (for computational efficiency)
    max_samples = 1000
    if len(real_flat) > max_samples:
        indices = np.random.choice(len(real_flat), max_samples, replace=False)
        real_flat = real_flat[indices]
    
    if len(synth_flat) > max_samples:
        indices = np.random.choice(len(synth_flat), max_samples, replace=False)
        synth_flat = synth_flat[indices]
    
    # Calculate median bandwidth using median heuristic
    X = np.vstack([real_flat, synth_flat])
    median_dist = np.median(np.sqrt(np.sum((X[:, None, :] - X[None, :, :]) ** 2, axis=-1)))
    
    mmd_values = []
    for bandwidth_multiplier in bandwidth_multipliers:
        bandwidth = bandwidth_multiplier * median_dist
        
        # Calculate kernel matrices
        K_XX = rbf_kernel(real_flat, real_flat, gamma=1.0/(2*bandwidth**2))
        K_YY = rbf_kernel(synth_flat, synth_flat, gamma=1.0/(2*bandwidth**2))
        K_XY = rbf_kernel(real_flat, synth_flat, gamma=1.0/(2*bandwidth**2))
        
        # Calculate MMD
        mmd = np.mean(K_XX) - 2 * np.mean(K_XY) + np.mean(K_YY)
        mmd_values.append(max(0, mmd))  # MMD should be non-negative
    
    # Return average MMD across bandwidths
    return np.mean(mmd_values)

# Calculate MMD for GAN and TABLE
mmd_gan = calculate_mmd(real_array, gan_renormalized)
mmd_table = calculate_mmd(real_array, table_renormalized)

print(f"MMD (GAN): {mmd_gan:.6f}")
print(f"MMD (TABLE): {mmd_table:.6f}")

# 2. PEARSON CORRELATION 
def pearson_correlation_comparison(real_data, gan_data, table_data, feature_names=None):
    """
    Compare correlation matrices between real, GAN, and TABLE data
    
    Args:
        real_data: 3D array of real data
        gan_data: 3D array of GAN-generated data
        table_data: 3D array of TABLE-generated data
        feature_names: List of feature names
    """
    if feature_names is None:
        feature_names = ['HR', 'RR', 'SpO2', 'FiO2', 'RR_set']
    
    # Flatten data for correlation calculation
    # We'll use the first 1000 samples if available
    real_flat = real_data.reshape(-1, real_data.shape[-1])
    gan_flat = gan_data.reshape(-1, gan_data.shape[-1])
    table_flat = table_data.reshape(-1, table_data.shape[-1])
    
    # Calculate correlation matrices
    real_corr = np.corrcoef(real_flat.T)
    gan_corr = np.corrcoef(gan_flat.T)
    table_corr = np.corrcoef(table_flat.T)
    
    # Calculate differences
    gan_diff = gan_corr - real_corr
    table_diff = table_corr - real_corr
    
    # Calculate Frobenius norm of difference matrices
    gan_frob = np.linalg.norm(gan_diff)
    table_frob = np.linalg.norm(table_diff)
    
    print(f"Correlation difference (Frobenius norm) - GAN: {gan_frob:.4f}, TABLE: {table_frob:.4f}")
    
    # Create correlation plots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Real data correlation
    sns.heatmap(real_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, 
                xticklabels=feature_names, yticklabels=feature_names, ax=axes[0, 0])
    axes[0, 0].set_title('Real Data Correlation')
    
    # GAN data correlation
    sns.heatmap(gan_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, 
                xticklabels=feature_names, yticklabels=feature_names, ax=axes[0, 1])
    axes[0, 1].set_title('GAN Data Correlation')
    
    # TABLE data correlation
    sns.heatmap(table_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, 
                xticklabels=feature_names, yticklabels=feature_names, ax=axes[0, 2])
    axes[0, 2].set_title('TABLE Data Correlation')
    
    # GAN difference
    sns.heatmap(gan_diff, annot=True, cmap='coolwarm', vmin=-0.5, vmax=0.5, center=0, 
                xticklabels=feature_names, yticklabels=feature_names, ax=axes[1, 0])
    axes[1, 0].set_title('GAN - Real Difference')
    
    # TABLE difference
    sns.heatmap(table_diff, annot=True, cmap='coolwarm', vmin=-0.5, vmax=0.5, center=0, 
                xticklabels=feature_names, yticklabels=feature_names, ax=axes[1, 1])
    axes[1, 1].set_title('TABLE - Real Difference')
    
    # GAN vs TABLE difference
    gan_table_diff = gan_corr - table_corr
    sns.heatmap(gan_table_diff, annot=True, cmap='coolwarm', vmin=-0.5, vmax=0.5, center=0, 
                xticklabels=feature_names, yticklabels=feature_names, ax=axes[1, 2])
    axes[1, 2].set_title('GAN - TABLE Difference')
    
    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/correlation_comparison.png", dpi=300)
    plt.show()
    
    return real_corr, gan_corr, table_corr

# Run correlation analysis
feature_names = ['Heart Rate', 'Resp. Rate', 'SpO2', 'FiO2', 'RR Set']
real_corr, gan_corr, table_corr = pearson_correlation_comparison(
    real_array, gan_renormalized, table_renormalized, feature_names)

# 3. DIMENSION-WISE PROBABILITY
def dimension_wise_probability(real_data, gan_data, table_data, feature_names=None):
    """
    Calculate and plot dimension-wise statistics
    """
    if feature_names is None:
        feature_names = ['Heart Rate', 'Resp. Rate', 'SpO2', 'FiO2', 'RR Set']
    
    n_features = real_data.shape[2]
    
    # Calculate means and stds for each feature over time
    real_means = np.nanmean(real_data, axis=(0, 1))  # Average over patients and time steps
    gan_means = np.nanmean(gan_data, axis=(0, 1))
    table_means = np.nanmean(table_data, axis=(0, 1))
    
    real_stds = np.nanstd(real_data, axis=(0, 1))
    gan_stds = np.nanstd(gan_data, axis=(0, 1))
    table_stds = np.nanstd(table_data, axis=(0, 1))
    
    # Calculate temporal means
    real_temporal_means = np.nanmean(real_data, axis=0)  # Average over patients only
    gan_temporal_means = np.nanmean(gan_data, axis=0)
    table_temporal_means = np.nanmean(table_data, axis=0)
    
    # Calculate correlation coefficients and RMSEs
    gan_cc = np.corrcoef(real_means, gan_means)[0, 1]
    table_cc = np.corrcoef(real_means, table_means)[0, 1]
    
    gan_rmse = sqrt(mean_squared_error(real_means, gan_means))
    table_rmse = sqrt(mean_squared_error(real_means, table_means))
    
    # Plot feature means comparison
    plt.figure(figsize=(10, 6))
    x = np.arange(len(feature_names))
    width = 0.25
    
    plt.bar(x - width, real_means, width, label='Real')
    plt.bar(x, gan_means, width, label='GAN')
    plt.bar(x + width, table_means, width, label='TABLE')
    
    plt.xlabel('Features')
    plt.ylabel('Mean Value')
    plt.title('Mean Feature Values Comparison')
    plt.xticks(x, feature_names, rotation=45)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/feature_means_comparison.png", dpi=300)
    plt.show()
    
    # Plot temporal patterns for each feature
    fig, axes = plt.subplots(n_features, 1, figsize=(12, 4*n_features))
    for i in range(n_features):
        axes[i].plot(real_temporal_means[:, i], 'b-', label='Real')
        axes[i].plot(gan_temporal_means[:, i], 'r-', label='GAN')
        axes[i].plot(table_temporal_means[:, i], 'g-', label='TABLE')
        axes[i].set_title(f'{feature_names[i]} - Temporal Pattern')
        axes[i].set_xlabel('Time Step')
        axes[i].set_ylabel('Mean Value')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/temporal_patterns.png", dpi=300)
    plt.show()
    
    # Create scatter plot comparing real vs. synthetic means
    plt.figure(figsize=(10, 5))
    
    plt.subplot(1, 2, 1)
    plt.scatter(real_means, gan_means, c=range(len(feature_names)), cmap='viridis')
    for i in range(len(feature_names)):
        plt.annotate(feature_names[i], (real_means[i], gan_means[i]))
    plt.plot([min(real_means), max(real_means)], [min(real_means), max(real_means)], 'k--')
    plt.title(f'GAN vs. Real (CC={gan_cc:.4f}, RMSE={gan_rmse:.4f})')
    plt.xlabel('Real Data Mean')
    plt.ylabel('GAN Data Mean')
    
    plt.subplot(1, 2, 2)
    plt.scatter(real_means, table_means, c=range(len(feature_names)), cmap='viridis')
    for i in range(len(feature_names)):
        plt.annotate(feature_names[i], (real_means[i], table_means[i]))
    plt.plot([min(real_means), max(real_means)], [min(real_means), max(real_means)], 'k--')
    plt.title(f'TABLE vs. Real (CC={table_cc:.4f}, RMSE={table_rmse:.4f})')
    plt.xlabel('Real Data Mean')
    plt.ylabel('TABLE Data Mean')
    
    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/feature_scatter_comparison.png", dpi=300)
    plt.show()
    
    return {
        'gan_cc': gan_cc,
        'table_cc': table_cc,
        'gan_rmse': gan_rmse,
        'table_rmse': table_rmse
    }

# Run dimension-wise probability analysis
dim_stats = dimension_wise_probability(real_array, gan_renormalized, table_renormalized)

# 4. DISCRIMINATIVE SCORE
def train_discriminator(real_data, synthetic_data):
    """
    Train a model to distinguish real from synthetic data
    
    Returns:
        AUC, AP score and training history
    """
    # Flatten data while preserving patient structure
    real_flat = real_data.reshape(real_data.shape[0], -1)
    syn_flat = synthetic_data.reshape(synthetic_data.shape[0], -1)
    
    # Combine and create labels
    X = np.vstack([real_flat, syn_flat])
    y = np.concatenate([np.zeros(len(real_flat)), np.ones(len(syn_flat))])
    
    # Split train/test
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Create model
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    
    # Compile model
    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    
    # Train model
    history = model.fit(
        X_train, y_train,
        epochs=10,
        batch_size=32,
        validation_data=(X_test, y_test),
        verbose=1
    )
    
    # Evaluate model
    y_pred = model.predict(X_test)
    from sklearn.metrics import roc_auc_score, average_precision_score
    
    auc = roc_auc_score(y_test, y_pred)
    apr = average_precision_score(y_test, y_pred)
    
    return auc, apr, history

# Calculate discriminative scores
print("\nTraining discriminator for GAN data...")
gan_auc, gan_apr, gan_history = train_discriminator(real_array, gan_renormalized)
print(f"GAN - AUC: {gan_auc:.4f}, APR: {gan_apr:.4f}")

print("\nTraining discriminator for TABLE data...")
table_auc, table_apr, table_history = train_discriminator(real_array, table_renormalized)
print(f"TABLE - AUC: {table_auc:.4f}, APR: {table_apr:.4f}")

# Plot discriminator results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(gan_history.history['accuracy'], label='train')
plt.plot(gan_history.history['val_accuracy'], label='test')
plt.title(f'GAN Discriminator (AUC={gan_auc:.4f})')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(table_history.history['accuracy'], label='train')
plt.plot(table_history.history['val_accuracy'], label='test')
plt.title(f'TABLE Discriminator (AUC={table_auc:.4f})')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()

plt.tight_layout()
plt.savefig(f"{evaluation_dir}/discriminator_comparison.png", dpi=300)
plt.show()

# 5. VISUALIZATION FUNCTIONS
def visualize_trajectories(real_data, gan_data, table_data, feature_names=None, n_patients=5):
    """
    Visualize example trajectories from real, GAN, and TABLE data
    """
    if feature_names is None:
        feature_names = ['Heart Rate', 'Resp. Rate', 'SpO2', 'FiO2', 'RR Set']
    
    n_features = len(feature_names)
    fig, axes = plt.subplots(n_features, 3, figsize=(18, 4*n_features))
    
    # Randomly select patients
    real_patients = np.random.choice(real_data.shape[0], n_patients, replace=False)
    gan_patients = np.random.choice(gan_data.shape[0], n_patients, replace=False)
    table_patients = np.random.choice(table_data.shape[0], n_patients, replace=False)
    
    # Plot each feature
    for i in range(n_features):
        # Real data
        for p in real_patients:
            axes[i, 0].plot(real_data[p, :, i], 'b-', alpha=0.7)
        axes[i, 0].set_title(f'{feature_names[i]} - Real')
        axes[i, 0].grid(True, alpha=0.3)
        
        # GAN data
        for p in gan_patients:
            axes[i, 1].plot(gan_data[p, :, i], 'r-', alpha=0.7)
        axes[i, 1].set_title(f'{feature_names[i]} - GAN')
        axes[i, 1].grid(True, alpha=0.3)
        
        # TABLE data
        for p in table_patients:
            axes[i, 2].plot(table_data[p, :, i], 'g-', alpha=0.7)
        axes[i, 2].set_title(f'{feature_names[i]} - TABLE')
        axes[i, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/trajectory_visualization.png", dpi=300)
    plt.show()

# Visualize example trajectories
visualize_trajectories(real_array, gan_renormalized, table_renormalized)

# 6. DIMENSIONALITY REDUCTION VISUALIZATION
def visualize_dim_reduction(real_data, gan_data, table_data):
    """
    Visualize data in lower dimensions using t-SNE
    """
    # Flatten data
    real_flat = real_data.reshape(real_data.shape[0], -1)
    gan_flat = gan_data.reshape(gan_data.shape[0], -1)
    table_flat = table_data.reshape(table_data.shape[0], -1)
    
    # Sample for computational efficiency
    max_samples = 1000
    if len(real_flat) > max_samples:
        real_indices = np.random.choice(len(real_flat), max_samples, replace=False)
        real_flat = real_flat[real_indices]
    
    if len(gan_flat) > max_samples:
        gan_indices = np.random.choice(len(gan_flat), max_samples, replace=False)
        gan_flat = gan_flat[gan_indices]
        
    if len(table_flat) > max_samples:
        table_indices = np.random.choice(len(table_flat), max_samples, replace=False)
        table_flat = table_flat[table_indices]
    
    # Combine data
    combined = np.vstack([real_flat, gan_flat, table_flat])
    
    # Create labels
    labels = np.concatenate([
        np.zeros(len(real_flat)),  # Real: 0
        np.ones(len(gan_flat)),    # GAN: 1
        2*np.ones(len(table_flat))  # TABLE: 2
    ])
    
    # Apply t-SNE
    print("Computing t-SNE embedding...")
    tsne = TSNE(n_components=2, random_state=42)
    embedded = tsne.fit_transform(combined)
    
    # Create plot
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(
        embedded[:, 0], embedded[:, 1], 
        c=labels, cmap='viridis', 
        alpha=0.7, s=5
    )
    
    # Add legend
    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#440154', label='Real', markersize=8),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#21918c', label='GAN', markersize=8),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#fde725', label='TABLE', markersize=8)
    ]
    plt.legend(handles=legend_elements)
    
    plt.title('t-SNE Visualization of Real and Generated Data')
    plt.tight_layout()
    plt.savefig(f"{evaluation_dir}/tsne_visualization.png", dpi=300)
    plt.show()

# Run t-SNE visualization
visualize_dim_reduction(real_array, gan_renormalized, table_renormalized)

# 7. SUMMARY METRICS TABLE
# Create a summary table of all metrics
metrics_summary = pd.DataFrame({
    'Metric': [
        'Max Mean Discrepancy (MMD)', 
        'Correlation Matrix Difference (Frobenius Norm)',
        'Feature Mean Correlation Coefficient (CC)',
        'Feature Mean RMSE',
        'Discriminator AUC',
        'Discriminator APR'
    ],
    'GAN': [
        mmd_gan,
        np.linalg.norm(gan_corr - real_corr),
        dim_stats['gan_cc'],
        dim_stats['gan_rmse'],
        gan_auc,
        gan_apr
    ],
    'TABLE': [
        mmd_table,
        np.linalg.norm(table_corr - real_corr),
        dim_stats['table_cc'],
        dim_stats['table_rmse'],
        table_auc,
        table_apr
    ]
})

print("\nSummary of Evaluation Metrics:")
print(metrics_summary)

# Save metrics to CSV
metrics_summary.to_csv(f"{evaluation_dir}/metrics_summary.csv", index=False)
print(f"\nEvaluation metrics saved to {evaluation_dir}/metrics_summary.csv")

In [None]:
def feature_temporal_correlation(real_data, gan_data, table_data, feature_names=None, 
                                 timepoints=None, output_dir=None):
    """
    Create correlation heatmaps showing temporal relationships between features
    for real, GAN, and TABLE generated data.
    
    Args:
        real_data: Real data array (n_patients, n_timesteps, n_features)
        gan_data: GAN generated data array (n_patients, n_timesteps, n_features)
        table_data: TABLE generated data array (n_patients, n_timesteps, n_features)
        feature_names: Names of features (default: uses feature_names list)
        timepoints: Specific timepoints to include (default: select 6 evenly spread points)
        output_dir: Directory to save figures (optional)
    """
    if feature_names is None:
        feature_names = ['Heart Rate', 'Resp. Rate', 'SpO2', 'FiO2', 'RR Set']
    
    # Select timepoints to analyze (to avoid overcrowding the plot)
    n_timesteps = real_data.shape[1]
    if timepoints is None:
        # Select 6 evenly spaced timepoints
        timepoints = np.linspace(0, n_timesteps-1, 6, dtype=int)
    
    fig, axes = plt.subplots(1, 3, figsize=(21, 7))
    
    # Function to process and plot one dataset
    def create_correlation_plot(data, ax, title):
        # Extract data and create feature names with timepoints
        selected_data = []
        combined_feature_names = []
        
        for i, feature in enumerate(feature_names):
            for tp in timepoints:
                # Extract data for this feature and timepoint
                feature_data = data[:, tp, i]
                selected_data.append(feature_data)
                # Create label: "HR_00", "HR_06", etc.
                combined_feature_names.append(f"{feature}_{tp:02d}")
        
        # Transpose to get shape (n_patients, n_selected_features)
        selected_data = np.array(selected_data).T
        
        # Create DataFrame and calculate correlation
        df = pd.DataFrame(data=selected_data, columns=combined_feature_names)
        corr_matrix = df.corr()
        
        # Create mask for upper triangle
        mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
        
        # Plot heatmap - added linewidths=0 to remove grid lines
        sns.heatmap(corr_matrix, ax=ax, mask=mask, cmap="coolwarm", 
                   vmin=-1., vmax=1., center=0, square=True, 
                   linewidths=0, linecolor=None,
                   cbar_kws={"shrink": .5})
        
        # Adjust labels
        ax.set_xticklabels(ax.get_xticklabels(), rotation=75)
        ax.set_title(title)
    
    # Create plots for real and generated data
    create_correlation_plot(real_data, axes[0], "Real Data - Temporal Correlations")
    create_correlation_plot(gan_data, axes[1], "GAN Data - Temporal Correlations")
    create_correlation_plot(table_data, axes[2], "TABLE Data - Temporal Correlations")
    
    plt.tight_layout()
    
    # Save figure if output_dir is provided
    if output_dir:
        plt.savefig(f'{output_dir}/feature_time_correlations.png', dpi=300)
        plt.savefig(f'{output_dir}/feature_time_correlations.pdf', format='pdf')
    
    plt.show()


def feature_cross_correlation_difference(real_data, gan_data, table_data, feature_names=None, output_dir=None):
    """
    Create correlation difference heatmaps showing how feature correlations 
    differ between real and generated data (both GAN and TABLE).
    
    Args:
        real_data: Real data array (n_patients, n_timesteps, n_features)
        gan_data: GAN generated data array (n_patients, n_timesteps, n_features)
        table_data: TABLE generated data array (n_patients, n_timesteps, n_features)
        feature_names: Names of features
        output_dir: Directory to save figures (optional)
    """
    if feature_names is None:
        feature_names = ['HR', 'RR', 'SpO2', 'FiO2', 'RR_Set']
    
    # Reshape to 2D: (patients*timesteps, features)
    real_flat = real_data.reshape(-1, real_data.shape[-1])
    gan_flat = gan_data.reshape(-1, gan_data.shape[-1])
    table_flat = table_data.reshape(-1, table_data.shape[-1])
    
    # Calculate correlation matrices
    real_corr = np.corrcoef(real_flat.T)
    gan_corr = np.corrcoef(gan_flat.T)
    table_corr = np.corrcoef(table_flat.T)
    
    # Calculate differences
    gan_diff = gan_corr - real_corr
    table_diff = table_corr - real_corr
    gan_table_diff = gan_corr - table_corr
    
    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(21, 7))
    
    # Plot GAN difference
    sns.heatmap(gan_diff, annot=True, cmap='coolwarm', vmin=-0.5, vmax=0.5, center=0, 
                square=True, ax=axes[0], linewidths=0, linecolor=None,
                xticklabels=feature_names, yticklabels=feature_names)
    axes[0].set_title('GAN - Real Correlation Difference')
    
    # Plot TABLE difference
    sns.heatmap(table_diff, annot=True, cmap='coolwarm', vmin=-0.5, vmax=0.5, center=0, 
                square=True, ax=axes[1], linewidths=0, linecolor=None,
                xticklabels=feature_names, yticklabels=feature_names)
    axes[1].set_title('TABLE - Real Correlation Difference')
    
    # Plot GAN vs TABLE difference
    sns.heatmap(gan_table_diff, annot=True, cmap='coolwarm', vmin=-0.5, vmax=0.5, center=0, 
                square=True, ax=axes[2], linewidths=0, linecolor=None,
                xticklabels=feature_names, yticklabels=feature_names)
    axes[2].set_title('GAN - TABLE Correlation Difference')
    
    plt.tight_layout()
    
    # Save figure if output_dir is provided
    if output_dir:
        plt.savefig(f'{output_dir}/correlation_differences.png', dpi=300)
        plt.savefig(f'{output_dir}/correlation_differences.pdf', format='pdf')
    
    plt.show()
    
    # Calculate and return metrics
    gan_mean_abs_diff = np.mean(np.abs(gan_diff))
    table_mean_abs_diff = np.mean(np.abs(table_diff))
    
    print(f"GAN - Real mean absolute correlation difference: {gan_mean_abs_diff:.4f}")
    print(f"TABLE - Real mean absolute correlation difference: {table_mean_abs_diff:.4f}")
    
    # Create a summary dataframe
    diff_summary = pd.DataFrame({
        'Method': ['GAN vs Real', 'TABLE vs Real', 'GAN vs TABLE'],
        'Mean Abs Diff': [
            gan_mean_abs_diff,
            table_mean_abs_diff,
            np.mean(np.abs(gan_table_diff))
        ],
        'Max Abs Diff': [
            np.max(np.abs(gan_diff)),
            np.max(np.abs(table_diff)),
            np.max(np.abs(gan_table_diff))
        ]
    })
    
    print("\nCorrelation Difference Summary:")
    print(diff_summary)
    
    return diff_summary

In [None]:
# Run temporal correlation analysis
print("\nGenerating feature-temporal correlation heatmaps...")
feature_temporal_correlation(
    real_array, 
    gan_renormalized,
    table_renormalized,
    feature_names=['Heart Rate', 'Resp. Rate', 'SpO2', 'FiO2', 'RR Set'],
    output_dir=evaluation_dir
)

# Run correlation difference analysis
print("\nGenerating correlation difference heatmaps...")
corr_diff_summary = feature_cross_correlation_difference(
    real_array, 
    gan_renormalized,
    table_renormalized,
    feature_names=['HR', 'RR', 'SpO2', 'FiO2', 'RR_Set'],
    output_dir=evaluation_dir
)