In [1]:
import jax
import jax.numpy as jnp
from tqdm import tqdm 
import pandas as pd
import wandb
#import argparse
import scipy
import pickle 
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.backends.backend_pdf as backend_pdf
import subprocess
import time
import glob 

from bridge_sampling.BFFG import backward_filter, forward_guide, forward_guide_edge, get_logpsi
from bridge_sampling.setup_SDEs import Stratonovich_to_Ito, dtsdWsT, dWs
from bridge_sampling.noise_kernel import Q12
from bridge_sampling.helper_functions import *

from mcmc import *
import subprocess
import time

In [None]:
def run_mcmc_for_all_datasets(experiment_path, 
                              num_chains=3, 
                              num_samples=3000, 
                              dt=0.05, 
                              lambd=0.95, 
                              obs_var=0.001, 
                              rb=2, 
                              prior_sigma_min=0.0, 
                              prior_sigma_max=1.0, 
                              prior_alpha_min=0.0, 
                              prior_alpha_max=0.01, 
                              proposal_sigma_tau=0.1, 
                              proposal_alpha_tau=0.0015, 
                              use_wandb=True):
    """Run MCMC for all datasets in the experiment path"""

    
    # Get all dataset folders
    dataset_folders = glob.glob(f"{experiment_path}/seed=*")
    
    print(f"Found {len(dataset_folders)} datasets in {experiment_path}")
    
    # Loop through each dataset folder
    for dataset_folder in dataset_folders:
        # Extract the folder name
        folder_name = os.path.basename(dataset_folder)
        print(f"\nProcessing dataset: {folder_name}")
        
        # Check if the required file exists
        data_file = f"{dataset_folder}/procrustes_aligned.csv"
        if not os.path.exists(data_file):
            print(f"  Skipping: {data_file} not found")
            continue
        
        # Generate a random seed for this batch of chains
        seed_start = np.random.randint(0, 1000_000_000)
        
        # Set up output path
        output_path = f"{dataset_folder}/mcmc_seed={seed_start}_N={num_samples}"
        
        print(f"  Starting {num_chains} MCMC chains with seed {seed_start}")
        
        # Start MCMC chains in screen sessions
        screen_sessions = run_mcmc_in_screens(
            num_chains=num_chains,
            script_path="run_mcmc.py",
            seed_param="--seed_mcmc",
            seed_start=seed_start,
            screen_prefix=f"mcmc_{folder_name}",  # Use unique screen names
            script_args={
                "--outputpath": output_path,
                "--phylopath": "../data/chazot_subtree_rounded.nw",
                "--datapath": data_file,
                "--dt": dt,
                "--lambd": lambd,
                "--obs_var": obs_var,
                "--rb": rb,
                "--N": num_samples,
                "--prior_sigma_min": prior_sigma_min,
                "--prior_sigma_max": prior_sigma_max,
                "--prior_alpha_min": prior_alpha_min,
                "--prior_alpha_max": prior_alpha_max,
                "--proposal_sigma_tau": proposal_sigma_tau,
                "--proposal_alpha_tau": proposal_alpha_tau,
                "--use_wandb": True
            }
        )
        
        print(f"  Started chains for {folder_name}. Screen sessions: {', '.join(screen_sessions)}")
        
        # Optional: Add a delay between datasets to avoid overloading the system
        time.sleep(5)
    
    print(f"\nMCMC chains started for all datasets in {experiment_path}")



In [None]:
# MCMC settings
num_chains = 3
num_samples = 5000
dt = 0.05
lambd = 0.9
obs_var = 0.001
rb = 2
prior_sigma_min = 0.0
prior_sigma_max = 1.5
prior_alpha_min = 0.0
prior_alpha_max = 0.03
proposal_sigma_tau = 0.1
proposal_alpha_tau = 0.015
seed_start = np.random.randint(0,1000_000_000)

In [4]:
# Run the function
experiment_path = "exp_1_sigma=0.7_alpha=0.025_dt=0.05"
run_mcmc_for_all_datasets(experiment_path=experiment_path,
                          num_chains=num_chains,
                          num_samples=num_samples,
                          dt=dt,
                          lambd=lambd,
                          obs_var=obs_var,
                          rb=rb,
                          prior_sigma_min=prior_sigma_min,
                          prior_sigma_max=prior_sigma_max,
                          prior_alpha_min=prior_alpha_min,
                          prior_alpha_max=prior_alpha_max,
                          proposal_sigma_tau=proposal_sigma_tau,
                          proposal_alpha_tau=proposal_alpha_tau)    

Found 1 datasets in exp_1_sigma=0.7_alpha=0.025_dt=0.05

Processing dataset: seed=121197884
  Starting 3 MCMC chains with seed 572716989
Starting chain 1 with seed 572716989 in screen 'mcmc_seed=121197884_1'
Starting chain 2 with seed 572716990 in screen 'mcmc_seed=121197884_2'
Starting chain 3 with seed 572716991 in screen 'mcmc_seed=121197884_3'

3 MCMC chains started in separate screen sessions.
To attach to a screen session: screen -r <screen_name>
To detach from a screen session: Ctrl+A, then D
Screen sessions: mcmc_seed=121197884_1, mcmc_seed=121197884_2, mcmc_seed=121197884_3
  Started chains for seed=121197884. Screen sessions: mcmc_seed=121197884_1, mcmc_seed=121197884_2, mcmc_seed=121197884_3

MCMC chains started for all datasets in exp_1_sigma=0.7_alpha=0.025_dt=0.05


# Visualize results 

In [5]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import glob
import os

def load_mcmc_results(filepath_pattern):
    """
    Load MCMC results from pickle files matching the given pattern.
    
    Args:
        filepath_pattern: Pattern to match pickle files (e.g., "results/chain_*.pkl")
        
    Returns:
        List of loaded results
    """
    results = []
    for filepath in sorted(glob.glob(filepath_pattern)):
        print(f"Loading {filepath}")
        with open(filepath, 'rb') as f:
            results.append(pickle.load(f))
    return results

In [6]:
results_path = "exp_1_sigma=0.7_alpha=0.025_dt=0.05/seed=121197884/mcmc_seed=871636495_N=3000/results_*.pkl"  # Adjust pattern as needed
chain_results = load_mcmc_results(results_path)
param_names = ["sigmas", "alphas"]  # Replace with your actual parameter names
len(chain_results)

0

In [7]:
chain_results[0]['acceptpath']

IndexError: list index out of range

In [None]:
[np.mean(result['acceptpath']) for result in chain_results if result is not None]

In [None]:
[plt.plot(chain_results[i]['sigmas']) for i in range(len(chain_results))]

In [None]:
[plt.plot(chain_results[i]['alphas']) for i in range(len(chain_results))]