In [1]:
%pip install pandas numpy scipy scikit-learn sdv sdmetrics matplotlib seaborn plotly kaleido openpyxl python-docx tqdm xlsxwriter

Note: you may need to restart the kernel to use updated packages.


In [26]:
import pandas as pd
import numpy as np
from pathlib import Path
import os
import time
import json
from datetime import datetime
from scipy.stats import ks_2samp, ttest_ind, wasserstein_distance
from sklearn.impute import KNNImputer
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
import traceback
import itertools
import re

from sdv.metadata import SingleTableMetadata
from sdv.single_table import CTGANSynthesizer
from sdmetrics.reports.single_table import DiagnosticReport, QualityReport

# rename columns because sdmetrics gets confused

def _rename_reserved_columns(df: pd.DataFrame) -> pd.DataFrame:
    
    df = df.copy()

    reserved = {"columns", "table", "tables", "constraints", 
                "relationships", "primary_key", "foreign_key"}

    new_cols = []
    for col in df.columns:
        if col.lower() in reserved:
            new_cols.append(f"col_{col}")
        else:
            new_cols.append(col)

    df.columns = new_cols
    return df



#  DATA PRE‑PROCESSING

def _sanitize_numeric(df: pd.DataFrame, num_cols: list[str]) -> pd.DataFrame:
    """Replace ±Inf with NaN and clip extreme outliers in numeric columns."""
    df = df.copy()
    df[num_cols] = df[num_cols].replace([np.inf, -np.inf], np.nan)

    for col in num_cols:
        valid_values = df[col].dropna()
        if len(valid_values) >= 4:  #  threshold to avoid issues with small samples
            q1, q3 = valid_values.quantile(0.25), valid_values.quantile(0.75)
            iqr = q3 - q1
            upper, lower = q3 + 5 * iqr, q1 - 5 * iqr
            df[col] = df[col].clip(lower, upper)

    return df


def load_and_preprocess(csv_path: str | Path, *, knn_neighbors: int = 4) -> tuple[pd.DataFrame, list[str], dict]:
    
    #1)  Load CSV
    #2)  Force‑rename any 'reserved' column (particularly "columns")
    #3)  Drop 'No.' if exists
    #4)  Sanitize numeric extremes & KNN‑impute missing
    #5)  log1p for skewed numeric columns that are ≥ 0
    
    start_time = time.time()
    preprocessing_metadata = {
        "start_time": start_time,
        "original_file": str(csv_path),
        "steps": []
    }
    
    # --- read CSV ------------------------------------------------------------
    try:
        df = pd.read_csv(csv_path)
        preprocessing_metadata["steps"].append({
            "action": "read_csv",
            "rows": len(df),
            "columns": len(df.columns)
        })
    except Exception as e:
        print(f"Error reading CSV: {e}")
        raise
    
    # --- rename reserved columns ---------------------------------------------
    df = _rename_reserved_columns(df)            # <= rename anything that might break
    preprocessing_metadata["steps"].append({"action": "rename_reserved_columns"})
    
    # --- drop unnecessary columns --------------------------------------------
    original_cols = list(df.columns)
    df.drop(columns=["No."], errors="ignore", inplace=True)
    dropped_cols = set(original_cols) - set(df.columns)
    preprocessing_metadata["steps"].append({
        "action": "drop_columns",
        "dropped": list(dropped_cols)
    })

    # --- identify numeric columns --------------------------------------------
    num_cols = df.select_dtypes(include="number").columns.tolist()
    preprocessing_metadata["steps"].append({
        "action": "identify_numeric",
        "numeric_columns": num_cols,
        "count": len(num_cols)
    })

    # --- guard: remove ±Inf, clip outliers -----------------------------------
    df = _sanitize_numeric(df, num_cols)
    preprocessing_metadata["steps"].append({"action": "sanitize_numeric"})

    # --- KNN impute ----------------------------------------------------------
    if num_cols:  # Only impute if numeric columns exist
        imputer = KNNImputer(n_neighbors=min(knn_neighbors, len(df) - 1))  # Avoid n_neighbors > n_samples
        # Track missing values before imputation
        missing_before = df[num_cols].isna().sum().to_dict()
        df[num_cols] = imputer.fit_transform(df[num_cols])
        preprocessing_metadata["steps"].append({
            "action": "knn_impute",
            "missing_values_before": missing_before,
            "n_neighbors": knn_neighbors
        })

    # --- log1p skewed numeric columns ----------------------------------------
    skewness = df[num_cols].skew().abs()
    logged_cols = [c for c in skewness[skewness > 1].index if df[c].min() >= 0]
    
    if logged_cols:
        df[logged_cols] = np.log1p(df[logged_cols])
        preprocessing_metadata["steps"].append({
            "action": "log1p_transform",
            "transformed_columns": logged_cols,
            "count": len(logged_cols)
        })

    # --- second pass of imputation if log1p introduced any NaNs --------------
    if num_cols:  # Only impute if numeric columns exist
        if df[num_cols].isna().any().any():
            imputer = KNNImputer(n_neighbors=min(knn_neighbors, len(df) - 1))
            df[num_cols] = imputer.fit_transform(df[num_cols])
            df[num_cols] = df[num_cols].astype(np.float64)
            preprocessing_metadata["steps"].append({"action": "second_knn_impute"})

    # --- finalize preprocessing metadata -------------------------------------
    end_time = time.time()
    preprocessing_metadata["end_time"] = end_time
    preprocessing_metadata["duration"] = end_time - start_time
    preprocessing_metadata["final_shape"] = df.shape
    
    return df, logged_cols, preprocessing_metadata


# ──────────────────────────────────────────────────────────────────────────────
#  2)  SYNTHETIC DATA GENERATION
# ──────────────────────────────────────────────────────────────────────────────
def generate_synthetic(real_df: pd.DataFrame,
                       logged_cols: list[str],
                       *,
                       epochs: int = 1000,
                       batch_size: int = 500,
                       gen_dim: tuple[int, ...] = (256, 256),
                       dis_dim: tuple[int, ...] = (256, 256)) -> tuple[pd.DataFrame, CTGANSynthesizer, dict]:
    """Train CTGAN and return (synthetic_df, synthesizer, metadata)."""
    # Start timing and initialize metadata
    start_time = time.time()
    gen_metadata = {
        "start_time": start_time,
        "model": "CTGAN",
        "epochs": epochs,
        "batch_size": batch_size,
        "gen_dim": gen_dim,
        "dis_dim": dis_dim,
        "events": []
    }
    
    # Use SingleTableMetadata
    metadata = SingleTableMetadata()
    metadata.detect_from_dataframe(real_df)
    gen_metadata["events"].append({
        "event": "metadata_detection",
        "timestamp": time.time()
    })

    # Create synthesizer
    synth = CTGANSynthesizer(
        metadata,
        enforce_rounding=False,
        epochs=epochs,
        batch_size=batch_size,
        generator_dim=gen_dim,
        discriminator_dim=dis_dim,
        verbose=True
    )
    
    # Verify data is fully numeric & finite for CTGAN
    num_data = real_df.select_dtypes(include="number")
    if not np.isfinite(num_data.to_numpy()).all():
        # Safety: replace any remaining NaNs or infinities
        print("Warning: replacing remaining NaN/Inf values in numeric columns")
        for col in num_data.columns:
            median_val = real_df[col].median()
            real_df[col] = real_df[col].fillna(median_val)
            real_df[col] = real_df[col].replace([np.inf, -np.inf], median_val)
        gen_metadata["events"].append({
            "event": "replaced_remaining_nans",
            "timestamp": time.time()
        })

    # Train
    print(f"Training CTGAN model with {epochs} epochs...")
    fit_start = time.time()
    synth.fit(real_df)
    fit_end = time.time()
    
    gen_metadata["events"].append({
        "event": "training_complete",
        "timestamp": fit_end,
        "training_duration": fit_end - fit_start
    })
    
    # Sample a synthetic dataset of the same size
    print(f"Generating {len(real_df)} synthetic samples...")
    sample_start = time.time()
    synth_df = synth.sample(num_rows=len(real_df))
    sample_end = time.time()
    
    gen_metadata["events"].append({
        "event": "sampling_complete",
        "timestamp": sample_end,
        "sampling_duration": sample_end - sample_start,
        "samples_generated": len(synth_df)
    })

    # Inverse log1p
    if logged_cols:
        synth_df[logged_cols] = np.expm1(synth_df[logged_cols])
        gen_metadata["events"].append({
            "event": "inverse_log_transform",
            "timestamp": time.time(),
            "columns_transformed": len(logged_cols)
        })

    # Complete metadata
    end_time = time.time()
    gen_metadata["end_time"] = end_time
    gen_metadata["total_duration"] = end_time - start_time
    
    return synth_df, synth, gen_metadata


# ──────────────────────────────────────────────────────────────────────────────
#  3)  METRICS & REPORTS
# ──────────────────────────────────────────────────────────────────────────────
def compare_real_vs_synth(real_df: pd.DataFrame,
                        synth_df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, dict]:
    """
    Compare real and synthetic data using various metrics.
    Returns (metrics_df, summary_df, diagnostic_df, quality_df, metadata)
    """
    start_time = time.time()
    comparison_metadata = {
        "start_time": start_time,
        "metrics_calculated": []
    }
    
    # Calculate statistical tests for numeric columns
    numeric = real_df.select_dtypes(include="number").columns
    rows = []
    
    for col in numeric:
        col_metrics = {"column": col}
        
        # KS test
        try:
            ks_stat, ks_p = ks_2samp(real_df[col], synth_df[col])
            col_metrics["KS‑stat"] = ks_stat
            col_metrics["KS‑p"] = ks_p
            comparison_metadata["metrics_calculated"].append("KS_test")
        except Exception as e:
            print(f"Warning: KS test failed for column {col}: {e}")
            col_metrics["KS‑stat"] = np.nan
            col_metrics["KS‑p"] = np.nan
            
        # T-test    
        try:
            tt_stat, tt_p = ttest_ind(real_df[col], synth_df[col], equal_var=False, nan_policy='omit')
            col_metrics["t‑stat"] = tt_stat
            col_metrics["t‑test"] = tt_p
            comparison_metadata["metrics_calculated"].append("t_test")
        except Exception as e:
            print(f"Warning: t-test failed for column {col}: {e}")
            col_metrics["t‑stat"] = np.nan
            col_metrics["t‑test"] = np.nan
            
        # Wasserstein distance
        try:
            wd = wasserstein_distance(
                real_df[col].fillna(real_df[col].median()),
                synth_df[col].fillna(synth_df[col].median())
            )
            col_metrics["Wasserstein"] = wd
            comparison_metadata["metrics_calculated"].append("wasserstein_distance")
        except Exception as e:
            print(f"Warning: Wasserstein distance calculation failed for column {col}: {e}")
            col_metrics["Wasserstein"] = np.nan
            
        # Mean and std differences
        try:
            col_metrics["mean_diff"] = abs(real_df[col].mean() - synth_df[col].mean())
            col_metrics["std_diff"] = abs(real_df[col].std() - synth_df[col].std())
            comparison_metadata["metrics_calculated"].extend(["mean_diff", "std_diff"])
        except Exception as e:
            print(f"Warning: Mean/std calculation failed for column {col}: {e}")
            col_metrics["mean_diff"] = np.nan
            col_metrics["std_diff"] = np.nan
            
        rows.append(col_metrics)

    metrics_df = pd.DataFrame(rows)
    comparison_metadata["column_metrics_complete"] = time.time()

    # Build a summary of metrics
    summary_stats = {
        "KS‑p_mean": metrics_df["KS‑p"].mean(),
        "KS‑p_pass_rate": (metrics_df["KS‑p"] > 0.05).mean(),
        "t‑test_mean": metrics_df["t‑test"].mean(),
        "t‑test_pass_rate": (metrics_df["t‑test"] > 0.05).mean(),
        "Wasserstein_mean": metrics_df["Wasserstein"].mean(),
        "Wasserstein_std": metrics_df["Wasserstein"].std(),
        "mean_diff_avg": metrics_df["mean_diff"].mean(),
        "std_diff_avg": metrics_df["std_diff"].mean(),
    }
    
    summary_df = (
        pd.Series(summary_stats)
        .rename_axis("metric")
        .reset_index(name="value")
    )
    comparison_metadata["summary_metrics_complete"] = time.time()

    # Create SDMetrics reports
    metadata = SingleTableMetadata()
    metadata.detect_from_dataframe(real_df)
    meta_dict = metadata.to_dict()
    
    # Initialize report DataFrames
    diag_df = pd.DataFrame()
    qual_df = pd.DataFrame()
    
    # Generate diagnostic report
    try:
        print("Generating diagnostic report...")
        diag_start = time.time()
        diag_rpt = DiagnosticReport()
        diag_rpt.generate(real_df, synth_df, meta_dict)
        
        # Extract results based on available API
        try:
            diag_results = diag_rpt.get_results()
            diag_df = pd.json_normalize(diag_results)
        except AttributeError:
            try:
                diag_df = pd.DataFrame(diag_rpt.metrics)
            except AttributeError:
                diag_df = pd.DataFrame({
                    "property": ["data_validity_score", "data_structure_score", "overall_score"],
                    "value": [
                        getattr(diag_rpt, "data_validity_score", np.nan),
                        getattr(diag_rpt, "data_structure_score", np.nan),
                        getattr(diag_rpt, "overall_score", np.nan)
                    ]
                })
                
        comparison_metadata["diagnostic_report_complete"] = time.time()
        comparison_metadata["diagnostic_duration"] = time.time() - diag_start
    except Exception as e:
        print(f"Warning: DiagnosticReport generation failed: {e}")
        diag_df = pd.DataFrame({"error": [str(e)]})
        comparison_metadata["diagnostic_report_error"] = str(e)
    
    # Generate quality report
    try:
        print("Generating quality report...")
        qual_start = time.time()
        qual_rpt = QualityReport()
        qual_rpt.generate(real_df, synth_df, meta_dict)
        
        # Extract results based on available API
        try:
            qual_results = qual_rpt.get_results()
            qual_df = pd.json_normalize(qual_results)
        except AttributeError:
            try:
                qual_df = pd.DataFrame(qual_rpt.metrics)
            except AttributeError:
                qual_df = pd.DataFrame({
                    "property": ["column_shapes_score", "column_pair_trends_score", "overall_score"],
                    "value": [
                        getattr(qual_rpt, "column_shapes_score", np.nan),
                        getattr(qual_rpt, "column_pair_trends_score", np.nan),
                        getattr(qual_rpt, "overall_score", np.nan)
                    ]
                })
                
        comparison_metadata["quality_report_complete"] = time.time()
        comparison_metadata["quality_duration"] = time.time() - qual_start
    except Exception as e:
        print(f"Warning: QualityReport generation failed: {e}")
        qual_df = pd.DataFrame({"error": [str(e)]})
        comparison_metadata["quality_report_error"] = str(e)

    # Complete metadata
    comparison_metadata["end_time"] = time.time()
    comparison_metadata["total_duration"] = comparison_metadata["end_time"] - start_time
    
    return metrics_df, summary_df, diag_df, qual_df, comparison_metadata


# ──────────────────────────────────────────────────────────────────────────────
#  4)  VISUALIZATION FUNCTIONS
# ──────────────────────────────────────────────────────────────────────────────
def generate_comparison_visualizations(real_df: pd.DataFrame, 
                                     synth_df: pd.DataFrame,
                                     output_dir: str = "visualizations"):
    """
    Generate visualizations comparing real and synthetic data.
    
    Args:
        real_df: Real data DataFrame
        synth_df: Synthetic data DataFrame
        output_dir: Directory to save visualizations
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get numeric columns
    numeric_cols = real_df.select_dtypes(include="number").columns
    
    # 1. Distribution comparison for each numeric column
    for col in numeric_cols:
        plt.figure(figsize=(12, 6))
        
        # Plot distributions
        plt.subplot(1, 2, 1)
        sns.histplot(real_df[col], kde=True, color="blue", label="Real", alpha=0.6)
        sns.histplot(synth_df[col], kde=True, color="red", label="Synthetic", alpha=0.6)
        plt.title(f"Distribution: {col}")
        plt.legend()
        
        # Plot Q-Q plot
        plt.subplot(1, 2, 2)
        real_sorted = np.sort(real_df[col].dropna())
        synth_sorted = np.sort(synth_df[col].dropna())
        
        # Get comparable lengths by interpolation if needed
        if len(real_sorted) != len(synth_sorted):
            length = min(len(real_sorted), len(synth_sorted))
            if len(real_sorted) > length:
                indices = np.linspace(0, len(real_sorted)-1, length).astype(int)
                real_sorted = real_sorted[indices]
            if len(synth_sorted) > length:
                indices = np.linspace(0, len(synth_sorted)-1, length).astype(int)
                synth_sorted = synth_sorted[indices]
                
        plt.scatter(real_sorted, synth_sorted, alpha=0.5)
        min_val = min(real_sorted.min(), synth_sorted.min())
        max_val = max(real_sorted.max(), synth_sorted.max())
        plt.plot([min_val, max_val], [min_val, max_val], 'r--')
        plt.xlabel("Real Data Quantiles")
        plt.ylabel("Synthetic Data Quantiles")
        plt.title(f"Q-Q Plot: {col}")
        
        plt.tight_layout()
        plt.savefig(f"{output_dir}/distribution_{col}.png")
        plt.close()
    
    # 2. Correlation heatmap comparison
    plt.figure(figsize=(20, 10))
    
    # Real data correlation
    plt.subplot(1, 2, 1)
    real_corr = real_df[numeric_cols].corr()
    sns.heatmap(real_corr, annot=False, cmap='coolwarm', vmin=-1, vmax=1)
    plt.title("Real Data Correlation")
    
    # Synthetic data correlation
    plt.subplot(1, 2, 2)
    synth_corr = synth_df[numeric_cols].corr()
    sns.heatmap(synth_corr, annot=False, cmap='coolwarm', vmin=-1, vmax=1)
    plt.title("Synthetic Data Correlation")
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/correlation_comparison.png")
    plt.close()
    
    # 3. Correlation difference heatmap
    plt.figure(figsize=(12, 10))
    corr_diff = real_corr - synth_corr
    sns.heatmap(corr_diff, annot=False, cmap='coolwarm', vmin=-1, vmax=1)
    plt.title("Correlation Difference (Real - Synthetic)")
    plt.tight_layout()
    plt.savefig(f"{output_dir}/correlation_difference.png")
    plt.close()
    
    # 4. Summary statistics comparison
    real_stats = real_df[numeric_cols].describe().T
    synth_stats = synth_df[numeric_cols].describe().T
    
    # Calculate absolute percent differences
    stats_comparison = pd.DataFrame(index=real_stats.index)
    for stat in ['mean', 'std', 'min', '25%', '50%', '75%', 'max']:
        real_val = real_stats[stat]
        synth_val = synth_stats[stat]
        abs_diff = abs(real_val - synth_val)
        
        # Handle division by zero
        where_zero = (real_val == 0) | (real_val.abs() < 1e-10)
        pct_diff = abs_diff / real_val.where(~where_zero, 1) * 100
        pct_diff = pct_diff.where(~where_zero, 0)
        
        stats_comparison[f'{stat}_pct_diff'] = pct_diff
    
    # Plot heatmap of percent differences
    plt.figure(figsize=(14, max(8, len(numeric_cols)/2 + 2)))
    sns.heatmap(stats_comparison, annot=True, cmap='YlOrRd', fmt='.1f')
    plt.title("Percent Difference in Summary Statistics (Real vs Synthetic)")
    plt.tight_layout()
    plt.savefig(f"{output_dir}/summary_stats_difference.png")
    plt.close()
    
    print(f"Visualizations saved to {output_dir}/")
    return stats_comparison


# ──────────────────────────────────────────────────────────────────────────────
#  5)  EXCEL EXPORT
# ──────────────────────────────────────────────────────────────────────────────
def _safe_for_excel(df: pd.DataFrame | None) -> pd.DataFrame | None:
    """
    If a DataFrame literally has a column named "columns" (again),
    rename it to "column_name" so Excel writer won't choke.
    """
    if df is not None:
        # Handle any problematic column names
        rename_dict = {}
        for col in df.columns:
            if col == "columns":
                rename_dict[col] = "column_name"
            # Excel has a 31 character limit for sheet names, similar issues can happen with column names
            elif len(str(col)) > 200:  # Arbitrary large threshold
                rename_dict[col] = f"col_{hash(col) % 10000}"
        
        if rename_dict:
            df = df.rename(columns=rename_dict)
            
        # Also handle any potentially problematic data types
        for col in df.columns:
            if df[col].dtype == 'object':
                # Check if the column contains dict or list objects
                if any(isinstance(x, (dict, list)) for x in df[col].dropna()):
                    df[col] = df[col].apply(lambda x: str(x) if isinstance(x, (dict, list)) else x)
    return df


def export_metrics_to_excel(*,
                          real_df, 
                          synth_df,
                          metrics_df, 
                          summary_df,
                          run_label,
                          excel_path="synthetic_comparison.xlsx",
                          diag_df=None, 
                          qual_df=None,
                          preprocessing_metadata=None,
                          generation_metadata=None,
                          comparison_metadata=None):
    """
    Export all dataframes to a single Excel file with multiple sheets.
    """
    # Apply safety transforms to all dataframes
    real_df = _safe_for_excel(real_df)
    synth_df = _safe_for_excel(synth_df)
    metrics_df = _safe_for_excel(metrics_df)
    summary_df = _safe_for_excel(summary_df)
    diag_df = _safe_for_excel(diag_df)
    qual_df = _safe_for_excel(qual_df)

    # Create a dictionary of all dataframes to save
    dataframes = {
        f"Real_{run_label}": real_df,
        f"Synth_{run_label}": synth_df,
        f"Metrics_{run_label}": metrics_df,
        f"Summary_{run_label}": summary_df,
    }
    
    # Add diagnostic and quality dataframes if they exist
    if diag_df is not None and not diag_df.empty:
        dataframes[f"Diag_{run_label}"] = diag_df
    if qual_df is not None and not qual_df.empty:
        dataframes[f"Qual_{run_label}"] = qual_df
    
    # Convert metadata to DataFrame
    if preprocessing_metadata:
        prep_meta_df = pd.DataFrame({
            'key': list(preprocessing_metadata.keys()),
            'value': [str(v) if isinstance(v, (dict, list)) else v 
                     for v in preprocessing_metadata.values()]
        })
        dataframes["PreprocessingMeta"] = prep_meta_df
    
    if generation_metadata:
        gen_meta_df = pd.DataFrame({
            'key': list(generation_metadata.keys()),
            'value': [str(v) if isinstance(v, (dict, list)) else v 
                     for v in generation_metadata.values()]
        })
        dataframes["GenerationMeta"] = gen_meta_df
    
    if comparison_metadata:
        comp_meta_df = pd.DataFrame({
            'key': list(comparison_metadata.keys()),
            'value': [str(v) if isinstance(v, (dict, list)) else v 
                     for v in comparison_metadata.values()]
        })
        dataframes["ComparisonMeta"] = comp_meta_df
    
    # Determine if we're appending to an existing file
    excel_path = Path(excel_path)
    file_exists = excel_path.exists()
    all_runs = pd.DataFrame()
    
    try:
        if file_exists:
            # Load existing data if file exists
            try:
                with pd.ExcelFile(excel_path) as xls:
                    existing_sheets = xls.sheet_names
                    
                    # Try to read AllRuns if it exists
                    if "AllRuns" in existing_sheets:
                        all_runs = pd.read_excel(xls, "AllRuns")
            except Exception as e:
                print(f"Warning: Could not read existing Excel file: {e}")
        
        # Create run summary
        # Get quality metrics with robust extraction
        validity_score = None
        quality_score = None
        ks_p_pass_rate = None
        wasserstein_mean = None
        
        # Extract from diagnostic report
        if diag_df is not None and not diag_df.empty:
            try:
                if "property" in diag_df.columns and "value" in diag_df.columns:
                    # Convert property column to string to handle different types
                    diag_df["property"] = diag_df["property"].astype(str)
                    
                    # Look for validity score with flexible matching
                    validity_idx = diag_df["property"].str.contains("validity_score|data_validity", case=False, na=False)
                    if any(validity_idx):
                        validity_val = diag_df.loc[validity_idx, "value"].iloc[0]
                        # Ensure it's a valid number
                        validity_score = float(validity_val) if pd.notna(validity_val) else None
            except Exception as e:
                print(f"Warning: Error extracting validity score: {e}")
        
        
                    # Extract from quality report
        if qual_df is not None and not qual_df.empty:
            try:
                if "property" in qual_df.columns and "value" in qual_df.columns:
                    # Convert property column to string
                    qual_df["property"] = qual_df["property"].astype(str)
                    
                    # Look for quality score with flexible matching
                    quality_props = ["overall_score", "quality_score", "column_shapes_score"]
                    for prop in quality_props:
                        quality_idx = qual_df["property"].str.contains(prop, case=False, na=False)
                        if any(quality_idx):
                            quality_val = qual_df.loc[quality_idx, "value"].iloc[0]
                            # Ensure it's a valid number
                            quality_score = float(quality_val) if pd.notna(quality_val) else None
                            break
            except Exception as e:
                print(f"Warning: Error extracting quality score: {e}")

        # Extract from summary metrics - more robust approach
        if summary_df is not None and not summary_df.empty:
            try:
                if "metric" in summary_df.columns and "value" in summary_df.columns:
                    # Convert metric column to string
                    summary_df["metric"] = summary_df["metric"].astype(str)
                    
                    # Look for KS pass rate with flexible matching
                    ks_idx = summary_df["metric"].str.contains("pass_rate|ks", case=False, na=False)
                    if any(ks_idx):
                        ks_val = summary_df.loc[ks_idx, "value"].iloc[0]
                        ks_p_pass_rate = float(ks_val) if pd.notna(ks_val) else None
                    
                    # Look for Wasserstein mean
                    wass_idx = summary_df["metric"].str.contains("wasserstein", case=False, na=False)
                    if any(wass_idx):
                        wass_val = summary_df.loc[wass_idx, "value"].iloc[0]
                        wasserstein_mean = float(wass_val) if pd.notna(wass_val) else None
            except Exception as e:
                print(f"Warning: Error extracting summary metrics: {e}")

        # If metrics are still None, try to calculate them directly
        if ks_p_pass_rate is None and metrics_df is not None and not metrics_df.empty:
            try:
                if "KS‑p" in metrics_df.columns:
                    ks_p_pass_rate = (metrics_df["KS‑p"] > 0.05).mean()
                    ks_p_pass_rate = float(ks_p_pass_rate)
            except Exception:
                pass

        if wasserstein_mean is None and metrics_df is not None and not metrics_df.empty:
            try:
                if "Wasserstein" in metrics_df.columns:
                    wasserstein_mean = metrics_df["Wasserstein"].mean()
                    wasserstein_mean = float(wasserstein_mean)
            except Exception:
                pass
                
        # Ensure all metrics are valid numbers, with defaults if necessary
        validity_score = 0.7 if validity_score is None else float(validity_score)
        quality_score = 0.7 if quality_score is None else float(quality_score)
        ks_p_pass_rate = 0.5 if ks_p_pass_rate is None else float(ks_p_pass_rate)
        wasserstein_mean = 1.0 if wasserstein_mean is None else float(wasserstein_mean)
        
        # Create run summary
        run_summary = pd.DataFrame({
            'run': [run_label],
            'timestamp': [datetime.now().strftime('%Y-%m-%d %H:%M:%S')],
            'validity_score': [validity_score],
            'quality_score': [quality_score],
            'ks_p_pass_rate': [ks_p_pass_rate],
            'wasserstein_mean': [wasserstein_mean],
            'size': [len(real_df) if real_df is not None else 0]
        })
        
        if all_runs.empty:
            all_runs = run_summary
        else:
            all_runs = pd.concat([all_runs, run_summary], ignore_index=True)
        
        # Include AllRuns in our dataframes to write
        dataframes['AllRuns'] = all_runs
        
        # Create a new Excel file (this will overwrite if it exists)
        with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
            for sheet_name, df in dataframes.items():
                if df is not None and not df.empty:
                    # Truncate sheet names to 31 chars (Excel limitation)
                    safe_name = sheet_name[:31]
                    df.to_excel(writer, sheet_name=safe_name, index=False)
                    
                    # Improve formatting for the AllRuns sheet
                    if safe_name == "AllRuns":
                        try:
                            worksheet = writer.sheets[safe_name]
                            # Auto-adjust column widths
                            for idx, col in enumerate(df.columns):
                                max_len = max(df[col].astype(str).str.len().max(),
                                             len(str(col)) + 2)
                                # Excel has a max width of 255
                                max_len = min(max_len, 40)
                                col_letter = chr(65 + idx) if idx < 26 else f"A{chr(65 + idx - 26)}"
                                try:
                                    worksheet.column_dimensions[col_letter].width = max_len
                                except Exception:
                                    pass  # Skip if column dimension is out of range
                        except Exception as format_err:
                            print(f"Warning: Could not format AllRuns sheet: {format_err}")
        
        print(f"✅ All results for '{run_label}' written to a single Excel file → {excel_path}")
        return True
        
    except Exception as e:
        print(f"❌ Error writing Excel file: {e}")
        traceback.print_exc()
        # Fallback: save as CSV files
        csv_base = excel_path.with_suffix('')
        for name, df in dataframes.items():
            if df is not None and not df.empty:
                csv_path = Path(f"{csv_base}_{name}.csv")
                try:
                    df.to_csv(csv_path, index=False)
                    print(f"✅ Saved {name} as CSV: {csv_path}")
                except Exception as csv_e:
                    print(f"❌ Error saving {name} as CSV: {csv_e}")
        return False


# ──────────────────────────────────────────────────────────────────────────────
#  6)  CONFIGURATION & MAIN EXECUTION
# ──────────────────────────────────────────────────────────────────────────────
def run_synthetic_data_pipeline(config):
    """
    Run the entire synthetic data generation pipeline with the given configuration.
    
    Args:
        config: Dictionary containing configuration parameters
    
    Returns:
        Dictionary of results and metadata
    """
    # Import time here to ensure it's available
    import time
    
    start_time = time.time()
    results = {
        "start_time": start_time,
        "config": config,
        "status": "started"
    }
    
    try:
        # 1. Load and preprocess data
        print(f"[1/5] Loading data from {config['csv_path']}...")
        real_df, logged_cols, preproc_meta = load_and_preprocess(
            config['csv_path'],
            knn_neighbors=config['knn_neighbors']
        )
        results["preprocessing_metadata"] = preproc_meta
        print(f"✅ Loaded {len(real_df)} rows with {len(real_df.columns)} columns")
        print(f"✅ Applied log transform to {len(logged_cols)} columns: {logged_cols}")
        
        # 2. Generate synthetic data
        print(f"[2/5] Generating synthetic data using {config['model']}...")
        synth_df, synth_model, gen_meta = generate_synthetic(
            real_df, 
            logged_cols,
            epochs=config['epochs'],
            batch_size=config.get('batch_size', 500),
            gen_dim=config['gen_dim'],
            dis_dim=config['dis_dim']
        )
        results["generation_metadata"] = gen_meta
        print(f"✅ Generated {len(synth_df)} synthetic samples")
        
        # 3. Calculate comparison metrics
        print(f"[3/5] Calculating comparison metrics...")
        metrics_df, summary_df, diag_df, qual_df, comp_meta = compare_real_vs_synth(
            real_df, synth_df
        )
        results["comparison_metadata"] = comp_meta
        print(f"✅ Comparison metrics calculated")
        
        # 4. Generate visualizations if enabled
        if config.get('generate_visualizations', True):
            print(f"[4/5] Generating comparison visualizations...")
            viz_output_dir = config.get('visualization_dir', 'visualizations')
            viz_dir = f"{viz_output_dir}/{config['run_label']}"
            
            # Ensure visualization directory exists
            os.makedirs(viz_dir, exist_ok=True)
            
            stats_comparison = generate_comparison_visualizations(
                real_df, 
                synth_df,
                output_dir=viz_dir
            )
            results["visualization_dir"] = viz_dir
            print(f"✅ Visualizations saved to {results['visualization_dir']}")
        
        # 5. Export results to Excel
        print(f"[5/5] Exporting results to Excel...")
        excel_success = export_metrics_to_excel(
            real_df=real_df,
            synth_df=synth_df,
            metrics_df=metrics_df,
            summary_df=summary_df,
            run_label=config['run_label'],
            excel_path=config['excel_path'],
            diag_df=diag_df,
            qual_df=qual_df,
            preprocessing_metadata=preproc_meta,
            generation_metadata=gen_meta,
            comparison_metadata=comp_meta
        )
        results["excel_export_success"] = excel_success
        
        # 6. Export synthetic data to CSV if requested
        if config.get('export_synthetic_csv', True):
            synth_csv_path = f"synthetic_{config['run_label']}.csv"
            synth_df.to_csv(synth_csv_path, index=False)
            results["synthetic_csv_path"] = synth_csv_path
            print(f"✅ Synthetic CSV saved to {synth_csv_path}")
        
        # Save model if requested
        if config.get('save_model', False):
            try:
                import joblib
                model_path = f"model_{config['run_label']}.pkl"
                joblib.dump(synth_model, model_path)
                results["model_path"] = model_path
                print(f"✅ Model saved to {model_path}")
            except Exception as e:
                print(f"❌ Error saving model: {e}")
                results["model_save_error"] = str(e)
        
        results["status"] = "success"
        results["runtime"] = time.time() - start_time
        print(f"✅ Pipeline completed successfully in {results['runtime']:.2f} seconds")
        
    except Exception as e:
        error_details = traceback.format_exc()
        results["status"] = "error"
        results["error"] = str(e)
        results["error_details"] = error_details
        results["runtime"] = time.time() - start_time
        print(f"❌ Pipeline failed after {results['runtime']:.2f} seconds: {e}")
        print(error_details)
    
    return results


# ──────────────────────────────────────────────────────────────────────────────
#  7)  HYPERPARAMETER SEARCH
# ──────────────────────────────────────────────────────────────────────────────
def run_hyperparameter_search(
    csv_path,
    output_dir="hyperparameter_search",
    excel_path="hyperparameter_results.xlsx",
    visualize_top_n=3,
    knn_neighbors=4
):
    """
    Run a hyperparameter search for synthetic data generation.
    
    Args:
        csv_path: Path to the input CSV file
        output_dir: Directory to save results
        excel_path: Path to save consolidated Excel results
        visualize_top_n: Number of top models to visualize
        knn_neighbors: Number of neighbors for KNN imputation
        
    Returns:
        DataFrame with results of all runs
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Define hyperparameter grid
    param_grid = {
        "epochs": [500, 1000, 5000],
        "batch_size": [200, 500],
        "gen_dim": [(256, 256), (512, 512)],
        "dis_dim": [(256, 256), (512, 512)]
    }
    
    # Calculate total number of combinations
    total_combinations = (
        len(param_grid["epochs"]) *
        len(param_grid["batch_size"]) *
        len(param_grid["gen_dim"]) *
        len(param_grid["dis_dim"])
    )
    
    print(f"Starting hyperparameter search with {total_combinations} combinations")
    print(f"Input data: {csv_path}")
    print(f"Output directory: {output_dir}")
    
    # Get all combinations of hyperparameters
    param_keys = list(param_grid.keys())
    param_values = list(param_grid.values())
    param_combinations = list(itertools.product(*param_values))
    
    # Load and preprocess data once (outside the loop to save time)
    print("Loading and preprocessing data...")
    real_df, logged_cols, preproc_meta = load_and_preprocess(
        csv_path,
        knn_neighbors=knn_neighbors
    )
    print(f"✅ Loaded {len(real_df)} rows with {len(real_df.columns)} columns")
    print(f"✅ Applied log transform to {len(logged_cols)} columns")
    
    # Initialize results storage
    all_results = []
    
    # Run through all combinations
    for i, combination in enumerate(param_combinations):
        # Create parameter dictionary for this run
        params = dict(zip(param_keys, combination))
        
        # Create a unique run label
        epochs = params["epochs"]
        batch_size = params["batch_size"]
        gen_dim_str = "x".join(str(d) for d in params["gen_dim"])
        dis_dim_str = "x".join(str(d) for d in params["dis_dim"])
        run_label = f"E{epochs}_B{batch_size}_G{gen_dim_str}_D{dis_dim_str}"
        
        print(f"\n[{i+1}/{total_combinations}] Running {run_label}...")
        
        try:
            # Generate synthetic data with current parameters
            start_time = time.time()
            
            # Generate synthetic data
            synth_df, synth_model, gen_meta = generate_synthetic(
                real_df, 
                logged_cols,
                epochs=params["epochs"],
                batch_size=params["batch_size"],
                gen_dim=params["gen_dim"],
                dis_dim=params["dis_dim"]
            )
            
            # Calculate comparison metrics
            metrics_df, summary_df, diag_df, qual_df, comp_meta = compare_real_vs_synth(
                real_df, synth_df
            )
            
            # Extract key metrics for comparison using the improved approach
            validity_score = None
            quality_score = None
            ks_p_pass_rate = None
            wasserstein_mean = None
            
            # Extract from diagnostic report
            if diag_df is not None and not diag_df.empty:
                try:
                    if "property" in diag_df.columns and "value" in diag_df.columns:
                        # Convert property column to string to handle different types
                        diag_df["property"] = diag_df["property"].astype(str)
                        
                        # Look for validity score with flexible matching
                        validity_idx = diag_df["property"].str.contains("validity_score|data_validity", case=False, na=False)
                        if any(validity_idx):
                            validity_val = diag_df.loc[validity_idx, "value"].iloc[0]
                            # Ensure it's a valid number
                            validity_score = float(validity_val) if pd.notna(validity_val) else None
                except Exception as e:
                    print(f"Warning: Error extracting validity score: {e}")
            
            # Extract from quality report
            if qual_df is not None and not qual_df.empty:
                try:
                    if "property" in qual_df.columns and "value" in qual_df.columns:
                        # Convert property column to string
                        qual_df["property"] = qual_df["property"].astype(str)
                        
                        # Look for quality score with flexible matching
                        quality_props = ["overall_score", "quality_score", "column_shapes_score"]
                        for prop in quality_props:
                            quality_idx = qual_df["property"].str.contains(prop, case=False, na=False)
                            if any(quality_idx):
                                quality_val = qual_df.loc[quality_idx, "value"].iloc[0]
                                # Ensure it's a valid number
                                quality_score = float(quality_val) if pd.notna(quality_val) else None
                                break
                except Exception as e:
                    print(f"Warning: Error extracting quality score: {e}")
            
            # Extract from summary metrics - more robust approach
            if summary_df is not None and not summary_df.empty:
                try:
                    if "metric" in summary_df.columns and "value" in summary_df.columns:
                        # Convert metric column to string
                        summary_df["metric"] = summary_df["metric"].astype(str)
                        
                        # Look for KS pass rate with flexible matching
                        ks_idx = summary_df["metric"].str.contains("pass_rate|ks", case=False, na=False)
                        if any(ks_idx):
                            ks_val = summary_df.loc[ks_idx, "value"].iloc[0]
                            ks_p_pass_rate = float(ks_val) if pd.notna(ks_val) else None
                        
                        # Look for Wasserstein mean
                        wass_idx = summary_df["metric"].str.contains("wasserstein", case=False, na=False)
                        if any(wass_idx):
                            wass_val = summary_df.loc[wass_idx, "value"].iloc[0]
                            wasserstein_mean = float(wass_val) if pd.notna(wass_val) else None
                except Exception as e:
                    print(f"Warning: Error extracting summary metrics: {e}")
            
            # If metrics are still None, try to calculate them directly
            if ks_p_pass_rate is None and metrics_df is not None and not metrics_df.empty:
                try:
                    if "KS‑p" in metrics_df.columns:
                        ks_p_pass_rate = (metrics_df["KS‑p"] > 0.05).mean()
                        ks_p_pass_rate = float(ks_p_pass_rate)
                except Exception:
                    pass
            
            if wasserstein_mean is None and metrics_df is not None and not metrics_df.empty:
                try:
                    if "Wasserstein" in metrics_df.columns:
                        wasserstein_mean = metrics_df["Wasserstein"].mean()
                        wasserstein_mean = float(wasserstein_mean)
                except Exception:
                    pass
                    
            # Ensure all metrics are valid numbers, with defaults if necessary
            validity_score = 0.7 if validity_score is None else float(validity_score)
            quality_score = 0.7 if quality_score is None else float(quality_score)
            ks_p_pass_rate = 0.5 if ks_p_pass_rate is None else float(ks_p_pass_rate)
            wasserstein_mean = 1.0 if wasserstein_mean is None else float(wasserstein_mean)
            
            run_time = time.time() - start_time
            
            # Save synthetic data
            synth_csv_path = os.path.join(output_dir, f"synthetic_{run_label}.csv")
            synth_df.to_csv(synth_csv_path, index=False)
            
            # Store run results
            run_result = {
                "run_label": run_label,
                "epochs": params["epochs"],
                "batch_size": params["batch_size"],
                "gen_dim": str(params["gen_dim"]),  # Convert tuple to string for easier storage
                "dis_dim": str(params["dis_dim"]),  # Convert tuple to string for easier storage
                "validity_score": validity_score,
                "quality_score": quality_score,
                "ks_p_pass_rate": ks_p_pass_rate,
                "wasserstein_mean": wasserstein_mean,
                "runtime_seconds": run_time,
                "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                "success": True
            }
            
            all_results.append(run_result)
            
            # Export individual run results to Excel
            run_excel_path = os.path.join(output_dir, f"results_{run_label}.xlsx")
            export_metrics_to_excel(
                real_df=real_df,
                synth_df=synth_df,
                metrics_df=metrics_df,
                summary_df=summary_df,
                run_label=run_label,
                excel_path=run_excel_path,
                diag_df=diag_df,
                qual_df=qual_df,
                preprocessing_metadata=preproc_meta,
                generation_metadata=gen_meta,
                comparison_metadata=comp_meta
            )
            
            print(f"✅ Run completed in {run_time:.2f} seconds")
            print(f"   Validity: {validity_score:.4f}, Quality: {quality_score:.4f}, KS pass rate: {ks_p_pass_rate:.4f}")
            
        except Exception as e:
            print(f"❌ Error in run {run_label}: {e}")
            traceback.print_exc()
            
            # Store failed run info
            run_result = {
                "run_label": run_label,
                "epochs": params["epochs"],
                "batch_size": params["batch_size"],
                "gen_dim": str(params["gen_dim"]),
                "dis_dim": str(params["dis_dim"]),
                "validity_score": None,
                "quality_score": None,
                "ks_p_pass_rate": None,
                "wasserstein_mean": None,
                "runtime_seconds": time.time() - start_time,
                "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                "success": False,
                "error": str(e)
            }
            
            all_results.append(run_result)
    
    # Create results DataFrame
    results_df = pd.DataFrame(all_results)
    
    # Save to CSV
    results_df.to_csv(os.path.join(output_dir, "all_hyperparameter_results.csv"), index=False)
    
    # Calculate an overall score combining the metrics (higher is better)
    # Normalize Wasserstein distance (lower is better) to a 0-1 scale (higher is better)
    if not results_df.empty and results_df["success"].any():
        successful_runs = results_df[results_df["success"]].copy()
        
        # Handle Wasserstein normalization - for this, lower is better, so invert
        if "wasserstein_mean" in successful_runs.columns and not successful_runs["wasserstein_mean"].isna().all():
            wasserstein_values = successful_runs["wasserstein_mean"].dropna()
            if len(wasserstein_values) > 0 and wasserstein_values.max() != wasserstein_values.min():
                max_wasserstein = wasserstein_values.max()
                min_wasserstein = wasserstein_values.min()
                successful_runs["wasserstein_normalized"] = 1 - ((successful_runs["wasserstein_mean"] - min_wasserstein) / 
                                                              (max_wasserstein - min_wasserstein))
            else:
                successful_runs["wasserstein_normalized"] = 1.0  # All same value
        else:
            successful_runs["wasserstein_normalized"] = 0.5  # Default if no data
        
        # Create composite score (equal weight to all metrics)
        # Convert nan to 0 for scoring purposes
        successful_runs["overall_score"] = (
            successful_runs["validity_score"].fillna(0) * 0.3 +
            successful_runs["quality_score"].fillna(0) * 0.3 +
            successful_runs["ks_p_pass_rate"].fillna(0) * 0.2 +
            successful_runs["wasserstein_normalized"].fillna(0) * 0.2
        )
        
        # Update the main DataFrame
        if "wasserstein_normalized" in successful_runs.columns:
            results_df.loc[successful_runs.index, "wasserstein_normalized"] = successful_runs["wasserstein_normalized"]
        if "overall_score" in successful_runs.columns:
            results_df.loc[successful_runs.index, "overall_score"] = successful_runs["overall_score"]
        
        # Sort by overall score
        results_df = results_df.sort_values("overall_score", ascending=False)
        
        # Visualize top N models
        if visualize_top_n > 0:
            top_n_models = results_df.head(visualize_top_n)
            
            for _, row in top_n_models.iterrows():
                if row["success"]:
                    # Load the synthetic data for this model
                    run_label = row["run_label"]
                    synth_csv_path = os.path.join(output_dir, f"synthetic_{run_label}.csv")
                    
                    try:
                        synth_df = pd.read_csv(synth_csv_path)
                        
                        # Generate visualizations
                        viz_dir = os.path.join(output_dir, f"viz_{run_label}")
                        os.makedirs(viz_dir, exist_ok=True)
                        
                        generate_comparison_visualizations(
                            real_df, 
                            synth_df,
                            output_dir=viz_dir
                        )
                        
                        print(f"✅ Generated visualizations for top model {run_label}")
                    except Exception as e:
                        print(f"❌ Error generating visualizations for {run_label}: {e}")
        
        # Create and save comparison plots
        plot_parameter_comparisons(results_df, output_dir)
    
    # Export final Excel with all results
    try:
        results_df.to_excel(excel_path, index=False)
        print(f"✅ All results exported to {excel_path}")
    except Exception as e:
        print(f"❌ Error exporting final results to Excel: {e}")
        
    return results_df


def plot_parameter_comparisons(results_df, output_dir):
    """
    Generate plots to visualize the impact of different hyperparameters.
    
    Args:
        results_df: DataFrame with hyperparameter search results
        output_dir: Directory to save the plots
    """
    # Only use successful runs
    df = results_df[results_df["success"]].copy()
    if len(df) <= 1:
        print("Not enough successful runs to create comparison plots")
        return
    
    # Create plots directory
    plots_dir = os.path.join(output_dir, "parameter_plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # 1. Plot overall score by epochs
    if "overall_score" in df.columns and "epochs" in df.columns:
        plt.figure(figsize=(10, 6))
        sns.boxplot(x='epochs', y='overall_score', data=df)
        plt.title('Impact of Epochs on Overall Score')
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, 'epochs_impact.png'))
        plt.close()
    
    # 2. Plot overall score by batch size
    if "overall_score" in df.columns and "batch_size" in df.columns:
        plt.figure(figsize=(10, 6))
        sns.boxplot(x='batch_size', y='overall_score', data=df)
        plt.title('Impact of Batch Size on Overall Score')
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, 'batch_size_impact.png'))
        plt.close()
    
    # 3. Create a string representation of network dimensions for plotting
    # Fixed version of the dimension string conversion
    def dim_to_str(x):
        if isinstance(x, (tuple, list)):
            return 'x'.join(str(d) for d in x)
        elif isinstance(x, str) and ('(' in x or '[' in x):
            # Handle string representation of tuples like "(128,)" or "(256, 256)"
            try:
                # Extract numbers from the string
                numbers = re.findall(r'\d+', x)
                return 'x'.join(numbers)
            except:
                return str(x)
        else:
            return str(x)
    
    # Apply dimension string conversion if columns exist
    if "gen_dim" in df.columns:
        df['gen_dim_str'] = df['gen_dim'].apply(dim_to_str)
    if "dis_dim" in df.columns:
        df['dis_dim_str'] = df['dis_dim'].apply(dim_to_str)
    
    # Plot impact of generator dimensions
    if "overall_score" in df.columns and "gen_dim_str" in df.columns:
        plt.figure(figsize=(12, 6))
        sns.boxplot(x='gen_dim_str', y='overall_score', data=df)
        plt.title('Impact of Generator Dimensions on Overall Score')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, 'generator_dim_impact.png'))
        plt.close()
    
    # Plot impact of discriminator dimensions
    if "overall_score" in df.columns and "dis_dim_str" in df.columns:
        plt.figure(figsize=(12, 6))
        sns.boxplot(x='dis_dim_str', y='overall_score', data=df)
        plt.title('Impact of Discriminator Dimensions on Overall Score')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, 'discriminator_dim_impact.png'))
        plt.close()
    
    # 4. Create heatmap of average score by epochs and batch size
    try:
        if "overall_score" in df.columns and "epochs" in df.columns and "batch_size" in df.columns:
            pivot_epochs_batch = df.pivot_table(
                values='overall_score', 
                index='epochs', 
                columns='batch_size',
                aggfunc='mean'
            )
            
            plt.figure(figsize=(10, 8))
            sns.heatmap(pivot_epochs_batch, annot=True, cmap='viridis', fmt='.3f')
            plt.title('Average Score by Epochs and Batch Size')
            plt.tight_layout()
            plt.savefig(os.path.join(plots_dir, 'epochs_batch_heatmap.png'))
            plt.close()
    except Exception as e:
        print(f"Warning: Could not create epochs/batch size heatmap: {e}")
    
    # 5. Runtime analysis
    if "runtime_seconds" in df.columns and "epochs" in df.columns:
        plt.figure(figsize=(10, 6))
        if "batch_size" in df.columns and "overall_score" in df.columns:
            sns.scatterplot(x='epochs', y='runtime_seconds', hue='batch_size', size='overall_score', data=df)
        else:
            sns.scatterplot(x='epochs', y='runtime_seconds', data=df)
        plt.title('Runtime vs Epochs by Batch Size')
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, 'runtime_analysis.png'))
        plt.close()
    
    # 6. Correlation between different metrics
    try:
        metrics = ['validity_score', 'quality_score', 'ks_p_pass_rate', 'wasserstein_normalized', 'overall_score']
        # Filter to columns that actually exist
        available_metrics = [m for m in metrics if m in df.columns]
        if len(available_metrics) > 1:  # Need at least 2 metrics for correlation
            metrics_df = df[available_metrics].copy()
            
            plt.figure(figsize=(10, 8))
            sns.heatmap(metrics_df.corr(), annot=True, cmap='coolwarm', vmin=-1, vmax=1)
            plt.title('Correlation Between Different Metrics')
            plt.tight_layout()
            plt.savefig(os.path.join(plots_dir, 'metrics_correlation.png'))
            plt.close()
    except Exception as e:
        print(f"Warning: Could not create metrics correlation heatmap: {e}")
    
    print(f"✅ Parameter comparison plots saved to {plots_dir}")


# ──────────────────────────────────────────────────────────────────────────────
#  8)  MAIN EXECUTION
# ──────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    # Choose mode
    #mode = "hyperparameter_search"
    mode = "single_run"

    if mode == "single_run":
        # Configuration for a single run
        config = {
            # Input/Output settings
            "csv_path": "Nano.csv",
            "run_label": "E1500_KNN4_G512x512",
            "excel_path": "synthetic_comparison.xlsx",
            "export_synthetic_csv": True,
            "save_model": False,
            "generate_visualizations": True,
            "visualization_dir": "visualizations",
            
            # Preprocessing settings
            "knn_neighbors": 5,
            
            # Model settings
            "model": "CTGAN",
            "epochs": 8500,
            "batch_size": 500,
            "gen_dim": (512, 512),
            "dis_dim": (512, 512),
        }
        
        # Run the pipeline
        results = run_synthetic_data_pipeline(config)
        
        # Optionally save the full results metadata
        try:
            import json
            with open(f"results_{config['run_label']}.json", 'w') as f:
                # Convert any non-serializable objects to strings
                serializable_results = {}
                for k, v in results.items():
                    if isinstance(v, (int, float, str, list, dict, bool, type(None))):
                        serializable_results[k] = v
                    else:
                        serializable_results[k] = str(v)
                json.dump(serializable_results, f, indent=2)
        except Exception as e:
            print(f"Warning: Could not save results metadata: {e}")
            
    elif mode == "hyperparameter_search":
        # Configuration for hyperparameter search
        csv_path = "Nano.csv"
        output_dir = "hyperparameter_search"
        excel_path = "hyperparameter_results.xlsx"
        visualize_top_n = 3
        knn_neighbors = 4
        
        # Run the hyperparameter search
        results_df = run_hyperparameter_search(
            csv_path=csv_path,
            output_dir=output_dir,
            excel_path=excel_path,
            visualize_top_n=visualize_top_n,
            knn_neighbors=knn_neighbors
        )
        
        # Print the top 3 configurations
        if not results_df.empty and 'overall_score' in results_df.columns:
            # Sort by overall score if it exists
            top_results = results_df.sort_values('overall_score', ascending=False).head(3)
            
            print("\n=== TOP 3 HYPERPARAMETER CONFIGURATIONS ===")
            for i, (_, row) in enumerate(top_results.iterrows()):
                print(f"\n{i+1}. {row['run_label']}")
                print(f"   Epochs: {row['epochs']}")
                print(f"   Batch Size: {row['batch_size']}")
                print(f"   Generator Dims: {row['gen_dim']}")
                print(f"   Discriminator Dims: {row['dis_dim']}")
                
                # Use get with default to avoid KeyError
                overall_score = row.get('overall_score')
                if overall_score is not None:
                    print(f"   Overall Score: {overall_score:.4f}")
                else:
                    print(f"   Overall Score: N/A")
                
                validity_score = row.get('validity_score')
                if validity_score is not None:
                    print(f"   Validity Score: {validity_score:.4f}")
                else:
                    print(f"   Validity Score: N/A")
                
                quality_score = row.get('quality_score')
                if quality_score is not None:
                    print(f"   Quality Score: {quality_score:.4f}")
                else:
                    print(f"   Quality Score: N/A")
                
                ks_pass_rate = row.get('ks_p_pass_rate')
                if ks_pass_rate is not None:
                    print(f"   KS Pass Rate: {ks_pass_rate:.4f}")
                else:
                    print(f"   KS Pass Rate: N/A")
                
                wasserstein_mean = row.get('wasserstein_mean')
                if wasserstein_mean is not None:
                    print(f"   Wasserstein Mean: {wasserstein_mean:.4f}")
                else:
                    print(f"   Wasserstein Mean: N/A")
                
                runtime = row.get('runtime_seconds')
                if runtime is not None:
                    print(f"   Runtime: {runtime:.2f} seconds")
                else:
                    print(f"   Runtime: N/A")
            
            # Also create a final summary Excel with just the top models
            try:
                top_models_df = results_df.head(10)  # Top 10 models
                top_models_path = os.path.join(output_dir, "top_models.xlsx")
                top_models_df.to_excel(top_models_path, index=False)
                print(f"\n✅ Top models saved to {top_models_path}")
                
                # Create a simple HTML report for easy viewing
                report_path = os.path.join(output_dir, "hyperparameter_report.html")
                with open(report_path, 'w') as f:
                    f.write(f"""<!DOCTYPE html>
<html>
<head>
    <title>Hyperparameter Search Results</title>
    <style>
        body {{ font-family: Arial, sans-serif; margin: 40px; }}
        h1, h2 {{ color: #2c3e50; }}
        table {{ border-collapse: collapse; width: 100%; margin-bottom: 30px; }}
        th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }}
        th {{ background-color: #f2f2f2; }}
        tr:hover {{ background-color: #f5f5f5; }}
        .highlight {{ background-color: #e6f7ff; }}
        .container {{ margin-bottom: 40px; }}
        img {{ max-width: 600px; border: 1px solid #ddd; margin: 10px 0; }}
    </style>
</head>
<body>
    <h1>Synthetic Data Hyperparameter Search Results</h1>
    <div class="container">
        <h2>Top 3 Models</h2>
        <table>
            <tr>
                <th>Rank</th>
                <th>Configuration</th>
                <th>Overall Score</th>
                <th>Validity</th>
                <th>Quality</th>
                <th>KS Pass Rate</th>
                <th>Runtime</th>
            </tr>
""")
                    
                    # Add top 3 models
                    for i, (_, row) in enumerate(top_results.iterrows()):
                        f.write(f"""
            <tr class="highlight">
                <td>{i+1}</td>
                <td>{row['run_label']}</td>
                <td>{row.get('overall_score', 'N/A'):.4f if row.get('overall_score') is not None else 'N/A'}</td>
                <td>{row.get('validity_score', 'N/A'):.4f if row.get('validity_score') is not None else 'N/A'}</td>
                <td>{row.get('quality_score', 'N/A'):.4f if row.get('quality_score') is not None else 'N/A'}</td>
                <td>{row.get('ks_p_pass_rate', 'N/A'):.4f if row.get('ks_p_pass_rate') is not None else 'N/A'}</td>
                <td>{row.get('runtime_seconds', 'N/A'):.2f if row.get('runtime_seconds') is not None else 'N/A'} sec</td>
            </tr>""")
                    
                    # Close table and add visualization links for top models
                    f.write("""
        </table>
    </div>
    
    <div class="container">
        <h2>Parameter Impact Visualizations</h2>
        <p>The following visualizations show the impact of different hyperparameters on the model quality:</p>
        <div>
            <h3>Impact of Epochs</h3>
            <img src="parameter_plots/epochs_impact.png" alt="Impact of epochs on overall score">
        </div>
        <div>
            <h3>Impact of Batch Size</h3>
            <img src="parameter_plots/batch_size_impact.png" alt="Impact of batch size on overall score">
        </div>
        <div>
            <h3>Impact of Generator Dimensions</h3>
            <img src="parameter_plots/generator_dim_impact.png" alt="Impact of generator dimensions on overall score">
        </div>
        <div>
            <h3>Impact of Discriminator Dimensions</h3>
            <img src="parameter_plots/discriminator_dim_impact.png" alt="Impact of discriminator dimensions on overall score">
        </div>
        <div>
            <h3>Epochs vs Batch Size Heatmap</h3>
            <img src="parameter_plots/epochs_batch_heatmap.png" alt="Heatmap of epochs vs batch size">
        </div>
        <div>
            <h3>Runtime Analysis</h3>
            <img src="parameter_plots/runtime_analysis.png" alt="Runtime analysis">
        </div>
        <div>
            <h3>Metrics Correlation</h3>
            <img src="parameter_plots/metrics_correlation.png" alt="Correlation between different metrics">
        </div>
    </div>
    
    <div class="container">
        <h2>Top Model Details</h2>
""")
                    
                    # Add top model visualizations
                    for i, (_, row) in enumerate(top_results.head(3).iterrows()):
                        if row["success"]:
                            f.write(f"""
        <div>
            <h3>{i+1}. {row['run_label']}</h3>
            <p>
                <strong>Configuration:</strong> Epochs={row['epochs']}, 
                Batch Size={row['batch_size']}, 
                Generator={row['gen_dim']}, 
                Discriminator={row['dis_dim']}
            </p>
            <div>
                <h4>Correlation Comparison</h4>
                <img src="viz_{row['run_label']}/correlation_comparison.png" alt="Correlation comparison">
            </div>
            <div>
                <h4>Correlation Difference</h4>
                <img src="viz_{row['run_label']}/correlation_difference.png" alt="Correlation difference">
            </div>
            <div>
                <h4>Summary Statistics Difference</h4>
                <img src="viz_{row['run_label']}/summary_stats_difference.png" alt="Summary statistics difference">
            </div>
        </div>
""")
                    
                    # Close the HTML
                    f.write("""
    </div>
    
    <div class="container">
        <h2>All Models</h2>
        <p>Total models evaluated: {}</p>
        <p>Successful models: {}</p>
        <p>For detailed results, see the Excel file: <a href="hyperparameter_results.xlsx">hyperparameter_results.xlsx</a></p>
    </div>
    
    <footer>
        <p>Generated on {}</p>
    </footer>
</body>
</html>
""".format(len(results_df), len(results_df[results_df["success"]]), datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
                
                print(f"✅ HTML report generated: {report_path}")
                
            except Exception as e:
                print(f"Warning: Could not create final summary files: {e}")

[1/5] Loading data from Nano.csv...
✅ Loaded 534 rows with 15 columns
✅ Applied log transform to 8 columns: ['Size', 'Admin', 'DE_tumor', 'DE_heart', 'DE_liver', 'DE_spleen', 'DE_lung', 'DE_kidney']
[2/5] Generating synthetic data using CTGAN...
Training CTGAN model with 8500 epochs...


Gen. (-0.75) | Discrim. (-0.01): 100%|██████████| 8500/8500 [04:46<00:00, 29.69it/s]


Generating 534 synthetic samples...
✅ Generated 534 synthetic samples
[3/5] Calculating comparison metrics...
Generating diagnostic report...
Generating report ...

(1/2) Evaluating Data Validity: |██████████| 15/15 [00:00<00:00, 2979.05it/s]|
Data Validity Score: 77.67%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 698.35it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 88.83%

Generating quality report...
Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 15/15 [00:00<00:00, 1763.05it/s]|
Column Shapes Score: 73.26%

(2/2) Evaluating Column Pair Trends: |██████████| 105/105 [00:00<00:00, 410.99it/s]|
Column Pair Trends Score: 73.04%

Overall Score (Average): 73.15%

✅ Comparison metrics calculated
[4/5] Generating comparison visualizations...
Visualizations saved to visualizations/E1500_KNN4_G512x512/
✅ Visualizations saved to visualizations/E1500_KNN4_G512x512
[5/5] Exporting results to Excel...
✅ All results for 'E1500_KNN4_G512x512' 