In [14]:
import pandas as pd
import numpy as np
import os
from scipy import stats
from scipy.spatial.distance import jensenshannon
from scipy.stats import wasserstein_distance, chi2_contingency, fisher_exact
import matplotlib.pyplot as plt
import warnings

def calculate_metrics_numerical(real_values, synth_values):
    """
    Calculate statistical distance metrics between real and synthetic distributions
    for numerical features.
    
    Args:
        real_values: Array of values from the real dataset
        synth_values: Array of values from the synthetic dataset
        
    Returns:
        Dictionary with calculated metrics
    """
    metrics = {}
    
    # Calculate Kolmogorov-Smirnov statistic
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        ks_stat, ks_pval = stats.ks_2samp(real_values, synth_values)
    metrics['ks_statistic'] = ks_stat
    metrics['ks_pvalue'] = ks_pval
    
    # Calculate Earth Mover's Distance (Wasserstein metric)
    emd = wasserstein_distance(real_values, synth_values)
    metrics['earth_movers_distance'] = emd
    
    # Calculate Jensen-Shannon Divergence for numerical data
    min_val = min(np.min(real_values), np.min(synth_values))
    max_val = max(np.max(real_values), np.max(synth_values))
    bins = np.linspace(min_val, max_val, min(50, len(np.unique(np.concatenate([real_values, synth_values])))))
    
    # Get histogram counts and normalize
    real_hist, _ = np.histogram(real_values, bins=bins, density=True)
    synth_hist, _ = np.histogram(synth_values, bins=bins, density=True)
    
    # Handle zero frequencies by adding a small epsilon
    epsilon = 1e-10
    real_hist = real_hist + epsilon
    synth_hist = synth_hist + epsilon
    
    # Normalize to make proper probability distributions
    real_hist = real_hist / np.sum(real_hist)
    synth_hist = synth_hist / np.sum(synth_hist)
    
    js_div = jensenshannon(real_hist, synth_hist)
    metrics['jensen_shannon'] = js_div
    
    return metrics

def calculate_metrics_categorical(real_values, synth_values):
    """
    Calculate statistical distance metrics between real and synthetic distributions
    for categorical features.
    
    Args:
        real_values: Array of values from the real dataset
        synth_values: Array of values from the synthetic dataset
        
    Returns:
        Dictionary with calculated metrics
    """
    metrics = {}
    
    # Get all unique categories
    all_categories = np.unique(np.concatenate([real_values, synth_values]))
    
    # Create contingency table
    contingency_table = np.zeros((2, len(all_categories)))
    
    for i, cat in enumerate(all_categories):
        contingency_table[0, i] = np.sum(real_values == cat)
        contingency_table[1, i] = np.sum(synth_values == cat)
    
    # Perform Chi-Square test
    if np.all(contingency_table > 0) and len(all_categories) > 1:
        chi2, p_value, dof, expected = chi2_contingency(contingency_table)
        metrics['chi2_statistic'] = chi2
        metrics['chi2_pvalue'] = p_value
    else:
        metrics['chi2_statistic'] = np.nan
        metrics['chi2_pvalue'] = np.nan
    
    # Perform Fisher's Exact test if possible (only for 2x2 tables)
    if len(all_categories) == 2:
        try:
            odds_ratio, p_value = fisher_exact(contingency_table)
            metrics['fisher_pvalue'] = p_value
            metrics['fisher_odds_ratio'] = odds_ratio
        except:
            metrics['fisher_pvalue'] = np.nan
            metrics['fisher_odds_ratio'] = np.nan
    else:
        metrics['fisher_pvalue'] = np.nan
        metrics['fisher_odds_ratio'] = np.nan
    
    # Calculate G-test (Log-Likelihood ratio)
    if np.all(contingency_table > 0) and len(all_categories) > 1:
        # G-test calculation
        observed = contingency_table.flatten()
        row_sums = contingency_table.sum(axis=1).reshape(-1, 1)
        col_sums = contingency_table.sum(axis=0).reshape(1, -1)
        total = contingency_table.sum()
        expected = np.dot(row_sums, col_sums) / total
        expected = expected.flatten()
        
        g_stat = 2 * np.sum(observed * np.log(observed / expected))
        g_pvalue = 1 - stats.chi2.cdf(g_stat, dof)
        
        metrics['g_test_statistic'] = g_stat
        metrics['g_test_pvalue'] = g_pvalue
    else:
        metrics['g_test_statistic'] = np.nan
        metrics['g_test_pvalue'] = np.nan
    
    # Calculate Jensen-Shannon Divergence for categorical data
    real_counts = np.zeros(len(all_categories))
    synth_counts = np.zeros(len(all_categories))
    
    for i, cat in enumerate(all_categories):
        real_counts[i] = np.sum(real_values == cat) / len(real_values)
        synth_counts[i] = np.sum(synth_values == cat) / len(synth_values)
    
    # Handle zero frequencies by adding a small epsilon
    epsilon = 1e-10
    real_counts = real_counts + epsilon
    synth_counts = synth_counts + epsilon
    
    # Normalize to make proper probability distributions
    real_counts = real_counts / np.sum(real_counts)
    synth_counts = synth_counts / np.sum(synth_counts)
    
    js_div = jensenshannon(real_counts, synth_counts)
    metrics['jensen_shannon'] = js_div
    
    return metrics

def analyze_file(file_path, target_feature='income', target_value='>50K', 
                 synthetic_feature='synthetic', output_dir=None, plot=False):
    """
    Analyze a single CSV file comparing real vs synthetic distributions.
    
    Args:
        file_path: Path to the CSV file
        target_feature: Column name for the target feature (default: 'income')
        target_value: Value in target feature to filter for (default: '>50K')
        synthetic_feature: Column name indicating if a record is synthetic (default: 'synthetic')
        output_dir: Directory to save plots to (if plot=True)
        plot: Whether to generate distribution plots
        
    Returns:
        Dictionary with analysis results
    """
    try:
        df = pd.read_csv(file_path)
    except Exception as e:
        return {'file': os.path.basename(file_path), 'error': str(e)}
    
    # Check required columns
    if target_feature not in df.columns or synthetic_feature not in df.columns:
        return {'file': os.path.basename(file_path), 
                'error': f"Missing required columns '{target_feature}' or '{synthetic_feature}'"}
    
    # Convert synthetic column to boolean if it's 'yes'/'no'
    if df[synthetic_feature].dtype == object:
        df[synthetic_feature] = df[synthetic_feature].apply(lambda x: 1 if str(x).lower() == 'yes' else 0)
    
    # Filter for target value
    if df[target_feature].dtype == object:
        df_filtered = df[df[target_feature].str.contains(target_value, case=False, na=False)]
    else:
        df_filtered = df[df[target_feature] == target_value]
    
    if len(df_filtered) == 0:
        return {'file': os.path.basename(file_path), 
                'error': f"No instances with {target_feature}={target_value} found"}
    
    # Separate real and synthetic samples
    real_samples = df_filtered[df_filtered[synthetic_feature] == 0]
    synthetic_samples = df_filtered[df_filtered[synthetic_feature] == 1]
    
    if len(real_samples) == 0 or len(synthetic_samples) == 0:
        return {'file': os.path.basename(file_path), 
                'error': f"Missing real or synthetic samples with {target_feature}={target_value}"}
    
    # Define numeric and categorical features
    numeric_features = ['age', 'fnlwgt', 'education_num', 'capital_gain', 
                        'capital_loss', 'hours_per_week']
    categorical_features = ['workclass', 'education', 'marital_status', 'occupation', 
                           'relationship', 'race', 'sex', 'native_country']
    
    # Filter out target and synthetic features, and missing features
    numeric_features = [f for f in numeric_features 
                        if f in df.columns and f != target_feature and f != synthetic_feature]
    categorical_features = [f for f in categorical_features 
                           if f in df.columns and f != target_feature and f != synthetic_feature]
    
    results = {'file': os.path.basename(file_path), 'numeric': {}, 'categorical': {}}
    
    # Process numeric features
    for feature in numeric_features:
        real_values = real_samples[feature].dropna().values
        synth_values = synthetic_samples[feature].dropna().values
        
        if len(real_values) < 2 or len(synth_values) < 2:
            continue
        
        metrics = calculate_metrics_numerical(real_values, synth_values)
        results['numeric'][feature] = metrics
        
        # Create distribution plots if requested
        if plot and output_dir:
            os.makedirs(output_dir, exist_ok=True)
            plt.figure(figsize=(10, 6))
            plt.hist(real_values, bins=30, alpha=0.5, label='Real', density=True)
            plt.hist(synth_values, bins=30, alpha=0.5, label='Synthetic', density=True)
            plt.title(f'Distribution of {feature} - {os.path.basename(file_path)}')
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"{os.path.basename(file_path).split('.')[0]}_{feature}_dist.png"))
            plt.close()
    
    # Process categorical features
    for feature in categorical_features:
        real_values = real_samples[feature].dropna().values
        synth_values = synthetic_samples[feature].dropna().values
        
        if len(real_values) < 2 or len(synth_values) < 2:
            continue
        
        metrics = calculate_metrics_categorical(real_values, synth_values)
        results['categorical'][feature] = metrics
        
        # Create bar plots for categorical features if requested
        if plot and output_dir:
            os.makedirs(output_dir, exist_ok=True)
            plt.figure(figsize=(12, 6))
            
            # Get value counts
            real_counts = pd.Series(real_values).value_counts(normalize=True)
            synth_counts = pd.Series(synth_values).value_counts(normalize=True)
            
            # Align indices
            all_categories = sorted(set(real_counts.index) | set(synth_counts.index))
            real_counts = real_counts.reindex(all_categories, fill_value=0)
            synth_counts = synth_counts.reindex(all_categories, fill_value=0)
            
            # Plot
            x = np.arange(len(all_categories))
            width = 0.35
            
            plt.bar(x - width/2, real_counts, width, label='Real')
            plt.bar(x + width/2, synth_counts, width, label='Synthetic')
            
            plt.xlabel('Categories')
            plt.ylabel('Frequency')
            plt.title(f'Distribution of {feature} - {os.path.basename(file_path)}')
            plt.xticks(x, all_categories, rotation=90)
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"{os.path.basename(file_path).split('.')[0]}_{feature}_dist.png"))
            plt.close()
    
    return results

def analyze_distributions(csv_files, target_feature='income', target_value='>50K', 
                          synthetic_feature='synthetic', output_dir=None, plot=False):
    """
    Analyze multiple CSV files and compute statistical distance metrics.
    
    Args:
        csv_files: List of paths to CSV files
        target_feature: Column name for the target feature (default: 'income')
        target_value: Value in target feature to filter for (default: '>50K')
        synthetic_feature: Column name indicating if a record is synthetic (default: 'synthetic')
        output_dir: Directory to save plots to (if plot=True)
        plot: Whether to generate distribution plots
        
    Returns:
        List of dictionaries with analysis results
    """
    results = []
    for file_path in csv_files:
        print(f"Processing {file_path}...")
        file_results = analyze_file(
            file_path, 
            target_feature=target_feature, 
            target_value=target_value,
            synthetic_feature=synthetic_feature,
            output_dir=output_dir, 
            plot=plot
        )
        results.append(file_results)
    
    return results

def print_results_detailed(results):
    """
    Print detailed analysis results in a tabular format without tabulate.
    
    Args:
        results: List of dictionaries with analysis results
    """
    for result in results:
        print("\n" + "="*100)
        print(f"File: {result['file']}")
        print("="*100)
        
        if 'error' in result:
            print(f"ERROR: {result['error']}")
            continue
        
        # Print numerical features
        if result['numeric']:
            print("\nNUMERICAL FEATURES")
            print("-"*100)
            
            # Print header
            print(f"{'Feature':<20} {'KS Statistic':<15} {'KS p-value':<15} {'Jensen-Shannon':<15} {'Earth Mover\\'s Dist':<15}")
            print("-"*85)
            
            # Print data
            for feature, metrics in sorted(result['numeric'].items()):
                print(f"{feature:<20} {metrics.get('ks_statistic', '-'):<15.6f} "
                      f"{metrics.get('ks_pvalue', '-'):<15.6f} "
                      f"{metrics.get('jensen_shannon', '-'):<15.6f} "
                      f"{metrics.get('earth_movers_distance', '-'):<15.6f}")
        
        # Print categorical features
        if result['categorical']:
            print("\nCATEGORICAL FEATURES")
            print("-"*100)
            
            # Print header
            print(f"{'Feature':<15} {'Chi² Statistic':<15} {'Chi² p-value':<15} {'G-test Statistic':<15} "
                  f"{'G-test p-value':<15} {'Fisher p-value':<15} {'Jensen-Shannon':<15}")
            print("-"*105)
            
            # Print data
            for feature, metrics in sorted(result['categorical'].items()):
                fisher_pval = metrics.get('fisher_pvalue', np.nan)
                if isinstance(fisher_pval, (list, np.ndarray)) and len(fisher_pval) > 0:
                    fisher_pval = fisher_pval[0]
                elif fisher_pval is None:
                    fisher_pval = np.nan
                    
                print(f"{feature:<15} {metrics.get('chi2_statistic', np.nan):<15.6f} "
                      f"{metrics.get('chi2_pvalue', np.nan):<15.6f} "
                      f"{metrics.get('g_test_statistic', np.nan):<15.6f} "
                      f"{metrics.get('g_test_pvalue', np.nan):<15.6f} "
                      f"{fisher_pval:<15.6f} "
                      f"{metrics.get('jensen_shannon', np.nan):<15.6f}")
    
    print("\n")

def print_results_summary(results):
    """
    Print a summary of results across all files without tabulate.
    
    Args:
        results: List of dictionaries with analysis results
    """
    print("\n" + "="*100)
    print("SUMMARY OF RESULTS")
    print("="*100)
    
    # Collect all numeric and categorical features across files
    all_numeric = set()
    all_categorical = set()
    
    for result in results:
        if 'error' in result:
            continue
        all_numeric.update(result['numeric'].keys())
        all_categorical.update(result['categorical'].keys())
    
    # Print header
    print(f"{'File':<20} {'Avg JS (Num)':<15} {'Avg KS (Num)':<15} {'Avg EMD (Num)':<15} "
          f"{'Avg JS (Cat)':<15} {'Avg Chi² p-val':<15} {'Avg G-test p-val':<15}")
    print("-"*110)
    
    # Print summary data for each file
    for result in results:
        if 'error' in result:
            continue
        
        file_name = result['file']
        
        # Calculate average metrics for each file
        num_js_values = [metrics.get('jensen_shannon', np.nan) for metrics in result['numeric'].values()]
        num_js_avg = np.nanmean(num_js_values) if num_js_values else np.nan
        
        num_ks_values = [metrics.get('ks_statistic', np.nan) for metrics in result['numeric'].values()]
        num_ks_avg = np.nanmean(num_ks_values) if num_ks_values else np.nan
        
        num_emd_values = [metrics.get('earth_movers_distance', np.nan) for metrics in result['numeric'].values()]
        num_emd_avg = np.nanmean(num_emd_values) if num_emd_values else np.nan
        
        cat_js_values = [metrics.get('jensen_shannon', np.nan) for metrics in result['categorical'].values()]
        cat_js_avg = np.nanmean(cat_js_values) if cat_js_values else np.nan
        
        cat_chi2_values = [metrics.get('chi2_pvalue', np.nan) for metrics in result['categorical'].values()]
        cat_chi2_avg = np.nanmean(cat_chi2_values) if cat_chi2_values else np.nan
        
        cat_gtest_values = [metrics.get('g_test_pvalue', np.nan) for metrics in result['categorical'].values()]
        cat_gtest_avg = np.nanmean(cat_gtest_values) if cat_gtest_values else np.nan
        
        print(f"{file_name:<20} {num_js_avg:<15.6f} {num_ks_avg:<15.6f} {num_emd_avg:<15.6f} "
              f"{cat_js_avg:<15.6f} {cat_chi2_avg:<15.6f} {cat_gtest_avg:<15.6f}")
    
    print("\n")

def run_analysis(csv_files=None, directory=None, target_feature='income', target_value='>50K', 
                 synthetic_feature='synthetic', plot=False, plot_dir='distribution_plots',
                 detailed_output=True, summary_output=True):
    """
    Main function to run the distribution analysis.
    
    Args:
        csv_files: List of paths to CSV files (default: None)
        directory: Directory containing CSV files (default: None)
        target_feature: Column name for the target feature (default: 'income')
        target_value: Value in target feature to filter for (default: '>50K')
        synthetic_feature: Column name indicating if a record is synthetic (default: 'synthetic')
        plot: Whether to generate distribution plots (default: False)
        plot_dir: Directory to save plots to (default: 'distribution_plots')
        detailed_output: Whether to print detailed results for each file (default: True)
        summary_output: Whether to print a summary of results across files (default: True)
        
    Returns:
        List of dictionaries with analysis results
    """
    # Get CSV files
    files_to_analyze = []
    
    if csv_files:
        files_to_analyze.extend(csv_files)
    
    if directory:
        dir_files = [os.path.join(directory, f) for f in os.listdir(directory) 
                     if f.endswith('.csv')]
        files_to_analyze.extend(dir_files)
    
    if not files_to_analyze:
        # If no files specified, use current directory
        files_to_analyze = [f for f in os.listdir('.') if f.endswith('.csv')]
        files_to_analyze = [os.path.join('.', f) for f in files_to_analyze]
    
    if not files_to_analyze:
        print("No CSV files found.")
        return []
    
    results = analyze_distributions(
        files_to_analyze, 
        target_feature=target_feature,
        target_value=target_value,
        synthetic_feature=synthetic_feature,
        output_dir=plot_dir if plot else None, 
        plot=plot
    )
    
    if detailed_output:
        print_results_detailed(results)
    
    if summary_output and len(results) > 1:
        print_results_summary(results)
    
    return results

# This code runs if this script is executed directly
if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Analyze distributions of real vs synthetic data.')
    parser.add_argument('--files', nargs='+', help='List of CSV files to analyze')
    parser.add_argument('--dir', help='Directory containing CSV files to analyze')
    parser.add_argument('--target-feature', default='income', 
                        help='Target feature column name (default: income)')
    parser.add_argument('--target-value', default='>50K', 
                        help='Target value to filter for (default: >50K)')
    parser.add_argument('--synthetic-feature', default='synthetic', 
                        help='Column indicating synthetic records (default: synthetic)')
    parser.add_argument('--plot', action='store_true', help='Generate distribution plots')
    parser.add_argument('--plot-dir', default='distribution_plots', 
                        help='Directory to save distribution plots')
    parser.add_argument('--no-details', action='store_true', 
                        help='Skip printing detailed results')
    parser.add_argument('--no-summary', action='store_true', 
                        help='Skip printing summary results')
    
    args = parser.parse_args()
    
    run_analysis(
        csv_files=args.files,
        directory=args.dir,
        target_feature=args.target_feature,
        target_value=args.target_value,
        synthetic_feature=args.synthetic_feature,
        plot=args.plot,
        plot_dir=args.plot_dir,
        detailed_output=not args.no_details,
        summary_output=not args.no_summary
    )

SyntaxError: f-string expression part cannot include a backslash (2242483259.py, line 326)

In [11]:
files_to_analyze = ["OutputTrainingSets/augmented_trainVALIDATE1.csv", "OutputTrainingSets/augmented_trainVALIDATE2.csv", "OutputTrainingSets/augmented_trainVALIDATE3.csv", "OutputTrainingSets/augmented_trainVALIDATE4.csv", "OutputTrainingSets/augmented_trainVALIDATE5.csv", "OutputTrainingSets/augmented_trainVALIDATE6.csv", "OutputTrainingSets/augmented_trainVALIDATE7.csv", "OutputTrainingSets/augmented_trainVALIDATE8.csv", "OutputTrainingSets/augmented_trainVALIDATE9.csv", "OutputTrainingSets/augmented_trainVALIDATE10.csv"]
target_feature = 'income'
target_value = '1'


results = run_analysis(
    csv_files=files_to_analyze,
    target_feature='income',
    target_value='1',synhetic_feature = 'synthetic'
)

TypeError: run_analysis() got an unexpected keyword argument 'detailed_output'