In [2]:
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import wandb


def autocorrelation_fft(x, max_lag=None):
    """
    Efficiently compute autocorrelation function using FFT.
    
    Parameters:
    -----------
    x : array
        1D array of samples
    max_lag : int, optional
        Maximum lag to compute (default: len(x)//3)
        
    Returns:
    --------
    acf : array
        Autocorrelation function values
    """
    n = len(x)
    if max_lag is None:
        max_lag = min(n // 3, 20000)  # Cap at 20000 to prevent slow computation
    
    # Remove mean and normalize
    x_norm = x - np.mean(x)
    var = np.var(x_norm)
    x_norm = x_norm / np.sqrt(var)
    
    # Compute autocorrelation using FFT
    # Pad the signal with zeros to avoid circular correlation
    fft = np.fft.fft(x_norm, n=2*n)
    acf = np.fft.ifft(fft * np.conjugate(fft))[:n]
    acf = acf.real / n  # Normalize
    
    return acf[:max_lag]

def integrated_autocorr_time(x, M=5, c=10):
    """
    Estimate the integrated autocorrelation time using a self-consistent window.
    Based on the algorithm described by Goodman and Weare.
    
    Parameters:
    -----------
    x : array
        1D array of samples
    M : int, default=5
        Window size multiplier (typically 5-10)
    c : int, default=10
        Maximum lag cutoff for window determination
        
    Returns:
    --------
    tau : float
        Integrated autocorrelation time
    acf : array
        Autocorrelation function values
    ess : float
        Effective sample size
    """
    n = len(x)
    orig_x = x.copy()
    
    # Initial pairwise reduction if needed
    k = 0
    max_iterations = 10  # Prevent infinite loop
    
    while k < max_iterations:
        # Calculate autocorrelation function
        acf = autocorrelation_fft(x)
        
        # Calculate integrated autocorrelation time with self-consistent window
        tau = 1.0  # Initialize with the first term
        
        # Find the window size where window <= M * tau
        for window in range(1, len(acf)):
            # Update tau with this window
            tau_window = 1.0 + 2.0 * sum(acf[1:window+1])
            
            # Check window consistency: window <= M*tau
            if window <= M * tau_window:
                tau = tau_window
            else:
                break
        
        # If we have a robust estimate, we're done
        if n >= c * tau:
            # Scale tau back to the original time scale: tau_0 = 2^k * tau_k
            tau = tau * (2**k)
            break
            
        # If we don't have a robust estimate, perform pairwise reduction
        k += 1
        n_half = len(x) // 2
        x_new = np.zeros(n_half)
        for i in range(n_half):
            if 2*i + 1 < len(x):
                x_new[i] = 0.5 * (x[2*i] + x[2*i+1])
            else:
                x_new[i] = x[2*i]
        x = x_new
        n = len(x)
    
    # If we exited without a robust estimate, compute one final estimate
    if k >= max_iterations or n < c * tau:
        acf = autocorrelation_fft(orig_x)
        tau_reduced = 1.0 + 2.0 * sum(acf[1:min(len(acf), int(M)+1)])
        # Scale tau back to the original time scale
        tau = tau_reduced * (2**k)
    
    # Calculate effective sample size using original series length
    ess = len(orig_x) / tau
    
    return tau, acf, ess


def hamiltonian_walk_move_dual_avg(gradient_func, potential_func, initial, n_samples, 
                                  n_chains_per_group=5, epsilon_init=0.01, n_leapfrog=10, 
                                  beta=0.05, n_thin=1, target_accept=0.65, n_warmup=1000,
                                  gamma=0.05, t0=10, kappa=0.75):
    """
    Hamiltonian Walk Move sampler with dual averaging for automatic step size adaptation.
    
    Parameters:
    -----------
    gradient_func : callable
        Function that computes gradients of the log probability
    potential_func : callable  
        Function that computes the negative log probability (potential energy)
    initial : array_like
        Initial state
    n_samples : int
        Number of samples to collect (after warmup)
    n_chains_per_group : int
        Number of chains per group (default: 5)
    epsilon_init : float
        Initial step size (default: 0.01)
    n_leapfrog : int
        Number of leapfrog steps (default: 10)
    beta : float
        Preconditioning parameter (default: 0.05)
    n_thin : int
        Thinning factor - store every n_thin sample (default: 1, no thinning)
    target_accept : float
        Target acceptance rate for dual averaging (default: 0.65)
    n_warmup : int
        Number of warmup iterations for step size adaptation (default: 1000)
    gamma : float
        Dual averaging parameter controlling adaptation rate (default: 0.05)
    t0 : float
        Dual averaging parameter for numerical stability (default: 10)
    kappa : float
        Dual averaging parameter controlling decay (default: 0.75, should be in (0.5, 1])
    
    Returns:
    --------
    samples : ndarray
        Generated samples from all chains (after warmup)
    acceptance_rates : ndarray
        Final acceptance rates for each chain
    step_size_history : ndarray
        History of step sizes during adaptation
    """
    
    # Initialize
    orig_dim = initial.shape
    flat_dim = np.prod(orig_dim)
    total_chains = 2 * n_chains_per_group
    
    # Create initial states with small random perturbations
    states = np.tile(initial.flatten(), (total_chains, 1)) + 0.1 * np.random.randn(total_chains, flat_dim)
    
    # Split into two groups
    group1 = slice(0, n_chains_per_group)
    group2 = slice(n_chains_per_group, total_chains)
    
    # Dual averaging initialization
    log_epsilon = np.log(epsilon_init)
    log_epsilon_bar = 0.0
    H_bar = 0.0
    step_size_history = []
    
    # Calculate total iterations needed based on thinning factor
    total_sampling_iterations = n_samples * n_thin
    total_iterations = n_warmup + total_sampling_iterations
    
    # Storage for samples and acceptance tracking
    samples = np.zeros((total_chains, n_samples, flat_dim))
    accepts_warmup = np.zeros(total_chains)  # Track accepts during warmup
    accepts_sampling = np.zeros(total_chains)  # Track accepts during sampling
    
    # Sample index to track where to store thinned samples
    sample_idx = 0
    
    # Main sampling loop
    for i in range(total_iterations):
        is_warmup = i < n_warmup
        current_epsilon = epsilon_init if not is_warmup else np.exp(log_epsilon)
        
        # Store current state from all chains (only during sampling phase)
        if not is_warmup and (i - n_warmup) % n_thin == 0 and sample_idx < n_samples:
            samples[:, sample_idx] = states
            sample_idx += 1
        
        # Precompute step size terms
        beta_eps = beta * current_epsilon
        beta_eps_half = beta_eps / 2
        
        # Compute centered ensembles for preconditioning
        centered2 = (states[group2] - np.mean(states[group2], axis=0)) / np.sqrt(n_chains_per_group)
        
        # First group update
        p1 = np.random.randn(n_chains_per_group, n_chains_per_group)
        
        # Store current state and energy
        current_q1 = states[group1].copy()
        current_q1_reshaped = current_q1.reshape(n_chains_per_group, *orig_dim)
        current_U1 = potential_func(current_q1_reshaped)
        current_K1 = np.clip(0.5 * np.sum(p1**2, axis=1), 0, 1000)
        
        # Leapfrog integration with preconditioning
        q1 = current_q1.copy()
        p1_current = p1.copy()
        
        # Initial half-step for momentum
        grad1 = gradient_func(q1.reshape(n_chains_per_group, *orig_dim)).reshape(n_chains_per_group, -1)
        grad1 = np.nan_to_num(grad1, nan=0.0)
        p1_current -= beta_eps_half * np.dot(grad1, centered2.T)
        
        # Full leapfrog steps
        for step in range(n_leapfrog):
            # Position update with ensemble preconditioning
            q1 += beta_eps * np.dot(p1_current, centered2)
            
            if step < n_leapfrog - 1:
                # Momentum update
                grad1 = gradient_func(q1.reshape(n_chains_per_group, *orig_dim)).reshape(n_chains_per_group, -1)
                grad1 = np.nan_to_num(grad1, nan=0.0)
                p1_current -= beta_eps * np.dot(grad1, centered2.T)
        
        # Final half-step for momentum
        grad1 = gradient_func(q1.reshape(n_chains_per_group, *orig_dim)).reshape(n_chains_per_group, -1)
        grad1 = np.nan_to_num(grad1, nan=0.0)
        p1_current -= beta_eps_half * np.dot(grad1, centered2.T)
        
        # Compute proposed energy
        proposed_U1 = potential_func(q1.reshape(n_chains_per_group, *orig_dim))
        proposed_K1 = np.clip(0.5 * np.sum(p1_current**2, axis=1), 0, 1000)
        
        # Metropolis acceptance with numerical stability
        dH1 = (proposed_U1 + proposed_K1) - (current_U1 + current_K1)
        
        accept_probs1 = np.ones_like(dH1)
        exp_needed = dH1 > 0
        if np.any(exp_needed):
            safe_dH = np.clip(dH1[exp_needed], None, 100)
            accept_probs1[exp_needed] = np.exp(-safe_dH)
        
        accepts1 = np.random.random(n_chains_per_group) < accept_probs1
        states[group1][accepts1] = q1[accepts1]
        
        # Track acceptances
        if is_warmup:
            accepts_warmup[group1] += accepts1
        else:
            accepts_sampling[group1] += accepts1
        
        # Second group update
        centered1 = (states[group1] - np.mean(states[group1], axis=0)) / np.sqrt(n_chains_per_group)
        
        p2 = np.random.randn(n_chains_per_group, n_chains_per_group)
        
        current_q2 = states[group2].copy()
        current_q2_reshaped = current_q2.reshape(n_chains_per_group, *orig_dim)
        current_U2 = potential_func(current_q2_reshaped)
        current_K2 = np.clip(0.5 * np.sum(p2**2, axis=1), 0, 1000)
        
        q2 = current_q2.copy()
        p2_current = p2.copy()
        
        # Initial half-step for momentum
        grad2 = gradient_func(q2.reshape(n_chains_per_group, *orig_dim)).reshape(n_chains_per_group, -1)
        grad2 = np.nan_to_num(grad2, nan=0.0)
        p2_current -= beta_eps_half * np.dot(grad2, centered1.T)
        
        # Full leapfrog steps
        for step in range(n_leapfrog):
            q2 += beta_eps * np.dot(p2_current, centered1)
            
            if step < n_leapfrog - 1:
                grad2 = gradient_func(q2.reshape(n_chains_per_group, *orig_dim)).reshape(n_chains_per_group, -1)
                grad2 = np.nan_to_num(grad2, nan=0.0)
                p2_current -= beta_eps * np.dot(grad2, centered1.T)
        
        # Final half-step for momentum
        grad2 = gradient_func(q2.reshape(n_chains_per_group, *orig_dim)).reshape(n_chains_per_group, -1)
        grad2 = np.nan_to_num(grad2, nan=0.0)
        p2_current -= beta_eps_half * np.dot(grad2, centered1.T)
        
        # Compute proposed energy
        proposed_U2 = potential_func(q2.reshape(n_chains_per_group, *orig_dim))
        proposed_K2 = np.clip(0.5 * np.sum(p2_current**2, axis=1), 0, 1000)
        
        # Metropolis acceptance
        dH2 = (proposed_U2 + proposed_K2) - (current_U2 + current_K2)
        
        accept_probs2 = np.ones_like(dH2)
        exp_needed = dH2 > 0
        if np.any(exp_needed):
            safe_dH = np.clip(dH2[exp_needed], None, 100)
            accept_probs2[exp_needed] = np.exp(-safe_dH)
        
        accepts2 = np.random.random(n_chains_per_group) < accept_probs2
        states[group2][accepts2] = q2[accepts2]
        
        # Track acceptances
        if is_warmup:
            accepts_warmup[group2] += accepts2
        else:
            accepts_sampling[group2] += accepts2
        
        # Dual averaging step size adaptation during warmup
        if is_warmup:
            # Average acceptance probability across all chains in this iteration
            current_accept_rate = (np.sum(accepts1) + np.sum(accepts2)) / total_chains
            
            # Dual averaging update
            m = i + 1  # iteration number (1-indexed)
            eta = 1.0 / (m + t0)
            
            # Update log step size
            H_bar = (1 - eta) * H_bar + eta * (target_accept - current_accept_rate)
            
            # Compute log step size with shrinkage
            log_epsilon = np.log(epsilon_init) - np.sqrt(m) / gamma * H_bar
            
            # Update log_epsilon_bar for final step size
            eta_bar = m**(-kappa)
            log_epsilon_bar = (1 - eta_bar) * log_epsilon_bar + eta_bar * log_epsilon
            
            # Store step size history
            step_size_history.append(np.exp(log_epsilon))
        
        # After warmup, fix step size to the adapted value
        if i == n_warmup - 1:
            epsilon_init = np.exp(log_epsilon_bar)
            print(f"Warmup complete. Final adapted step size: {epsilon_init:.6f}")
    
    # Reshape final samples to original dimensions
    samples = samples.reshape((total_chains, n_samples) + orig_dim)
    
    # Compute acceptance rates for sampling phase only
    acceptance_rates = accepts_sampling / total_sampling_iterations
    
    return samples, acceptance_rates, np.array(step_size_history)


def create_high_dim_precision(dim, condition_number=100):
    """Create a high-dimensional diagonal precision matrix with given condition number."""
    # For reproducibility
    np.random.seed(42)
    
    # Create diagonal eigenvalues with desired condition number
    eigenvalues = 0.1 * np.linspace(1, condition_number, dim)
    
    # For diagonal matrices, we can just return the eigenvalues
    # This avoids storing the full matrix which is mostly zeros
    return eigenvalues

def benchmark_samplers(dim=40, n_samples=10000, burn_in=1000, condition_number=100, n_thin=1, save_dir=None):
    """
    Benchmark HWM sampler on a high-dimensional Gaussian.
    """
    # Create precision matrix (inverse covariance) - just the diagonal values
    precision_diag = create_high_dim_precision(dim, condition_number)
    
    # Compute covariance matrix diagonal for reference (needed for evaluation)
    # For diagonal matrices, inverse is just reciprocal of diagonal elements
    cov_diag = 1.0 / precision_diag
    
    true_mean = np.ones(dim)
    
    def gradient(x):
        """Optimized gradient of the negative log density with diagonal precision"""
        if x.ndim == 1:
            x = x.reshape(1, -1)
            
        # Vectorized operation for all samples
        centered = x - true_mean
        # For diagonal precision, this is just elementwise multiplication
        result = centered * precision_diag[np.newaxis, :]
            
        return result
    
    def potential(x):
        """Optimized negative log density (potential energy) with diagonal precision"""
        if x.ndim == 1:
            x = x.reshape(1, -1)
            
        # Vectorized operation for all samples
        centered = x - true_mean
        # For diagonal precision, this simplifies to sum of elementwise products
        result = 0.5 * np.sum(centered**2 * precision_diag, axis=1)
            
        return result
    
    # Initial state
    initial = np.zeros(dim)
    
    # Dictionary to store results
    results = {}
    
    # Define samplers to benchmark with burn-in
    total_samples = n_samples
    
    # Define samplers to benchmark - adjust parameters for high-dimensional case
    samplers = {
        "Hamiltonian Walk Move": lambda: hamiltonian_walk_move_dual_avg(
            gradient_func=gradient, 
            potential_func=potential, 
            initial=initial, 
            n_samples=total_samples, 
            n_warmup=burn_in, 
            n_chains_per_group=dim, 
            epsilon_init=1/(dim**(1/4)), 
            n_leapfrog=3, 
            beta=1.0,
            target_accept=0.65, 
            n_thin=n_thin
        ),
    }
    
    for name, sampler_func in samplers.items():
        print(f"Running {name}...")
        start_time = time.time()
        samples, acceptance_rates, step_size_history = sampler_func()
        elapsed = time.time() - start_time
        
        post_burn_in_samples = samples
        
        # Flatten samples from all chains
        flat_samples = post_burn_in_samples.reshape(-1, dim)
        
        # Compute sample mean and covariance
        sample_mean = np.mean(flat_samples, axis=0)
        
        # For MSE calculation, we don't need to compute the full covariance matrix
        # We can compute the diagonal elements directly
        sample_var = np.var(flat_samples, axis=0)
        
        # Calculate mean squared error for mean and covariance
        mean_mse = np.mean((sample_mean - true_mean)**2) / np.mean(true_mean**2)
        # For diagonal covariance, we only compare diagonal elements
        cov_mse = np.sum((sample_var - cov_diag)**2) / np.sum(cov_diag**2)
        
        # Compute autocorrelation for first dimension
        # Average over chains to compute autocorrelation
        acf = autocorrelation_fft(np.mean(post_burn_in_samples[:, :, 0], axis=0))
        
        # Compute integrated autocorrelation time for first dimension
        try:
            tau, _, ess = integrated_autocorr_time(np.mean(post_burn_in_samples[:, :, 0], axis=0))
        except:
            tau, ess = np.nan, np.nan
        
        # Store results
        results[name] = {
            "samples": flat_samples,
            "acceptance_rates": acceptance_rates,
            "mean_mse": mean_mse,
            "cov_mse": cov_mse,
            "autocorrelation": acf,
            "tau": tau,
            "ess": ess,
            "time": elapsed
        }
        
        print(f"  Acceptance rate: {np.mean(acceptance_rates):.2f}")
        print(f"  Mean MSE: {mean_mse:.6f}")
        print(f"  Covariance MSE: {cov_mse:.6f}")
        print(f"  Integrated autocorrelation time: {tau:.2f}")
        print(f"  Time: {elapsed:.2f} seconds")

        if save_dir:
            np.save(os.path.join(save_dir, f"samples_{name}.npy"), post_burn_in_samples)
            np.save(os.path.join(save_dir, f"acf_{name}.npy"), acf)
            
    return results, true_mean, cov_diag


# Main benchmark script
n_samples = 10**4
burn_in = 10**3
# array_dim = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
array_dim = [4, 8, 16, 32, 64]
n_thin = 1

home = "/scratch/yc3400/AffineInvariant/"
timestamp = time.strftime("%Y%m%d-%H%M%S")
folder = f"benchmark_results_HWMdualavg_Gaussian_sample_{timestamp}"

wandb_project = "AffineInvariant"
wandb_entity = 'yifanc96'
wandb_run = wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    resume=None,
    id=None,
    name=folder
)
wandb.run.log_code(".")

print(f'n_sample{n_samples}, burn_in{burn_in}, n_thin{n_thin}')
    
for dim in array_dim:
    print(f"dim={dim}")
    # Create a timestamped directory for this run
    save_dir = os.path.join(home + folder, f"{dim}")
    
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
    
    # Run benchmarks and save results
    results, true_mean, cov_diag = benchmark_samplers(
        dim=dim, 
        n_samples=n_samples, 
        burn_in=burn_in, 
        condition_number=1000,
        n_thin=n_thin,
        save_dir=save_dir
    )

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


n_sample10000, burn_in1000, n_thin1
dim=4
Running Hamiltonian Walk Move...
Warmup complete. Final adapted step size: 1.149099
  Acceptance rate: 0.65
  Mean MSE: 0.000065
  Covariance MSE: 0.000221
  Integrated autocorrelation time: 4.75
  Time: 4.67 seconds
dim=8
Running Hamiltonian Walk Move...
Warmup complete. Final adapted step size: 0.967865
  Acceptance rate: 0.67
  Mean MSE: 0.000002
  Covariance MSE: 0.000002
  Integrated autocorrelation time: 2.51
  Time: 5.40 seconds
dim=16
Running Hamiltonian Walk Move...
Warmup complete. Final adapted step size: 0.845588
  Acceptance rate: 0.65
  Mean MSE: 0.000004
  Covariance MSE: 0.000002
  Integrated autocorrelation time: 2.38
  Time: 6.07 seconds
dim=32
Running Hamiltonian Walk Move...
Warmup complete. Final adapted step size: 0.737972
  Acceptance rate: 0.64
  Mean MSE: 0.000001
  Covariance MSE: 0.000002
  Integrated autocorrelation time: 2.23
  Time: 7.95 seconds
dim=64
Running Hamiltonian Walk Move...
Warmup complete. Final adapted