In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from scipy.stats import gaussian_kde

from physhapes.plotting import *
from physhapes.mcmc import load_mcmc_results

In [7]:
def export_diagnostics(diagnostics, save_path):
    """
    Export the content of diagnostics dictionary to a text file
    
    Parameters:
    -----------
    diagnostics : dict
        Dictionary containing diagnostic values
    save_path : str
        Path to save the diagnostics file
    """
    diagnostics_path = os.path.join(save_path, "mcmc_diagnostics.txt")
    
    with open(diagnostics_path, "w") as f:
        f.write("MCMC Diagnostics Summary\n")
        f.write("=======================\n\n")
        
        # Write parameter diagnostics
        if "parameter_ess" in diagnostics:
            f.write("Parameter Effective Sample Size (ESS):\n")
            for param, ess in diagnostics["parameter_ess"].items():
                f.write(f"  {param}: {ess:.2f}\n")
            f.write("\n")
        
        if "parameter_rhat" in diagnostics:
            f.write("Parameter Gelman-Rubin Statistics (R-hat):\n")
            for param, rhat in diagnostics["parameter_rhat"].items():
                f.write(f"  {param}: {rhat:.4f}\n")
            f.write("\n")
        
        # Write tree node diagnostics
        if "ESS" in diagnostics:
            f.write("Tree Node Effective Sample Size (ESS):\n")
            f.write(f"  Minimum: {diagnostics['ESS'].min():.2f}\n")
            f.write(f"  Maximum: {diagnostics['ESS'].max():.2f}\n")
            f.write(f"  Mean: {diagnostics['ESS'].mean():.2f}\n")
            f.write(f"  Median: {np.median(diagnostics['ESS']):.2f}\n")
            f.write("\n")
        
        if "Rhat" in diagnostics:
            f.write("Tree Node Gelman-Rubin Statistics (R-hat):\n")
            f.write(f"  Minimum: {diagnostics['Rhat'].min():.4f}\n")
            f.write(f"  Maximum: {diagnostics['Rhat'].max():.4f}\n")
            f.write(f"  Mean: {diagnostics['Rhat'].mean():.4f}\n")
            f.write(f"  Median: {np.median(diagnostics['Rhat']):.4f}\n")
            f.write("\n")
            
            # List nodes with poor convergence
            poor_convergence = np.where(diagnostics['tree_rhat'] > 1.1)[0]
            if len(poor_convergence) > 0:
                f.write("Nodes with poor convergence (R-hat > 1.1):\n")
                for node in poor_convergence:
                    f.write(f"  Node {node}: R-hat = {diagnostics['tree_rhat'][node]:.4f}\n")
            else:
                f.write("All nodes show good convergence (R-hat <= 1.1)\n")
    
    print(f"Diagnostics exported to {diagnostics_path}")

In [12]:
def process_mcmc_results(sim_path, burnin_percent=0.3, node_idx=[0, 1, 2, 6], force_replot=False):
    """
    Process MCMC results for both regular and Procrustes chains
    
    Parameters:
    -----------
    sim_path : str
        Path to the simulation folder
    burnin_percent : float
        Percentage of chain to discard as burn-in
    node_idx : list
        Node indices to plot
    force_replot : bool
        If True, regenerate plots even if they already exist
    """
    # Define types of MCMC to check
    mcmc_types = ["mcmc", "mcmc_procrustes"]
    param_names = ["sigma", "alpha"]
    
    for runfile in os.listdir(sim_path):
        print(f"Processing inference on simulated data set {runfile}")
        
        # Process both regular and Procrustes MCMC
        for mcmc_type in mcmc_types:
            base_mcmc_path = os.path.join(sim_path, runfile, mcmc_type)
            
            if not os.path.exists(base_mcmc_path):
                print(f"{mcmc_type} path does not exist: {base_mcmc_path}")
                continue
                
            mcmc_runs = os.listdir(base_mcmc_path)
            if not mcmc_runs:
                print(f"No {mcmc_type} runs found in: {base_mcmc_path}")
                continue
                
            print(f"Found {len(mcmc_runs)} {mcmc_type} run directories")
            
            for subid in mcmc_runs:
                mcmc_path = os.path.join(base_mcmc_path, subid)
                results_path = os.path.join(mcmc_path, "results_*.pkl")
                save_path = os.path.join(mcmc_path, "plots")
                os.makedirs(save_path, exist_ok=True)
                
                # Check if we've already generated plots
                log_posterior_plot = os.path.join(save_path, f'log_posterior_burnin_percent={burnin_percent}.png')
                if os.path.exists(log_posterior_plot) and not force_replot:
                    print(f"Plots already exist for {mcmc_type}/{subid}, skipping. Use force_replot=True to regenerate.")
                    continue
                
                # Load results
                chain_results = load_mcmc_results(results_path)
                
                # Check if we actually got results
                if not chain_results or all(result is None for result in chain_results):
                    print(f"No valid MCMC results found in: {results_path}")
                    continue
                
                print(f"Found {len(chain_results)} chains with results in {mcmc_type}/{subid}")
                
                try:
                    # Create basic diagnostic plots
                    plot_log_posterior(chain_results, burnin_percent, 
                                      save_path=os.path.join(save_path, f'log_posterior_burnin_percent={burnin_percent}.png'))
                    
                    plot_parameter_traces(chain_results, param_names, burnin_percent, 
                                         savepath=os.path.join(save_path, f'parameter_traces_burnin_percent={burnin_percent}.png'))
                    
                    # Calculate convergence diagnostics
                    diagnostics = compute_diagnostics(chain_results, burnin_percent)
                    
                    # Export diagnostics to a text file
                    export_diagnostics(diagnostics, save_path)
                    #np.savetxt(os.path.join(save_path, "Rhat_path.csv"),
                    #           diagnostics['Rhat'], delimiter=",")
                    # Create trace plots
                    plot_traces(chain_results, burnin_percent, node_idx=node_idx, 
                               save_path=save_path, diagnostics=diagnostics)
                    
                    # Create posterior sample plots
                    plot_samples_from_posterior(chain_results, burnin_percent=0.5, 
                                               node_idx=node_idx, sample_every=50, 
                                               savepath=save_path, true_values=None)
                    
                    # Compute and print parameter estimates
                    burnin_end = int(chain_results[0]['sigma'].shape[0] * burnin_percent)
                    sigma_mean = np.mean([chain_results[i]['sigma'][burnin_end:] 
                                         for i in range(len(chain_results)) if chain_results[i] is not None])
                    alpha_mean = np.mean([chain_results[i]['alpha'][burnin_end:] 
                                         for i in range(len(chain_results)) if chain_results[i] is not None])
                    
                    print(f"{mcmc_type}/{subid} - Mean sigma: {sigma_mean:.4f}, Mean alpha: {alpha_mean:.4f}")
                    
                    # Write parameter summary to file
                    with open(os.path.join(save_path, "parameter_summary.txt"), "w") as f:
                        f.write(f"Burn-in: {burnin_percent*100}%\n")
                        f.write(f"Mean sigma: {sigma_mean:.6f}\n")
                        f.write(f"Mean alpha: {alpha_mean:.6f}\n")
                        f.write(f"Acceptance rates:\n")
                        f.write(f"  Sigma: {np.mean([np.mean(result['acceptsigma']) for result in chain_results if result is not None]):.4f}\n")
                        f.write(f"  Alpha: {np.mean([np.mean(result['acceptalpha']) for result in chain_results if result is not None]):.4f}\n")
                        f.write(f"  Path: {np.mean([np.mean(result['acceptpath']) for result in chain_results if result is not None]):.4f}\n")
                    
                except Exception as e:
                    print(f"Error processing {mcmc_type}/{subid}: {e}")



In [13]:
# Usage
burnin_percent = 0.3
sim_path = "sigma=0.6_alpha=0.025_dt=0.05/old_run"  # Change to your experiment path
process_mcmc_results(sim_path, burnin_percent)

Processing inference on simulated data set seed=4074397383
Found 1 mcmc run directories
Plots already exist for mcmc/id=629540298, skipping. Use force_replot=True to regenerate.
Found 1 mcmc_procrustes run directories
Plots already exist for mcmc_procrustes/id=211875092, skipping. Use force_replot=True to regenerate.
Processing inference on simulated data set seed=3177238945
Found 1 mcmc run directories
Plots already exist for mcmc/id=342962132, skipping. Use force_replot=True to regenerate.
Found 1 mcmc_procrustes run directories
Plots already exist for mcmc_procrustes/id=79118914, skipping. Use force_replot=True to regenerate.
Processing inference on simulated data set seed=3127305017
Found 1 mcmc run directories
Plots already exist for mcmc/id=315697991, skipping. Use force_replot=True to regenerate.
Found 1 mcmc_procrustes run directories
Plots already exist for mcmc_procrustes/id=974688655, skipping. Use force_replot=True to regenerate.
