# Script to Plot performance bar plot (Figure 3a,d)

- Please first run '/statistical_analysis/quantitative_analysis/base_analysis_bootstrapping.py' to train classifier and yield evaluation results.
- Then set file_paths using output performance file.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib as mpl
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.font_manager as fm
from matplotlib.lines import Line2D

mpl.rcParams['pdf.fonttype'] = 42  
mpl.rcParams['ps.fonttype'] = 42   
mpl.rcParams['svg.fonttype'] = 'none'  
mpl.rcParams['text.usetex'] = False  
mpl.rcParams['pdf.use14corefonts'] = False
mpl.rcParams['pdf.compression'] = False
plt.rcParams["font.family"] = "Arial"
mpl.rcParams['font.sans-serif'] = ["Arial"]
plt.rcParams["font.size"] = 14

# Define unified color scheme for consistent visualization
colors = {
    'climate': '#109a96',          # Climate - teal green
    'terrain_water': '#336a9a',    # Terrain & Water - blue
    'manual_all': '#d36153',       # All manual features - red
    'embedding_all': '#eab14d',    # All embedding features - golden orange
    'embedding_climate': '#77508f', # Climate embedding - purple
    'embedding_terrain': '#a3507c'  # Terrain & Water embedding - magenta
}

# Define dataset names and model configurations
DATASETS = ['Europe', 'China\n(CHGIS)', 'China\n(walled city)']
MODEL_ORDER = [
    'Attribute',           # All manual feature
    'All embedding',       # All embedding
    'Climate embedding',   # Climate embedding
    'DEM&Water embedding', # DEM&Water embedding
]
MODEL_NAMES = [
    'All manual feature',
    'All embedding',
    'Climate embedding',
    'DEM&Water embedding',
]


def load_and_prepare_data(file_paths):
    """
    Load CSV data files and prepare for analysis.
    
    Args:
        file_paths (dict): Dictionary containing paths to CSV files for different datasets
        
    Returns:
        dict: Dictionary containing loaded dataframes
    """
    data = {}
    for key, path in file_paths.items():
        try:
            data[key] = pd.read_csv(path)
            print(f"Successfully loaded {key} data from {path}")
        except FileNotFoundError:
            print(f"Warning: Could not find file {path}")
            data[key] = None
    return data


def extract_shap_data(data_dict):
    """
    Extract SHAP values from the loaded datasets.
    
    Args:
        data_dict (dict): Dictionary containing dataframes for different datasets
        
    Returns:
        dict: Dictionary containing SHAP statistics for each feature type
    """
    # Find "All embedding" models in each dataset
    all_emb_data = {}
    for key, df in data_dict.items():
        if df is not None:
            all_emb_data[key] = df[df['model'] == 'All embedding'].iloc[0]
    
    # Define SHAP feature types (excluding agriculture)
    shap_types = [
        {'name': 'Climate', 'key': 'clim', 'color': colors['climate']},
        {'name': 'Terrain & Water', 'key': 'dem', 'color': colors['terrain_water']},
    ]
    
    # Prepare SHAP data structure
    shap_data = {}
    for shap in shap_types:
        key = shap['key']
        shap_data[key] = {
            'means': [
                all_emb_data['eu'][f'{key}_shap_mean'],
                all_emb_data['cnpref'][f'{key}_shap_mean'],
                all_emb_data['cnwalled'][f'{key}_shap_mean'],
            ],
            'ci_lower': [
                all_emb_data['eu'][f'{key}_shap_ci_lower'],
                all_emb_data['cnpref'][f'{key}_shap_ci_lower'],
                all_emb_data['cnwalled'][f'{key}_shap_ci_lower'],
            ],
            'ci_upper': [
                all_emb_data['eu'][f'{key}_shap_ci_upper'],
                all_emb_data['cnpref'][f'{key}_shap_ci_upper'],
                all_emb_data['cnwalled'][f'{key}_shap_ci_upper'],
            ]
        }
    
    return shap_data, shap_types


def extract_prauc_data(data_dict):
    """
    Extract PR-AUC performance data from datasets.
    
    Args:
        data_dict (dict): Dictionary containing dataframes for different datasets
        
    Returns:
        tuple: (pr_auc_means, pr_auc_err_low, pr_auc_err_high, reorder_indices)
    """
    data_sources = [data_dict['eu'], data_dict['cnpref'], data_dict['cnwalled']]
    
    # Get reordering indices for consistent model ordering
    original_models = data_dict['cnpref']['model'].tolist()
    reorder_indices = []
    model_to_index = {model: i for i, model in enumerate(original_models)}
    
    for model_name in MODEL_ORDER:
        for orig_model in original_models:
            if model_name in orig_model:
                reorder_indices.append(model_to_index[orig_model])
                break
    
    # Extract PR-AUC data
    pr_auc_means = []
    pr_auc_err_low = []
    pr_auc_err_high = []
    
    for data_source in data_sources:
        means = data_source['pr_auc_mean'].values[reorder_indices]
        ci_lower = data_source['pr_auc_ci_lower'].values[reorder_indices]
        ci_upper = data_source['pr_auc_ci_upper'].values[reorder_indices]
        
        pr_auc_means.append(means)
        pr_auc_err_low.append(means - ci_lower)
        pr_auc_err_high.append(ci_upper - means)
    
    return pr_auc_means, pr_auc_err_low, pr_auc_err_high, reorder_indices


def create_shap_plot(ax, shap_data, shap_types):
    """
    Create SHAP values bar plot.
    
    Args:
        ax: Matplotlib axis object
        shap_data (dict): SHAP statistics data
        shap_types (list): List of SHAP feature type definitions
    """
    # Set up bar plot parameters
    x = np.arange(len(DATASETS)) / 2
    width = 0.155
    spacing = 0.02
    
    # Create bars for each SHAP feature type
    for i, shap in enumerate(shap_types):
        key = shap['key']
        color = shap['color']
        offset = (width + spacing) * (i - 0.5)  # Center alignment for two bars
        
        means = shap_data[key]['means']
        ci_lower = shap_data[key]['ci_lower']
        ci_upper = shap_data[key]['ci_upper']
        
        # Calculate error bars
        yerr_low = [mean - lower if lower is not None else 0 for mean, lower in zip(means, ci_lower)]
        yerr_high = [upper - mean if upper is not None else 0 for mean, upper in zip(means, ci_upper)]
        
        # Create bars
        bars = ax.bar(x + offset, means, width, 
                     color=color, alpha=1.0, 
                     edgecolor='none', linewidth=0.5,
                     label=shap['name'])
        
        # Add error bars
        ax.errorbar(x + offset, means, 
                   yerr=[yerr_low, yerr_high], 
                   fmt='none', ecolor='black', 
                   capsize=4, elinewidth=1.0)
    
    # Configure plot appearance
    ax.set_xlabel('Dataset', fontsize=14, labelpad=10)
    ax.set_ylabel('Mean absolute SHAP value', fontsize=14, labelpad=10)
    ax.set_xticks(x)
    ax.set_xticklabels(DATASETS, fontsize=14)
    ax.tick_params(axis='both', which='both', length=7, width=1.5)
    
    # Set grid and borders
    ax.yaxis.grid(True, linestyle='-', alpha=0.3, color='gray')
    ax.set_axisbelow(True)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(1.5)
    ax.spines['left'].set_linewidth(1.5)
    ax.set_xlim(-0.3, 1.3)


def create_prauc_plot(ax, pr_auc_means, pr_auc_err_low, pr_auc_err_high):
    """
    Create PR-AUC performance comparison plot.
    
    Args:
        ax: Matplotlib axis object
        pr_auc_means (list): Mean PR-AUC values
        pr_auc_err_low (list): Lower error bounds
        pr_auc_err_high (list): Upper error bounds
    """
    # Model colors
    model_colors = [
        colors['manual_all'],
        colors['embedding_all'],
        colors['embedding_climate'],
        colors['embedding_terrain']
    ]
    
    # Set up bar positions
    x2 = np.arange(len(DATASETS)) / 2
    width2 = 0.14
    spacing2 = 0.016
    
    # Special handling: Stack All manual feature and All embedding
    for dataset_idx in range(len(DATASETS)):
        x_pos = x2[dataset_idx] + width2 + spacing2
        
        # All manual feature (bottom layer)
        manual_mean = pr_auc_means[dataset_idx][0]
        manual_err_low = pr_auc_err_low[dataset_idx][0]
        manual_err_high = pr_auc_err_high[dataset_idx][0]
        
        # All embedding (top layer)
        embedding_mean = pr_auc_means[dataset_idx][1]
        embedding_err_low = pr_auc_err_low[dataset_idx][1]
        embedding_err_high = pr_auc_err_high[dataset_idx][1]
        
        # Create stacked bars
        ax.bar(x_pos, embedding_mean, width2, 
               color=model_colors[1], alpha=1,
               edgecolor='none', linewidth=0,
               label='All embedding' if dataset_idx == 0 else "")
        
        ax.bar(x_pos, manual_mean, width2, 
               color=model_colors[0], alpha=1,
               edgecolor='none', linewidth=0,
               label='All manual feature' if dataset_idx == 0 else "")
        
        # Add error bars
        ax.errorbar(x_pos, manual_mean, 
                   yerr=[[manual_err_low], [manual_err_high]], 
                   fmt='none', ecolor='black', 
                   capsize=4, elinewidth=1)
        
        ax.errorbar(x_pos, embedding_mean, 
                   yerr=[[embedding_err_low], [embedding_err_high]], 
                   fmt='none', ecolor='black', 
                   capsize=4, elinewidth=1)
    
    # Draw other embedding models
    for model_idx in [2, 3]:  # Climate embedding, DEM&Water embedding
        offset = (width2 + spacing2) * (model_idx - 3)
        
        for dataset_idx in range(len(DATASETS)):
            mean = pr_auc_means[dataset_idx][model_idx]
            err_low = pr_auc_err_low[dataset_idx][model_idx]
            err_high = pr_auc_err_high[dataset_idx][model_idx]
            
            ax.bar(x2[dataset_idx] + offset, mean, width2, 
                   color=model_colors[model_idx], alpha=1.0,
                   edgecolor='none', linewidth=0.5,
                   label=MODEL_NAMES[model_idx] if dataset_idx == 0 else "")
            
            ax.errorbar(x2[dataset_idx] + offset, mean, 
                       yerr=[[err_low], [err_high]], 
                       fmt='none', ecolor='black', 
                       capsize=4, elinewidth=1)
    
    # Configure plot appearance
    ax.set_ylabel('PR-AUC', fontsize=14, labelpad=10)
    ax.set_xticks(x2)
    ax.set_xticklabels(DATASETS, fontsize=14)
    ax.set_ylim(0, 0.31)
    ax.set_yticks([0, 0.1, 0.2, 0.3])
    ax.set_yticklabels(['0', '0.1', '0.2', '0.3'], fontsize=14)
    ax.yaxis.grid(True, linestyle='-', alpha=0.3, color='gray')
    
    # Set borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(1.5)
    ax.spines['left'].set_linewidth(1.5)
    ax.tick_params(axis='both', which='both', length=7, width=1.5)


def create_combined_figure(data_dict):
    """
    Create the complete Figure 3 with both SHAP and PR-AUC panels.
    
    Args:
        data_dict (dict): Dictionary containing loaded dataframes
        
    Returns:
        matplotlib.figure.Figure: Complete figure object
    """
    # Create vertical layout figure
    width = 6
    height = 8
    fig = plt.figure(figsize=(width, height))
    
    # Create subplot layout - vertical arrangement
    gs = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.3)
    ax1 = fig.add_subplot(gs[0, 0])  # Top: PR-AUC plot
    ax2 = fig.add_subplot(gs[1, 0])  # Bottom: SHAP values plot
    
    # Extract and plot SHAP data
    shap_data, shap_types = extract_shap_data(data_dict)
    create_shap_plot(ax2, shap_data, shap_types)
    
    # Extract and plot PR-AUC data
    pr_auc_means, pr_auc_err_low, pr_auc_err_high, _ = extract_prauc_data(data_dict)
    create_prauc_plot(ax1, pr_auc_means, pr_auc_err_low, pr_auc_err_high)
    
    # Create legends
    # SHAP legend
    shap_legend_elements = [
        Line2D([], [], marker='s', 
               markerfacecolor=shap['color'],
               markeredgewidth=0,
               markersize=14,
               linestyle='None',
               label=shap['name'])
        for shap in shap_types
    ]
    
    # PR-AUC legend
    model_colors = [colors['manual_all'], colors['embedding_all'], 
                   colors['embedding_climate'], colors['embedding_terrain']]
    prauc_legend_elements = [
        Line2D([], [], marker='s',
               markerfacecolor=model_colors[i],
               markeredgewidth=0,
               markersize=14,
               linestyle='None',
               label=MODEL_NAMES[i])
        for i in range(len(MODEL_NAMES))
    ]
    
    # Add legends to the right side
    legend1 = fig.legend(
        handles=shap_legend_elements,
        title='Feature Groups',
        loc='center left',
        bbox_to_anchor=(0.83, 0.25),
        frameon=False,
        title_fontsize=14,
        fontsize=12,
        alignment='left',
    )
    
    legend2 = fig.legend(
        handles=prauc_legend_elements,
        title='Models',
        loc='center left',
        bbox_to_anchor=(0.83, 0.75),
        frameon=False,
        title_fontsize=14,
        fontsize=12,
        alignment='left',
    )
    
    # Remove bottom subplot x-label to avoid duplication
    ax2.set_xlabel('')
    
    return fig


file_paths = {
    'cnpref': '/path/to/urban_niche_statistical_analysis_datasets/hyper_parameter_tuning/best/results_cn-pref5.csv',
    'cnwalled': '/path/to/urban_niche_statistical_analysis_datasets/hyper_parameter_tuning/best/results_cn-walled5.csv',
    'eu': '/path/to/urban_niche_statistical_analysis_datasets/hyper_parameter_tuning/best/results_eu5.csv'
}

# Load data
print("Loading data files...")
data_dict = load_and_prepare_data(file_paths)

# Create the combined figure
print("Creating Figure 3 visualization...")
fig = create_combined_figure(data_dict)

# Adjust layout
plt.tight_layout()
plt.subplots_adjust(right=0.85)  # Leave space for right-side legends

# Save the figure
output_base = '/path/to/plotting/results_materials/Figure3/combined_figure'

plt.savefig(f'{output_base}.pdf', 
            dpi=300, bbox_inches='tight', transparent=True)
plt.savefig(f'{output_base}.png', 
            dpi=300, bbox_inches='tight')

print(f"Figure 3 saved to {output_base}.pdf and {output_base}.png")

plt.show()