# Workflow

In [1]:
import numpy as np
import healpy as hp
import matplotlib.pyplot as plt
import sys

sys.path.append('../hydra_tod/')

from astropy.coordinates import EarthLocation, AltAz, SkyCoord
from astropy.time import Time, TimeDelta
import astropy.units as u
from utils import Leg_poly_proj, view_samples
from flicker_model import sim_noise, flicker_cov


In [2]:
import pickle
from simulation import MultiTODSimulation
# Load the simulation data from a pickle file
with open('multi_tod_simulation_data.pkl', 'rb') as f:
    multi_tod_sim = pickle.load(f)

In [3]:
# Make all the individual variables available for backward compatibility
t_list = multi_tod_sim.t_list
theta_c_setting = multi_tod_sim.theta_c_setting
phi_c_setting = multi_tod_sim.phi_c_setting
theta_c_rising = multi_tod_sim.theta_c_rising
phi_c_rising = multi_tod_sim.phi_c_rising
bool_map_setting = multi_tod_sim.bool_map_setting
bool_map_rising = multi_tod_sim.bool_map_rising
integrated_beam_setting = multi_tod_sim.integrated_beam_setting
integrated_beam_rising = multi_tod_sim.integrated_beam_rising
full_bool_map = multi_tod_sim.full_bool_map
pixel_indices = multi_tod_sim.pixel_indices
integrated_beam = multi_tod_sim.integrated_beam
Tsky_operator_setting = multi_tod_sim.Tsky_operator_setting
Tsky_operator_rising = multi_tod_sim.Tsky_operator_rising
sky_params = multi_tod_sim.sky_params
ntime = multi_tod_sim.ntime
ndiode_proj = multi_tod_sim.ndiode_proj
nd_rec_operator = multi_tod_sim.nd_rec_operator
nd_rec_params = multi_tod_sim.nd_rec_params
gain_proj = multi_tod_sim.gain_proj
gain_params_setting = multi_tod_sim.gain_params_setting
gain_params_rising = multi_tod_sim.gain_params_rising
gains_setting = multi_tod_sim.gains_setting
gains_rising = multi_tod_sim.gains_rising
fc = multi_tod_sim.fc
logfc = multi_tod_sim.logfc
f0 = multi_tod_sim.f0
logf0 = multi_tod_sim.logf0
noise_setting = multi_tod_sim.noise_setting
noise_rising = multi_tod_sim.noise_rising
Tsys_setting = multi_tod_sim.Tsys_setting
Tsys_rising = multi_tod_sim.Tsys_rising
TOD_setting = multi_tod_sim.TOD_setting
TOD_rising = multi_tod_sim.TOD_rising
pixels_c_setting = multi_tod_sim.pixels_c_setting
bool_map_c_setting = multi_tod_sim.bool_map_c_setting
calibration_1_index = multi_tod_sim.calibration_1_index
calibration_5_indices = multi_tod_sim.calibration_5_indices

# Additional constants from the original code
T_ndiode = multi_tod_sim.T_ndiode
rec_params = multi_tod_sim.rec_params
dtime = multi_tod_sim.dtime
alpha = multi_tod_sim.alpha
logf0_list = [multi_tod_sim.logf0]
sigma_2 = multi_tod_sim.sigma_2

In [4]:
def log_gain_params(proj, gains):
    log_gains = np.log(gains)
    log_coeffs = np.linalg.lstsq(proj, log_gains, rcond=None)[0]
    return log_coeffs

def log_local_params(nd_rec_proj, local_params):
    log_Tnd = np.log(local_params[0])
    loc_res = local_params[1:]
    proj_res = nd_rec_proj[:, 1:]
    log_locals = np.log(proj_res @ loc_res)
    log_coeffs = np.linalg.lstsq(proj_res, log_locals, rcond=None)[0]
    return np.concatenate(([log_Tnd], log_coeffs))  

In [5]:
log_gain_params_s = log_gain_params(gain_proj, gains_setting)
log_gain_params_r = log_gain_params(gain_proj, gains_rising)

gain_prior_mean_s = log_gain_params_s
gain_prior_mean_r = log_gain_params_r
gain_prior_mean_list = [gain_prior_mean_s, gain_prior_mean_r]
gain_prior_cov_inv_s = 1. / (0.1 * gain_prior_mean_s)**2
gain_prior_cov_inv_s[0] = 10.0
gain_prior_cov_inv_r = 1. / (0.1 * gain_prior_mean_r)**2
gain_prior_cov_inv_r[0] = 10.0
gain_prior_cov_inv_list = [gain_prior_cov_inv_s, gain_prior_cov_inv_r]



In [6]:
prior_std_frac = 0.1
Tsky_prior_cov_inv = 1. / (prior_std_frac * sky_params)**2
calibration_indices = calibration_5_indices
Tsky_prior_cov_inv[calibration_indices] = 1e20

log_local_params_s = log_local_params(nd_rec_operator, nd_rec_params)
log_local_params_r = log_local_params_s
Tloc_prior_mean_list = [log_local_params_s, log_local_params_r]
# aux = np.ones_like(nd_rec_params)*0.0
aux = 1 / (0.1 * log_local_params_s)**2
# aux[1] = 1 / (0.1 * log_local_params_s[1])**2
Tloc_prior_cov_inv_list = [aux, aux]

def log_prior_noise(params):
    _, alpha_est = params
    return -1e10 * (alpha_est - alpha)**4 

init_Tloc_params_list = [log_local_params_s, log_local_params_r]
init_Tsky_params = sky_params
# init_Tsys_params = np.hstack([sky_params, log_local_params_s, log_local_params_r])
init_noise_params_list = [[-4.8, 2.2], [-4.8, 2.2]]

In [None]:
# Import the FlickerCorrEmulator class first
from full_Gibbs_sampler import TOD_Gibbs_sampler_joint_loc

Tsky_samples, all_gain_samples, all_noise_samples, all_Tloc_samples = \
    TOD_Gibbs_sampler_joint_loc(
        [TOD_setting, TOD_rising],
        [t_list, t_list],
        [gain_proj, gain_proj],
        [Tsky_operator_setting, Tsky_operator_rising],
        [nd_rec_operator, nd_rec_operator],
        init_Tsky_params,
        init_Tloc_params_list,
        init_noise_params_list,
        [logfc, logfc],
        wnoise_var=2.5e-6,
        Tsky_prior_cov_inv=Tsky_prior_cov_inv,
        Tsky_prior_mean=sky_params,
        local_Tloc_prior_cov_inv_list=Tloc_prior_cov_inv_list,
        local_Tloc_prior_mean_list=Tloc_prior_mean_list,
        local_gain_prior_cov_inv_list=gain_prior_cov_inv_list,
        local_gain_prior_mean_list=gain_prior_mean_list,
        local_noise_prior_func_list=[log_prior_noise, log_prior_noise],
        noise_sampler_type="emcee",
        ploc_Jeffreys_prior=True, 
        noise_Jeffreys_prior=True,
        n_samples=2000,
        tol=1e-20)

Using the emulator for flicker noise correlation function.
Get the JAX version of the emulators
Get the JAX version of the log-det emulator
Get the JAX version of the emulators


  mcmc = MCMC(nuts_kernel, num_warmup=current_warmup, num_samples=N_samples, num_chains=N_chains, progress_bar=False)


Running warmup round: 1500 additional steps (total warmup: 1500)


In [None]:
# First, save the samples as npy files..

np.save("outputs/GSF5_db/Tsky_samples_joint_loc.npy", Tsky_samples)
np.save("outputs/GSF5_db/gain_samples_joint_loc.npy", np.concatenate(all_gain_samples, axis=0))
np.save("outputs/GSF5_db/noise_samples_joint_loc.npy", np.concatenate(all_noise_samples, axis=0))
np.save("outputs/GSF5_db/Tloc_samples_joint_loc.npy", np.concatenate(all_Tloc_samples, axis=0))

In [None]:
Tsky_samples = np.load("outputs/GSF5_db/Tsky_samples_joint_loc.npy")
gain_samples = np.load("outputs/GSF5_db/gain_samples_joint_loc.npy")
noise_samples = np.load("outputs/GSF5_db/noise_samples_joint_loc.npy")
Tloc_samples = np.load("outputs/GSF5_db/Tloc_samples_joint_loc.npy")


In [None]:
from MCMC_diagnostics import diagnostics

In [None]:
Tloc_samples.shape

In [None]:
aux = gain_samples[1].reshape(1, 2000, 4)
diagnostics(aux[:, :, :], param_names=[r"$p_{g,0}$", r"$p_{g,1}$",r"$p_{g,2}$",r"$p_{g,3}$"])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
import pandas as pd

print("Sample shapes:")
print(f"Gain samples: {gain_samples.shape}")
print(f"Noise samples: {noise_samples.shape}")
print(f"Tloc samples: {Tloc_samples.shape}")

def analyze_parameter_correlations(gain_samples, noise_samples, Tloc_samples, 
                                 figsize=(16, 12), save_path=None):
    """
    Comprehensive correlation analysis between gain, noise, and Tloc parameters
    """
    
    # Reshape samples if needed
    n_chains_gain = gain_samples.shape[0] if len(gain_samples.shape) > 2 else 1
    n_chains_noise = noise_samples.shape[0] if len(noise_samples.shape) > 2 else 1
    n_chains_tloc = Tloc_samples.shape[0] if len(Tloc_samples.shape) > 2 else 1
    
    # Flatten chains if multiple chains exist
    if len(gain_samples.shape) == 3:
        gain_flat = gain_samples.reshape(-1, gain_samples.shape[-1])
    else:
        gain_flat = gain_samples
        
    if len(noise_samples.shape) == 3:
        noise_flat = noise_samples.reshape(-1, noise_samples.shape[-1])
    else:
        noise_flat = noise_samples
        
    if len(Tloc_samples.shape) == 3:
        tloc_flat = Tloc_samples.reshape(-1, Tloc_samples.shape[-1])
    else:
        tloc_flat = Tloc_samples
    
    # Create parameter names
    n_gain_params = gain_flat.shape[1] if len(gain_flat.shape) > 1 else 1
    n_noise_params = noise_flat.shape[1] if len(noise_flat.shape) > 1 else 1 
    n_tloc_params = tloc_flat.shape[1] if len(tloc_flat.shape) > 1 else 1
    
    param_names = []
    param_names.extend([f'Gain_{i}' for i in range(n_gain_params)])
    param_names.extend([f'Noise_{i}' for i in range(n_noise_params)])
    param_names.extend([f'Tloc_{i}' for i in range(n_tloc_params)])
    
    # Combine all parameters
    all_samples = np.column_stack([gain_flat, noise_flat, tloc_flat])
    
    # Calculate correlation matrix
    corr_matrix = np.corrcoef(all_samples.T)
    
    # Create figure with subplots
    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(2, 2, height_ratios=[2, 1], width_ratios=[2, 1])
    
    # Main correlation heatmap
    ax_main = fig.add_subplot(gs[0, 0])
    mask = np.triu(np.ones_like(corr_matrix, dtype=bool), k=1)
    
    sns.heatmap(corr_matrix, mask=mask, annot=True, fmt='.3f', 
                xticklabels=param_names, yticklabels=param_names,
                cmap='RdBu_r', center=0, square=True, ax=ax_main,
                cbar_kws={'shrink': 0.8, 'label': 'Correlation Coefficient'})
    ax_main.set_title('Parameter Correlation Matrix', fontsize=14, fontweight='bold')
    
    # Correlation histogram
    ax_hist = fig.add_subplot(gs[0, 1])
    lower_triangle = corr_matrix[np.tril_indices_from(corr_matrix, k=-1)]
    ax_hist.hist(lower_triangle, bins=30, alpha=0.7, color='steelblue', edgecolor='black')
    ax_hist.axvline(0, color='red', linestyle='--', alpha=0.8)
    ax_hist.set_xlabel('Correlation Coefficient')
    ax_hist.set_ylabel('Frequency')
    ax_hist.set_title('Distribution of\nCorrelation Coefficients')
    ax_hist.grid(True, alpha=0.3)
    
    # Strong correlations table
    ax_table = fig.add_subplot(gs[1, :])
    ax_table.axis('off')
    
    # Find strongest correlations (excluding diagonal)
    strong_corrs = []
    for i in range(len(param_names)):
        for j in range(i+1, len(param_names)):
            corr_val = corr_matrix[i, j]
            if abs(corr_val) > 0.1:  # Threshold for "strong" correlation
                strong_corrs.append((param_names[i], param_names[j], corr_val))
    
    # Sort by absolute correlation strength
    strong_corrs.sort(key=lambda x: abs(x[2]), reverse=True)
    
    # Create table
    if strong_corrs:
        table_data = []
        for param1, param2, corr in strong_corrs[:10]:  # Top 10
            table_data.append([param1, param2, f'{corr:.4f}'])
        
        table = ax_table.table(cellText=table_data,
                              colLabels=['Parameter 1', 'Parameter 2', 'Correlation'],
                              cellLoc='center',
                              loc='center',
                              bbox=[0.1, 0.3, 0.8, 0.6])
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1.2, 1.5)
        
        ax_table.set_title('Strongest Parameter Correlations (|r| > 0.1)', 
                          fontsize=12, fontweight='bold', y=0.95)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Correlation analysis saved to: {save_path}")
    
    plt.show()
    
    return corr_matrix, param_names, strong_corrs

# Run the analysis
corr_matrix, param_names, strong_corrs = analyze_parameter_correlations(
    gain_samples, noise_samples, Tloc_samples,
    figsize=(18, 12),
    save_path="figures/parameter_correlations_GSF5_db.pdf"
)

In [None]:
def plot_pairwise_correlations(gain_samples, noise_samples, Tloc_samples, 
                              top_n=6, figsize=(15, 10), save_path=None):
    """
    Create pairwise scatter plots for the most correlated parameters
    """
    # Flatten samples
    if len(gain_samples.shape) == 3:
        gain_flat = gain_samples.reshape(-1, gain_samples.shape[-1])
    else:
        gain_flat = gain_samples
        
    if len(noise_samples.shape) == 3:
        noise_flat = noise_samples.reshape(-1, noise_samples.shape[-1])
    else:
        noise_flat = noise_samples
        
    if len(Tloc_samples.shape) == 3:
        tloc_flat = Tloc_samples.reshape(-1, Tloc_samples.shape[-1])
    else:
        tloc_flat = Tloc_samples
    
    # Combine all parameters
    all_samples = np.column_stack([gain_flat, noise_flat, tloc_flat])
    
    # Parameter names with more descriptive labels
    param_labels = []
    param_labels.extend([f'Gain p_{{{i}}}' for i in range(gain_flat.shape[1])])
    param_labels.extend([f'log f_0, α' if i < 2 else f'Noise_{{{i}}}' for i in range(noise_flat.shape[1])])
    param_labels.extend([f'T_{{loc,{i}}}' for i in range(tloc_flat.shape[1])])
    
    # Calculate correlations and find strongest ones
    n_params = all_samples.shape[1]
    correlations = []
    
    for i in range(n_params):
        for j in range(i+1, n_params):
            corr, p_val = pearsonr(all_samples[:, i], all_samples[:, j])
            correlations.append((i, j, corr, p_val, param_labels[i], param_labels[j]))
    
    # Sort by absolute correlation
    correlations.sort(key=lambda x: abs(x[2]), reverse=True)
    
    # Create subplot grid
    rows = (top_n + 2) // 3
    cols = min(3, top_n)
    
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    if rows == 1:
        axes = axes.reshape(1, -1)
    
    for idx, (i, j, corr, p_val, label_i, label_j) in enumerate(correlations[:top_n]):
        row = idx // cols
        col = idx % cols
        ax = axes[row, col]
        
        # Create hexbin plot
        hb = ax.hexbin(all_samples[:, i], all_samples[:, j], 
                      gridsize=30, cmap='Blues', alpha=0.8)
        
        # Add trend line
        z = np.polyfit(all_samples[:, i], all_samples[:, j], 1)
        p = np.poly1d(z)
        x_trend = np.linspace(all_samples[:, i].min(), all_samples[:, i].max(), 100)
        ax.plot(x_trend, p(x_trend), "red", alpha=0.8, linewidth=2)
        
        # Labels and title
        ax.set_xlabel(label_i, fontsize=11)
        ax.set_ylabel(label_j, fontsize=11)
        ax.set_title(f'r = {corr:.3f}' + (f', p < 0.001' if p_val < 0.001 else f', p = {p_val:.3f}'), 
                    fontsize=12, fontweight='bold')
        ax.grid(True, alpha=0.3)
    
    # Remove empty subplots
    for idx in range(top_n, rows * cols):
        row = idx // cols
        col = idx % cols
        fig.delaxes(axes[row, col])
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Pairwise correlations saved to: {save_path}")
    
    plt.show()
    
    return correlations[:top_n]

# Create pairwise plots
top_correlations = plot_pairwise_correlations(
    gain_samples, noise_samples, Tloc_samples,
    top_n=9,  # 3x3 grid
    figsize=(16, 12),
    save_path="figures/pairwise_correlations_GSF5_db.pdf"
)

In [None]:
def plot_correlation_evolution(gain_samples, noise_samples, Tloc_samples, 
                             window_size=200, figsize=(14, 8), save_path=None):
    """
    Plot how correlations evolve during the sampling process
    """
    # Flatten samples (assuming shape is [n_chains, n_samples, n_params])
    if len(gain_samples.shape) == 3:
        gain_flat = gain_samples[0]  # Use first chain
    else:
        gain_flat = gain_samples
        
    if len(noise_samples.shape) == 3:
        noise_flat = noise_samples[0]
    else:
        noise_flat = noise_samples
        
    if len(Tloc_samples.shape) == 3:
        tloc_flat = Tloc_samples[0]
    else:
        tloc_flat = Tloc_samples
    
    all_samples = np.column_stack([gain_flat, noise_flat, tloc_flat])
    n_samples, n_params = all_samples.shape
    
    # Calculate rolling correlations
    sample_points = range(window_size, n_samples, window_size//4)
    
    # Select a few interesting parameter pairs
    interesting_pairs = [
        (0, n_params-2),  # First gain param vs first noise param
        (0, n_params-1),  # First gain param vs second noise param  
        (1, n_params-2),  # Second gain param vs first noise param
        (0, gain_flat.shape[1]),  # First gain param vs first Tloc param
    ]
    
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    axes = axes.flatten()
    
    param_labels = []
    param_labels.extend([f'Gain_{i}' for i in range(gain_flat.shape[1])])
    param_labels.extend([f'Noise_{i}' for i in range(noise_flat.shape[1])])
    param_labels.extend([f'Tloc_{i}' for i in range(tloc_flat.shape[1])])
    
    for idx, (i, j) in enumerate(interesting_pairs):
        if idx >= 4:
            break
            
        correlations = []
        for end_point in sample_points:
            start_point = max(0, end_point - window_size)
            window_data_i = all_samples[start_point:end_point, i]
            window_data_j = all_samples[start_point:end_point, j]
            corr, _ = pearsonr(window_data_i, window_data_j)
            correlations.append(corr)
        
        ax = axes[idx]
        ax.plot(sample_points, correlations, 'b-', linewidth=2, alpha=0.8)
        ax.axhline(0, color='red', linestyle='--', alpha=0.5)
        ax.set_xlabel('Sample Number')
        ax.set_ylabel('Correlation')
        ax.set_title(f'{param_labels[i]} vs {param_labels[j]}')
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Correlation evolution saved to: {save_path}")
    
    plt.show()

# Plot correlation evolution
plot_correlation_evolution(
    gain_samples, noise_samples, Tloc_samples,
    window_size=300,
    save_path="figures/correlation_evolution_GSF5_db.pdf"
)

In [None]:
def print_correlation_summary(gain_samples, noise_samples, Tloc_samples):
    """
    Print summary statistics for parameter correlations
    """
    # Flatten samples
    if len(gain_samples.shape) == 3:
        gain_flat = gain_samples.reshape(-1, gain_samples.shape[-1])
        noise_flat = noise_samples.reshape(-1, noise_samples.shape[-1])
        tloc_flat = Tloc_samples.reshape(-1, Tloc_samples.shape[-1])
    else:
        gain_flat = gain_samples
        noise_flat = noise_samples  
        tloc_flat = Tloc_samples
    
    print("=== PARAMETER CORRELATION SUMMARY ===\n")
    
    # Within-group correlations
    print("GAIN PARAMETER CORRELATIONS:")
    gain_corr = np.corrcoef(gain_flat.T)
    for i in range(gain_flat.shape[1]):
        for j in range(i+1, gain_flat.shape[1]):
            corr = gain_corr[i, j]
            print(f"  Gain_{i} - Gain_{j}: {corr:.4f}")
    
    print("\nNOISE PARAMETER CORRELATIONS:")
    noise_corr = np.corrcoef(noise_flat.T)
    for i in range(noise_flat.shape[1]):
        for j in range(i+1, noise_flat.shape[1]):
            corr = noise_corr[i, j]
            print(f"  Noise_{i} - Noise_{j}: {corr:.4f}")
    
    print("\nTLOC PARAMETER CORRELATIONS:")
    tloc_corr = np.corrcoef(tloc_flat.T)
    for i in range(tloc_flat.shape[1]):
        for j in range(i+1, tloc_flat.shape[1]):
            corr = tloc_corr[i, j]
            print(f"  Tloc_{i} - Tloc_{j}: {corr:.4f}")
    
    # Cross-group correlations
    print("\nCROSS-GROUP CORRELATIONS (|r| > 0.05):")
    all_samples = np.column_stack([gain_flat, noise_flat, tloc_flat])
    full_corr = np.corrcoef(all_samples.T)
    
    param_types = ['Gain']*gain_flat.shape[1] + ['Noise']*noise_flat.shape[1] + ['Tloc']*tloc_flat.shape[1]
    param_indices = list(range(gain_flat.shape[1])) + list(range(noise_flat.shape[1])) + list(range(tloc_flat.shape[1]))
    
    for i in range(len(param_types)):
        for j in range(i+1, len(param_types)):
            if param_types[i] != param_types[j]:  # Cross-group only
                corr = full_corr[i, j]
                if abs(corr) > 0.05:
                    print(f"  {param_types[i]}_{param_indices[i]} - {param_types[j]}_{param_indices[j]}: {corr:.4f}")

# Print summary
print_correlation_summary(gain_samples, noise_samples, Tloc_samples)