In [33]:

import matplotlib.pyplot as plt
import time
import os
import wandb


def autocorrelation_fft(x, max_lag=None):
    """
    Efficiently compute autocorrelation function using FFT.
    
    Parameters:
    -----------
    x : array
        1D array of samples
    max_lag : int, optional
        Maximum lag to compute (default: len(x)//3)
        
    Returns:
    --------
    acf : array
        Autocorrelation function values
    """
    n = len(x)
    if max_lag is None:
        max_lag = min(n // 3, 20000)  # Cap at 20000 to prevent slow computation
    
    # Remove mean and normalize
    x_norm = x - np.mean(x)
    var = np.var(x_norm)
    x_norm = x_norm / np.sqrt(var)
    
    # Compute autocorrelation using FFT
    # Pad the signal with zeros to avoid circular correlation
    fft = np.fft.fft(x_norm, n=2*n)
    acf = np.fft.ifft(fft * np.conjugate(fft))[:n]
    acf = acf.real / n  # Normalize
    
    return acf[:max_lag]

def integrated_autocorr_time(x, M=5, c=10):
    """
    Estimate the integrated autocorrelation time using a self-consistent window.
    Based on the algorithm described by Goodman and Weare.
    
    Parameters:
    -----------
    x : array
        1D array of samples
    M : int, default=5
        Window size multiplier (typically 5-10)
    c : int, default=10
        Maximum lag cutoff for window determination
        
    Returns:
    --------
    tau : float
        Integrated autocorrelation time
    acf : array
        Autocorrelation function values
    ess : float
        Effective sample size
    """
    n = len(x)
    orig_x = x.copy()
    
    # Initial pairwise reduction if needed
    k = 0
    max_iterations = 10  # Prevent infinite loop
    
    while k < max_iterations:
        # Calculate autocorrelation function
        acf = autocorrelation_fft(x)
        
        # Calculate integrated autocorrelation time with self-consistent window
        tau = 1.0  # Initialize with the first term
        
        # Find the window size where window <= M * tau
        for window in range(1, len(acf)):
            # Update tau with this window
            tau_window = 1.0 + 2.0 * sum(acf[1:window+1])
            
            # Check window consistency: window <= M*tau
            if window <= M * tau_window:
                tau = tau_window
            else:
                break
        
        # If we have a robust estimate, we're done
        if n >= c * tau:
            # Scale tau back to the original time scale: tau_0 = 2^k * tau_k
            tau = tau * (2**k)
            break
            
        # If we don't have a robust estimate, perform pairwise reduction
        k += 1
        n_half = len(x) // 2
        x_new = np.zeros(n_half)
        for i in range(n_half):
            if 2*i + 1 < len(x):
                x_new[i] = 0.5 * (x[2*i] + x[2*i+1])
            else:
                x_new[i] = x[2*i]
        x = x_new
        n = len(x)
    
    # If we exited without a robust estimate, compute one final estimate
    if k >= max_iterations or n < c * tau:
        acf = autocorrelation_fft(orig_x)
        tau_reduced = 1.0 + 2.0 * sum(acf[1:min(len(acf), int(M)+1)])
        # Scale tau back to the original time scale
        tau = tau_reduced * (2**k)
    
    # Calculate effective sample size using original series length
    ess = len(orig_x) / tau
    
    return tau, acf, ess


class FFTCovarianceOperator:
    """
    FFT-based covariance operator for periodic boundary conditions.
    
    We use Taylor expansion of V(u) = (1-u²)² around u=0:
    V(u) = 1 - 2u² + u⁴
    
    The issue: To eliminate the quadratic term completely, we'd need α = -4,
    but this gives negative eigenvalues (unstable).
    
    Instead, we use a small positive α for stability, and handle the 
    remaining quadratic term in the acceptance ratio.
    
    Prior: exp(-∫[½|∇u|² + ½αu²]dx) with small α > 0
    Target: exp(-∫[½|∇u|² + 1 - 2u² + u⁴]dx)
    
    Acceptance ratio involves: (1 - 2u² + u⁴) - ½αu² = 1 + (-2 - ½α)u² + u⁴
    """
    
    def __init__(self, N, h, alpha=0.01):
        """
        Initialize FFT-based covariance operator with small regularization.
        
        Parameters:
        -----------
        N : int
            Number of grid points
        h : float
            Grid spacing
        alpha : float
            Small regularization parameter for stability (default: 0.1)
        """
        self.N = N
        self.h = h
        self.alpha = alpha
        
        # Compute eigenvalues of periodic Laplacian: -Δ + α
        k = np.arange(N)
        # For k > N/2, use k - N to get negative frequencies  
        k = np.where(k > N//2, k - N, k)
        
        # Eigenvalues of periodic Laplacian: λₖ = (2πk)²
        # With discretization: λₖ = (2πk/L)² where L = 1, so λₖ = (2πk)²
        # BUT for finite differences: λₖ = 4sin²(πk/N)/h²
        # Let's use the exact continuous eigenvalues for simplicity
        laplacian_eigenvals = (2 * np.pi * k)**2
        
        # Add mass term: eigenvalues of (-Δ + α)
        self.eigenvalues = laplacian_eigenvals + alpha
        
        # All eigenvalues are now > 0, so covariance is well-defined
        self.cov_eigenvalues = 1.0 / self.eigenvalues
        self.cov_sqrt_eigenvalues = np.sqrt(self.cov_eigenvalues)
        
        print(f"  FFT covariance operator with quadratic regularization")
        print(f"  Taylor expansion: V(u) = 1 - 2u² + u⁴")
        print(f"  Quadratic regularization parameter α = {alpha}")
        print(f"  Min eigenvalue: {np.min(self.eigenvalues):.2e}")
        print(f"  Max eigenvalue: {np.max(self.eigenvalues):.2e}")
        print(f"  Condition number: {np.max(self.eigenvalues)/np.min(self.eigenvalues):.2e}")
        print(f"  Constant mode eigenvalue: {self.eigenvalues[0]:.2e}")
    
    def apply_covariance(self, x):
        """Apply covariance operator C₀ = (-Δ + α)⁻¹ to vector x using FFT."""
        if x.ndim == 1:
            # Transform to Fourier space
            x_hat = np.fft.fft(x)
            # Multiply by covariance eigenvalues
            result_hat = self.cov_eigenvalues * x_hat
            # Transform back to physical space
            return np.real(np.fft.ifft(result_hat))
        else:
            # Handle multiple vectors
            return np.array([self.apply_covariance(xi) for xi in x])
    
    def apply_sqrt_covariance(self, x):
        """Apply square root covariance operator C₀^{1/2} to vector x using FFT."""
        if x.ndim == 1:
            # Transform to Fourier space
            x_hat = np.fft.fft(x)
            # Multiply by sqrt covariance eigenvalues
            result_hat = self.cov_sqrt_eigenvalues * x_hat
            # Transform back to physical space
            return np.real(np.fft.ifft(result_hat))
        else:
            # Handle multiple vectors
            return np.array([self.apply_sqrt_covariance(xi) for xi in x])
    
    def sample_prior(self, n_samples=1):
        """Generate samples from the modified Gaussian prior N(0, C₀) where C₀ = (-Δ + α)⁻¹."""
        if n_samples == 1:
            # Generate white noise
            xi = np.random.randn(self.N)
            # Apply sqrt covariance
            return self.apply_sqrt_covariance(xi)
        else:
            # Generate multiple samples
            samples = np.zeros((n_samples, self.N))
            for i in range(n_samples):
                xi = np.random.randn(self.N)
                samples[i] = self.apply_sqrt_covariance(xi)
            return samples


def pcn_sampler_fft(log_density, initial, n_samples, n_warmup=1000, beta=0.5, 
                   n_thin=1, cov_op=None, h=None):
    """
    Preconditioned Crank-Nicolson (pCN) sampler with FFT-based covariance.
    
    The pCN algorithm is designed for sampling from measures on function spaces.
    Proposal: u' = √(1-β²) * u + β * C₀^(1/2) * ξ where ξ ~ N(0, I)
    
    Parameters:
    -----------
    log_density : callable
        Function that computes the log density
    initial : np.ndarray
        Initial state
    n_samples : int
        Number of samples to collect (after warmup)
    n_warmup : int
        Number of warmup iterations
    beta : float
        pCN parameter (step size), fixed
    cov_op : FFTCovarianceOperator
        FFT-based covariance operator
    
    Returns:
    --------
    samples : np.ndarray
        Collected samples (after warmup)
    acceptance_rate : float
        Final acceptance rate
    """
    
    dim = len(initial)
    
    if cov_op is None:
        raise ValueError("FFT covariance operator must be provided")
    
    # Calculate total iterations
    total_sampling_iterations = n_samples * n_thin
    total_iterations = n_warmup + total_sampling_iterations
    
    # Storage
    samples = np.zeros((n_samples, dim))
    accepts_warmup = 0
    accepts_sampling = 0
    
    # Current state
    current_u = initial.copy()
    
    # Sample index
    sample_idx = 0
    
    for i in range(total_iterations):
        is_warmup = i < n_warmup
        
        # Store sample (only during sampling phase)
        if not is_warmup and (i - n_warmup) % n_thin == 0 and sample_idx < n_samples:
            samples[sample_idx] = current_u
            sample_idx += 1
        
        # Generate proposal using pCN with FFT-based prior covariance
        xi = np.random.randn(dim)
        noise = cov_op.apply_sqrt_covariance(xi)
        proposal_u = np.sqrt(1 - beta**2) * current_u + beta * noise
        
        # pCN acceptance probability with corrected Taylor expansion
        # 
        # Target: exp(-∫[½|∇u|² + 1 - 2u² + u⁴]dx)
        # Prior:  exp(-∫[½|∇u|² + ½αu²]dx)
        #
        # Radon-Nikodym: exp(-∫[1 - 2u² + u⁴ - ½αu²]dx) 
        #               = exp(-∫[1 + (-2 - ½α)u² + u⁴]dx)
        # 
        # pCN acceptance: min(1, exp(-∫[(1 + (-2-½α)u'² + u'⁴) - (1 + (-2-½α)u² + u⁴)]dx))
        #                = min(1, exp(-∫[u'⁴ - u⁴ + (-2-½α)(u'² - u²)]dx))
        
        # Compute quartic terms
        quartic_current = current_u**4
        quartic_proposal = proposal_u**4
        quartic_diff = h * np.sum(quartic_proposal - quartic_current)
        
        # Compute quadratic terms with correct coefficient
        quadratic_current = current_u**2
        quadratic_proposal = proposal_u**2
        quadratic_coeff = -2 - 0.5 * cov_op.alpha
        quadratic_diff = quadratic_coeff * h * np.sum(quadratic_proposal - quadratic_current)
        
        # pCN acceptance probability
        log_alpha = -(quartic_diff + quadratic_diff)
        accept_prob = min(1.0, np.exp(log_alpha))
        
        # Accept/reject
        if np.random.rand() < accept_prob:
            current_u = proposal_u
            accept = True
        else:
            accept = False
        
        # Track acceptances
        if is_warmup:
            accepts_warmup += accept
        else:
            accepts_sampling += accept
    
    # Compute acceptance rates
    warmup_accept_rate = accepts_warmup / n_warmup if n_warmup > 0 else 0
    sampling_accept_rate = accepts_sampling / total_sampling_iterations
    
    print(f"Warmup acceptance rate: {warmup_accept_rate:.3f}")
    print(f"Sampling acceptance rate: {sampling_accept_rate:.3f}")
    
    return samples, sampling_accept_rate


def benchmark_pcn_allen_cahn_fft(N=100, n_samples=10000, burn_in=1000, n_thin=1, save_dir=None):
    """
    Benchmark pCN sampler on the Allen-Cahn SPDE with FFT-based periodic covariance.
    """
    # Define discretization parameters  
    h = 1.0 / N
    dim = N  # For periodic BC, we have N points (not N+1)
    
    print(f"Setting up FFT-based pCN for Allen-Cahn SPDE with N={N} (dim={dim})")
    print(f"Discretization step size h={h:.6f}")
    
    # Create FFT-based covariance operator (no regularization)
    print("Creating FFT-based covariance operator...")
    cov_op = FFTCovarianceOperator(N, h)
    
    # Define the log density function for periodic case
    def log_density(u):
        """
        Log density for Allen-Cahn SPDE: -∫[½|∇u|² + V(u)] dx
        where V(u) = (1-u²)² is the double-well potential.
        
        Note: This is the FULL energy (gradient + potential), not just the potential part.
        For pCN, we need the full target density, not the relative density.
        """
        if u.ndim == 1:
            u = u.reshape(1, -1)
            
        # Gradient term: ½∫|∇u|² dx using periodic finite differences
        # ∇u ≈ (u[j+1] - u[j])/h for periodic BC
        u_shifted = np.roll(u, -1, axis=1)  # u[j+1] with periodic wraparound
        gradients = (u_shifted - u) / h  # Finite difference approximation
        gradient_term = 0.5 * h * np.sum(gradients**2, axis=1)  # ½∫|∇u|² dx
        
        # Potential term: ∫V(u) dx where V(u) = (1-u²)²
        v_values = (1 - u**2)**2
        potential_term = h * np.sum(v_values, axis=1)
        
        # Total energy = gradient + potential
        total_energy = gradient_term + potential_term
        
        # Log density = -energy (we want to sample from exp(-energy))
        log_dens = -total_energy
            
        return log_dens
    
    # Function to compute path integral for periodic case
    def compute_path_integral(path):
        """Calculate path integral ∫u(x)dx for periodic domain"""
        if path.ndim > 1:
            return np.array([compute_path_integral(p) for p in path])
        
        # For periodic domain [0,1) with N points, integral is h * sum(u)
        return h * np.sum(path)
    
    # Initial state - start from a state closer to Allen-Cahn equilibrium
    print("Generating initial state near Allen-Cahn equilibrium...")
    # Start with u ≈ +1 with small perturbations instead of Gaussian prior
    initial = np.ones(dim) + 0.1 * np.random.randn(dim)
    
    # Run pCN sampler
    print("Running FFT-based pCN sampler...")
    start_time = time.time()
    
    try:
        samples, acceptance_rate = pcn_sampler_fft(
            log_density=log_density, 
            initial=initial, 
            n_samples=n_samples, 
            n_warmup=burn_in,
            beta=0.9,  # Even larger beta for exploration between wells
            n_thin=n_thin,
            cov_op=cov_op,
            h=h  # Pass discretization step size
        )
        
        elapsed = time.time() - start_time
        
        # Compute diagnostics
        sample_mean = np.mean(samples, axis=0)
        
        # Calculate path integrals
        path_integrals = compute_path_integral(samples)
        mean_path_integral = np.mean(path_integrals)
        path_integral_std = np.std(path_integrals)
        
        # Compute energies
        log_densities = log_density(samples)
        energies = -log_densities
        mean_energy = np.mean(energies)
        energy_std = np.std(energies)
        
        # Check well mixing with better thresholds for Allen-Cahn
        positive_well = np.mean(path_integrals > 0.2)  # Threshold for +1 well
        negative_well = np.mean(path_integrals < -0.2)  # Threshold for -1 well
        well_mixing = min(positive_well, negative_well)
        
        # Autocorrelation analysis
        acf = autocorrelation_fft(path_integrals)
        
        try:
            tau, _, ess = integrated_autocorr_time(path_integrals)
        except:
            tau, ess = np.nan, np.nan
            print("Warning: Could not compute integrated autocorrelation time")
        
        # Positive fraction
        positive_fraction = np.mean(samples > 0)
        
        print(f"FFT pCN Results:")
        print(f"  Acceptance rate: {acceptance_rate:.2f}")
        print(f"  Path integral mean: {mean_path_integral:.4f}")
        print(f"  Path integral std: {path_integral_std:.4f}")
        print(f"  Well mixing rate: {well_mixing:.4f}")
        print(f"  Integrated autocorrelation time: {tau:.2f}" if np.isfinite(tau) else "  Integrated autocorrelation time: NaN")
        print(f"  Time: {elapsed:.2f} seconds")
        
        # Store results
        results = {
            "samples": samples,
            "acceptance_rate": acceptance_rate,
            "path_integrals": path_integrals,
            "path_integral_mean": mean_path_integral,
            "path_integral_std": path_integral_std,
            "mean_energy": mean_energy,
            "energy_std": energy_std,
            "well_mixing": well_mixing,
            "positive_fraction": positive_fraction,
            "autocorrelation": acf,
            "tau": tau,
            "ess": ess,
            "time": elapsed,
            "covariance_condition_number": np.max(cov_op.eigenvalues[cov_op.eigenvalues > 1e-14])/np.min(cov_op.eigenvalues[cov_op.eigenvalues > 1e-14]) if np.sum(cov_op.eigenvalues > 1e-14) > 0 else np.inf
        }
        
        if save_dir:
            np.save(os.path.join(save_dir, "pcn_fft_samples_allen_cahn.npy"), samples)
            np.save(os.path.join(save_dir, "pcn_fft_path_integrals.npy"), path_integrals)
            
    except Exception as e:
        print(f"Error running FFT pCN: {str(e)}")
        results = {"error": str(e)}
        elapsed = time.time() - start_time
    
    return results


# Example usage
if __name__ == "__main__":
    # Test with different sizes
    print("Testing FFT-based pCN for Allen-Cahn SPDE")
    
    # Small test
    print("\nSmall test (N=32):")
    small_results = benchmark_pcn_allen_cahn_fft(N=32, n_samples=500000, burn_in=50000)
    
    if "error" not in small_results:
        print(f"Acceptance Rate: {small_results['acceptance_rate']:.3f}")
        print(f"Path Integral Std: {small_results['path_integral_std']:.4f}")
        print(f"Well Mixing: {small_results['well_mixing']:.3f}")
        print(f"ESS/sec: {small_results['ess']/small_results['time']:.2f}" if np.isfinite(small_results['ess']) else "ESS/sec: NaN")
    
    # Medium test  
    print("\nMedium test (N=128):")
    medium_results = benchmark_pcn_allen_cahn_fft(N=128, n_samples=1000000, burn_in=100000)
    
    if "error" not in medium_results:
        print(f"Acceptance Rate: {medium_results['acceptance_rate']:.3f}")
        print(f"Path Integral Std: {medium_results['path_integral_std']:.4f}")
        print(f"Well Mixing: {medium_results['well_mixing']:.3f}")
        print(f"ESS/sec: {medium_results['ess']/medium_results['time']:.2f}" if np.isfinite(medium_results['ess']) else "ESS/sec: NaN")

Testing FFT-based pCN for Allen-Cahn SPDE

Small test (N=32):
Setting up FFT-based pCN for Allen-Cahn SPDE with N=32 (dim=32)
Discretization step size h=0.031250
Creating FFT-based covariance operator...
  FFT covariance operator with quadratic regularization
  Taylor expansion: V(u) = 1 - 2u² + u⁴
  Quadratic regularization parameter α = 0.01
  Min eigenvalue: 1.00e-02
  Max eigenvalue: 1.01e+04
  Condition number: 1.01e+06
  Constant mode eigenvalue: 1.00e-02
Generating initial state near Allen-Cahn equilibrium...
Running FFT-based pCN sampler...
Warmup acceptance rate: 0.522
Sampling acceptance rate: 0.522
FFT pCN Results:
  Acceptance rate: 0.52
  Path integral mean: 0.0010
  Path integral std: 0.8775
  Well mixing rate: 0.4554
  Integrated autocorrelation time: 3.67
  Time: 32.66 seconds
Acceptance Rate: 0.522
Path Integral Std: 0.8775
Well Mixing: 0.455
ESS/sec: 4165.72

Medium test (N=128):
Setting up FFT-based pCN for Allen-Cahn SPDE with N=128 (dim=128)
Discretization step siz