# Drawing historical traces of prediction power, temperature and SHAP

Visualization script for temporal analysis of geographical influence on city formation.

Creates multi-panel plots showing PR-AUC performance, temperature variations, and SHAP values across Chinese historical periods.

In [None]:
# Replicate figure 4
# After runing scripts "/statistical_analysis/quantitative_analysis/temporal_analysis_cn_fixed_window_subsample_pref_shap_ray.py", the results will be stored in a folder with 2 csv files:
# - time_window_analysis_with_2nd_nature.csv
# - time_window_analysis_without_2nd_nature.csv
# Please set the path to the folder in the variable "base_path" below.
# Besides, please set the path to the TraCE-21k temperature data in the variable "temp_file" below. Trace-21k data can be downloaded from https://www.earthsystemgrid.org/project/trace.html and processed with the script "/data_processing/climate_data_preprocessing/extract_trace21k_trace.py".

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
import os

# Set matplotlib parameters for publication quality
plt.rcParams.update({
    'font.size': 7,
    'font.family': 'Arial'
})

plt.rcParams['ps.fonttype'] = 42
plt.rcParams['pdf.fonttype'] = 42


# Define Chinese dynasty periods for background shading
color1 = "#E7E7E7"
color2 = "none"
dynasties = [
    {"name": "Qin", "start": -221, "end": -206, "color": color1},
    {"name": "Han", "start": -206, "end": 220, "color": color2},
    {"name": "Wei-Jin-N&S", "start": 220, "end": 581, "color": color1},
    {"name": "Sui", "start": 581, "end": 618, "color": color2},
    {"name": "Tang", "start": 618, "end": 907, "color": color1},
    {"name": "Five Dynasties", "start": 907, "end": 960, "color": color2},
    {"name": "Song", "start": 960, "end": 1279, "color": color1},
    {"name": "Yuan", "start": 1279, "end": 1368, "color": color2},
    {"name": "Ming", "start": 1368, "end": 1644, "color": color1},
    {"name": "Qing", "start": 1644, "end": 1911, "color": color2}
]


def calculate_statistics(df, metric):
    """
    Calculate mean and standard error for each time point.
    
    Args:
        df (pd.DataFrame): Input dataframe with temporal data
        metric (str): Column name of the metric to calculate statistics for
        
    Returns:
        tuple: (time_points, means, standard_errors)
    """
    time_points = sorted(df['time_mid'].unique())
    means = []
    ses = []
    
    # Filter time points to exclude data after 1971
    time_points = [time for time in time_points if df.loc[df['time_mid'] == time, 'end_year'].values[0] <= 1971]
    
    for time in time_points:
        # Get all iteration values for this time point
        values = df[df['time_mid'] == time][f'{metric}'].values
        # Calculate mean
        means.append(np.mean(values))
        # Calculate standard error
        ses.append(np.std(values, ddof=1) / np.sqrt(len(values)))
    
    return time_points, means, ses


def add_event_lines(ax, events, time_points, y_min, y_max):
    """
    Add event marker lines and descriptions to plot.
    
    Args:
        ax: Matplotlib axis object
        events (dict): Dictionary of events with timing and descriptions
        time_points (list): List of time points
        y_min (float): Minimum y-axis value
        y_max (float): Maximum y-axis value
    """
    for event_id, event in events.items():
        if event['period_idx'] < len(time_points):
            # Get corresponding time point
            event_time = time_points[event['period_idx']]
            
            # Add vertical dashed line
            ax.axvline(x=event_time, color='gray', linestyle='--', alpha=0.5)
            
            # Calculate text position
            y_pos = y_min + (y_max - y_min) * event['y_position']
            x_pos = event_time + event['offset']
            
            # Set text alignment
            ha = event['text_align']
            
            # Add text description
            ax.text(x_pos, y_pos, event['description'],
                    horizontalalignment=ha,
                    verticalalignment='center',
                    fontsize=10,
                    bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))


def create_climate_plot(time_points, prauc_means_with, prauc_ses_with, 
                       prauc_means_without, prauc_ses_without,
                       trace_21k, norm_shap_clim, norm_shap_dem, dynasties):
    """
    Create comprehensive climate analysis visualization with three panels.
    
    Args:
        time_points (list): Time points for x-axis
        prauc_means_with (list): PR-AUC means with 2nd nature geography
        prauc_ses_with (list): PR-AUC standard errors with 2nd nature geography
        prauc_means_without (list): PR-AUC means without 2nd nature geography
        prauc_ses_without (list): PR-AUC standard errors without 2nd nature geography
        trace_21k (pd.DataFrame): Temperature data from TraCE-21k simulation
        norm_shap_clim (np.array): Normalized SHAP values for climate
        norm_shap_dem (np.array): Normalized SHAP values for terrain & water
        dynasties (list): Dynasty period definitions
        
    Returns:
        matplotlib.figure.Figure: Complete figure object
    """
    line_width = 1.3

    # Create figure and grid layout
    fig = plt.figure(figsize=(3.7, 5.5), facecolor='none', edgecolor='none')
    gs = gridspec.GridSpec(3, 1, height_ratios=[0.9, 1.0, 0.95], hspace=0)
    
    # Create three subplots with shared x-axis
    ax1 = plt.subplot(gs[0])  # PR-AUC
    ax2 = plt.subplot(gs[1])  # Temperature
    ax3 = plt.subplot(gs[2])  # SHAP values

    xlim = (-300, 1911)
    for ax in [ax1, ax2, ax3]:
        ax.set_xlim(xlim)
    
    # Plot PR-AUC data
    line2 = ax1.plot(time_points, prauc_means_without, '-',
                    color='#734e89', label='W/o 2nd nature',
                    linewidth=line_width, zorder=5, markersize=line_width*1.5, marker='o')
    
    # Add confidence intervals
    ci2 = ax1.fill_between(time_points,
                    np.array(prauc_means_without) - 1.96*np.array(prauc_ses_without),
                    np.array(prauc_means_without) + 1.96*np.array(prauc_ses_without),
                    color='#734e89', alpha=0.12, zorder=3)
    
    # Set y-axis limits for PR-AUC plot
    min_val = (np.array(prauc_means_without) - 1.96*np.array(prauc_ses_without)).min()
    max_val = (np.array(prauc_means_without) + 1.96*np.array(prauc_ses_without)).max()
    ax1.set_ylim(0.92*min_val, 1.08*max_val)

    # Plot temperature data
    temp_line1 = ax2.plot(trace_21k['year'], trace_21k['temperature_original'], 
                         linestyle='--', color='#e67e22', alpha=0.5, linewidth=0.5*line_width,
                         label='Raw temperature')
    temp_line2 = ax2.plot(trace_21k['year'], trace_21k['temperature_smoothed'], 
                         color='#e67e22', label='Smoothed temperature', linewidth=line_width)

    ax2.set_ylabel('Temperature (°C)', fontsize=6)
    ax2.set_ylim(3.8, 5.1)
    
    # Plot SHAP values
    shap_line1 = ax3.plot(time_points, norm_shap_clim, '-', 
                         label='Climate', 
                         color='#00a19e', linewidth=line_width)
    shap_line3 = ax3.plot(time_points, norm_shap_dem, '-', 
                         label='Terrain&Water', 
                         color='#366fa3', linewidth=line_width)
    
    # Add confidence intervals for SHAP values
    ci3 = ax3.fill_between(time_points, 
                    np.array(norm_shap_clim) - 1.96*np.array(shap_clim_ses),
                    np.array(norm_shap_clim) + 1.96*np.array(shap_clim_ses),
                    color='#00a19e', alpha=0.12, zorder=3)
    ci5 = ax3.fill_between(time_points,
                    np.array(norm_shap_dem) - 1.96*np.array(shap_dem_ses),
                    np.array(norm_shap_dem) + 1.96*np.array(shap_dem_ses),
                    color='#366fa3', alpha=0.12, zorder=3)
    
    # Add dynasty background shading to all subplots
    for ax in [ax1, ax2, ax3]:
        for dynasty in dynasties:
            ax.axvspan(dynasty["start"], dynasty["end"], 
                alpha=0.4, facecolor=dynasty["color"], zorder=1, edgecolor='none')
    
    # Add dynasty labels (only on top subplot)
    for dynasty in dynasties:
        mid_year = (dynasty["start"] + dynasty["end"]) / 2
        ax1.text(mid_year, ax1.get_ylim()[1]*1.01,
                dynasty["name"].replace(" ", "\n"),
                horizontalalignment='center',
                verticalalignment='bottom',
                rotation=0,
                fontsize=5)
    
    # Format all axes
    for ax in [ax1, ax2, ax3]:
        ax.tick_params(axis='both', length=3, width=0.5, labelsize=6)
        for spine in ax.spines.values():
            spine.set_linewidth(0.5)
        ax.set_facecolor('none')  # Set transparent background
        ax.patch.set_alpha(0.0)   # Set patch transparency to 0
    
    # Set axis labels
    ax1.set_ylabel('PR-AUC', fontsize=6)
    ax3.set_ylabel('Mean absolute SHAP value', fontsize=6)
    ax3.set_xlabel('Year (CE/BCE)')
    
    # Add legends
    ax2.legend(frameon=False,
               loc='upper center',
               bbox_to_anchor=(0.5, 1),
               ncol=2,
               bbox_transform=ax2.transAxes,
               fontsize=6)
    
    ax3.legend(frameon=False,
              loc='upper center',
              ncol=3,
              bbox_transform=ax3.transAxes,
              fontsize=6)
    
    # Set SHAP plot y-axis limits
    ax3.set_ylim(0.5, 4)
    
    # Configure x-axis display
    ax1.tick_params(axis='x', which='both', length=0)  # Remove x-axis ticks from first subplot
    ax2.tick_params(axis='x', which='both', length=0)  # Remove x-axis ticks from second subplot

    # Remove x-axis labels from top two subplots
    ax1.set_xticklabels([])
    ax2.set_xticklabels([])

    # Remove bottom spine from top two subplots
    ax1.spines['bottom'].set_visible(False)
    
    # Remove top spine from bottom subplot
    ax3.spines['top'].set_visible(False)
    
    return fig


# Set input file paths
base_path = "/path/to/statistical_model_analysis/scripts/statistical_analysis/model_project/results/time_window_analysis/pref"
run_id = "20250623_153043"

csv_with_2nd = f"{base_path}/{run_id}/time_window_analysis_with_2nd_nature.csv"
csv_without_2nd = f"{base_path}/{run_id}/time_window_analysis_without_2nd_nature.csv"

# Load data
df_with_2nd = pd.read_csv(csv_with_2nd)
df_without_2nd = pd.read_csv(csv_without_2nd)

# Calculate time midpoints
for df in [df_with_2nd, df_without_2nd]:
    df['time_mid'] = (df['start_year'] + df['end_year']) / 2

# Calculate statistics for different metrics
time_points, prauc_means_with, prauc_ses_with = calculate_statistics(df_with_2nd, 'prauc')
time_points, shap_clim_means, shap_clim_ses = calculate_statistics(df_with_2nd, 'shap_clim')
time_points, shap_dem_means, shap_dem_ses = calculate_statistics(df_with_2nd, 'shap_dem')
time_points_without, prauc_means_without, prauc_ses_without = calculate_statistics(df_without_2nd, 'prauc')

# Normalize SHAP values
norm_shap_clim = np.array(shap_clim_means) 
norm_shap_dem = np.array(shap_dem_means) 

# Load temperature data from TraCE-21k simulation
temp_file = '/path/to/china_temperature_series_trace21k1.parquet'
trace_21k = pd.read_parquet(temp_file)
trace_21k = trace_21k[trace_21k['year'] <= 1850]
trace_21k = trace_21k[trace_21k['year'] >= -250]

# Create the visualization
fig = create_climate_plot(time_points, prauc_means_with, prauc_ses_with,
                            prauc_means_without, prauc_ses_without,
                            trace_21k, norm_shap_clim, norm_shap_dem, 
                            dynasties)

# Save the figure
plt.savefig('climate_analysis_vertical.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
# Replicate Extended Data Figure 7
# Please first run the script "/statistical_analysis/quantitative_analysis/temporal_analysis_cn_fixed_window_subsample_pref_shap_ray_no_test_subsample.py" to generate the required CSV file.
# Then set the path to the CSV file in the variable "csv_with_2nd" below.
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import matplotlib.ticker as ticker
import pandas as pd
from matplotlib.ticker import ScalarFormatter, FormatStrFormatter

# Set Nature journal recommended figure parameters for publication quality
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'],
    'font.size': 7,
    'axes.titlesize': 8,
    'axes.labelsize': 8,
    'xtick.labelsize': 7,
    'ytick.labelsize': 7,
    'legend.fontsize': 7,
    'axes.linewidth': 0.7,
    'lines.linewidth': 1.0,
    'patch.linewidth': 0.5,
    'lines.markersize': 3,
    'xtick.major.width': 0.7,
    'ytick.major.width': 0.7,
    'xtick.minor.width': 0.7,
    'ytick.minor.width': 0.7,
    'xtick.major.size': 4,
    'ytick.major.size': 4,
    'xtick.minor.size': 2,
    'ytick.minor.size': 2,
    'axes.spines.top': True,
    'axes.spines.right': True,
    'figure.dpi': 600,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})

# Define Chinese dynasty periods for background shading
color1 = "#F6F6F6"  # Light gray for alternating dynasty background
color2 = "none"     # Transparent for alternating dynasty background

dynasties = [
    {"name": "Qin", "start": -221, "end": -206, "color": color1},
    {"name": "Han", "start": -206, "end": 220, "color": color2},
    {"name": "Wei-Jin-N&S", "start": 220, "end": 581, "color": color1},
    {"name": "Sui", "start": 581, "end": 618, "color": color2},
    {"name": "Tang", "start": 618, "end": 907, "color": color1},
    {"name": "Five Dynasties", "start": 907, "end": 960, "color": color2},
    {"name": "Song", "start": 960, "end": 1279, "color": color1},
    {"name": "Yuan", "start": 1279, "end": 1368, "color": color2},
    {"name": "Ming", "start": 1368, "end": 1644, "color": color1},
    {"name": "Qing", "start": 1644, "end": 1911, "color": color2}
]


def calculate_statistics(df, metric):
    """
    Calculate mean and standard error for each time point.
    
    Args:
        df (pd.DataFrame): Input dataframe with temporal data
        metric (str): Column name of the metric to calculate statistics for
        
    Returns:
        tuple: (time_points, means, standard_errors)
    """
    time_points = sorted(df['time_mid'].unique())
    means = []
    ses = []
    
    # Filter time points to exclude data after 1971
    time_points = [time for time in time_points if df.loc[df['time_mid'] == time, 'end_year'].values[0] <= 1971]
    
    for time in time_points:
        # Get all iteration values for this time point
        values = df[df['time_mid'] == time][f'{metric}'].values
        # Calculate mean
        means.append(np.mean(values))
        # Calculate standard error
        ses.append(np.std(values, ddof=1) / np.sqrt(len(values)))
    
    return time_points, means, ses


def create_dual_axis_plot(time_points, prauc_means_with, prauc_ses_with, 
                         positive_counts, positive_ses, dynasties):
    """
    Create dual-axis plot showing PR-AUC performance and positive sample counts.
    
    Args:
        time_points (list): Time points for x-axis
        prauc_means_with (list): PR-AUC mean values
        prauc_ses_with (list): PR-AUC standard errors
        positive_counts (list): Positive sample counts
        positive_ses (list): Positive sample count standard errors
        dynasties (list): Dynasty period definitions for background shading
        
    Returns:
        matplotlib.figure.Figure: Complete figure object
    """
    line_width = 1.5

    # Create figure with specified dimensions for publication
    fig = plt.figure(figsize=(5, 2.5), facecolor='none', edgecolor='none', dpi=600)
    gs = gridspec.GridSpec(1, 1)
    
    # Create primary axis for PR-AUC
    ax1 = plt.subplot(gs[0])
    
    # Set x-axis limits
    xlim = (-300, 1950)
    ax1.set_xlim(xlim)
    
    # Plot PR-AUC data with confidence intervals
    line1 = ax1.plot(time_points, prauc_means_with, '-', 
                    color='#3783BB', label='Prediction accuracy (PR-AUC)', 
                    linewidth=line_width, zorder=5, marker='o', markersize=line_width+1.5)
    
    # Add confidence intervals for PR-AUC
    ci1 = ax1.fill_between(time_points, 
                    np.array(prauc_means_with) - 1.96*np.array(prauc_ses_with),
                    np.array(prauc_means_with) + 1.96*np.array(prauc_ses_with),
                    color='#3783BB', alpha=0.3, zorder=3, edgecolor='none')

    # Set y-axis limits for PR-AUC with appropriate padding
    min_val = (np.array(prauc_means_with) - 1.96*np.array(prauc_ses_with)).min()
    max_val = (np.array(prauc_means_with) + 1.96*np.array(prauc_ses_with)).max()
    ax1.set_ylim(0, 1.2*max_val)

    # Create secondary y-axis for positive sample counts
    ax2 = ax1.twinx()
    line2 = ax2.plot(time_points, positive_counts, '-', 
                     color='#32a852', label='Historical positive samples (n)', 
                     linewidth=line_width, zorder=4, marker='s', markersize=line_width+1.5)
    
    # Set y-axis limits for positive counts with appropriate padding
    pos_min = (np.array(positive_counts) - 1.96*np.array(positive_ses)).min()
    pos_max = (np.array(positive_counts) + 1.96*np.array(positive_ses)).max()
    padding = (pos_max - pos_min) * 0.1  # Add 10% padding
    ax2.set_ylim(max(0, pos_min - padding), pos_max + padding + 200)
    
    # Add dynasty background shading
    for dynasty in dynasties:
        ax1.axvspan(dynasty["start"], dynasty["end"], 
                   alpha=1, facecolor=dynasty["color"], zorder=1, edgecolor='none')
    
    # Add dynasty labels at the top
    for dynasty in dynasties:
        mid_year = (dynasty["start"] + dynasty["end"]) / 2
        ax1.text(mid_year, ax1.get_ylim()[1]*1.01,
                dynasty["name"].replace(" ", "\n"),
                horizontalalignment='center',
                verticalalignment='bottom',
                rotation=0,
                fontsize=6)
    
    # Configure tick marks and minor ticks for professional appearance
    ax1.xaxis.set_minor_locator(ticker.AutoMinorLocator())
    ax1.yaxis.set_minor_locator(ticker.AutoMinorLocator())
    ax2.yaxis.set_minor_locator(ticker.AutoMinorLocator())
    
    # Set custom tick positions
    ax1.set_yticks(np.arange(0, 0.16, 0.05))
    ax2.set_yticks(np.arange(0, 3900, 1000))
    ax1.set_xticks(np.arange(-300, 1950, 300))
    
    # Set axis labels with appropriate colors
    ax1.set_ylabel('PR-AUC', color='#3783BB')
    ax1.set_xlabel('Year (CE/BCE)')
    ax2.set_ylabel('Positive samples (n)', color='#32a852', labelpad=10, rotation=270)
    
    # Color-code tick labels to match their respective y-axes
    ax1.tick_params(axis='y', labelcolor='#3783BB')
    ax2.tick_params(axis='y', labelcolor='#32a852')

    # Format secondary y-axis with scientific notation if needed
    formatter = ScalarFormatter(useMathText=True)
    formatter.set_powerlimits((0, 0))
    ax2.yaxis.set_major_formatter(formatter)
    ax2.yaxis.offsetText.set_position((1.06, 0))  
    
    # Create combined legend for both lines
    lines = line1 + line2
    labels = [l.get_label() for l in lines]
    ax1.legend(lines, labels, 
               loc='upper center', 
               bbox_to_anchor=(0.5, -0.2),
               ncol=2, frameon=False, 
               handlelength=1.5,
               fontsize=8,
               columnspacing=1.5)
    
    return fig



# Set input file path
csv_with_2nd = "/path/to/temporal_results_file.csv"

# Load temporal analysis data
df_with_2nd = pd.read_csv(csv_with_2nd)

# Calculate time midpoints for each temporal window
df_with_2nd['time_mid'] = (df_with_2nd['start_year'] + df_with_2nd['end_year']) / 2

# Calculate statistics for PR-AUC and positive sample counts
time_points, prauc_means_with, prauc_ses_with = calculate_statistics(df_with_2nd, 'pr_auc')
time_points, positive_counts, positive_ses = calculate_statistics(df_with_2nd, 'positive_count')

# Create the dual-axis visualization
fig = create_dual_axis_plot(time_points, prauc_means_with, prauc_ses_with,
                            positive_counts, positive_ses, dynasties)

# Save the figure in multiple formats for publication
output_base = '/path/to/plotting/results_materials/Extended_data_Fig7/pr_auc_with_positive_count'
plt.savefig(f'{output_base}.png', bbox_inches='tight', dpi=600)
plt.savefig(f'{output_base}.pdf', bbox_inches='tight')

plt.show()
plt.close()