In [1]:
import torch
import torch.nn as nn
import torch.distributions as dist
from torch.distributions import MultivariateNormal, InverseGamma, Normal
from typing import Tuple, Optional, List, Dict
import numpy as np


class SpectralCARVariationalBayes(nn.Module):
    """
    Variational Bayes for spatial model with spectral polynomial CAR prior.
    
    Model:
        y = X*beta + phi + eps, eps ~ N(0, sigma^2*I)
        phi ~ N(0, (tau^2 * Q(theta))^-1)
        Q(theta) = U * diag(p(lambda_j; theta)) * U^T
        p(lambda; theta) = exp(sum_k theta_k * T_k(lambda))  [Chebyshev]
    
    Marginalizes out phi for computational efficiency.
    """
    
    def __init__(
        self,
        n_obs: int,
        n_features: int,
        eigenvalues: torch.Tensor,
        eigenvectors: torch.Tensor,
        poly_order: int = 5,
        n_mc_samples: int = 20,
        prior_beta_mean: Optional[torch.Tensor] = None,
        prior_beta_std: float = 10.0,
        prior_theta_mean: Optional[torch.Tensor] = None,
        prior_theta_std: float = 1.0,
        prior_tau_a: float = 3.0,
        prior_tau_b: float = 1.0,
        prior_sigma_a: float = 3.0,
        prior_sigma_b: float = 0.5,
    ):
        """
        Args:
            n_obs: Number of observations
            n_features: Number of fixed effect features
            eigenvalues: Eigenvalues of graph Laplacian (n_obs,)
            eigenvectors: Eigenvectors of graph Laplacian (n_obs, n_obs)
            poly_order: Order of Chebyshev polynomial
            n_mc_samples: Number of MC samples for ELBO approximation
            prior_*: Hyperparameters for priors
        """
        super().__init__()
        
        self.n_obs = n_obs
        self.n_features = n_features
        self.poly_order = poly_order
        self.n_mc_samples_initial = n_mc_samples
        self.n_mc_samples = n_mc_samples
        
        # Store eigendecomposition
        self.register_buffer('eigenvalues', eigenvalues)
        self.register_buffer('eigenvectors', eigenvectors)
        
        # Normalize eigenvalues to [-1, 1] for Chebyshev stability
        lambda_min = eigenvalues.min()
        lambda_max = eigenvalues.max()
        self.register_buffer('lambda_min', lambda_min)
        self.register_buffer('lambda_max', lambda_max)
        self.register_buffer(
            'eigenvalues_normalized',
            2 * (eigenvalues - lambda_min) / (lambda_max - lambda_min + 1e-8) - 1
        )
        
        # Prior hyperparameters
        self.register_buffer('prior_beta_mean', 
                           prior_beta_mean if prior_beta_mean is not None 
                           else torch.zeros(n_features))
        self.register_buffer('prior_beta_cov', 
                           torch.eye(n_features) * prior_beta_std**2)
        
        self.register_buffer('prior_theta_mean',
                           prior_theta_mean if prior_theta_mean is not None
                           else torch.zeros(poly_order + 1))
        self.register_buffer('prior_theta_cov',
                           torch.eye(poly_order + 1) * prior_theta_std**2)
        
        self.prior_tau_a = prior_tau_a
        self.prior_tau_b = prior_tau_b
        self.prior_sigma_a = prior_sigma_a
        self.prior_sigma_b = prior_sigma_b
        
        # Variational parameters (will be optimized)
        self._init_variational_parameters()
        
    def _init_variational_parameters(self):
        """Initialize variational distribution parameters with better starting values."""
        # q(beta) = N(mu_beta, Sigma_beta)
        self.mu_beta = nn.Parameter(torch.zeros(self.n_features))
        self.log_diag_sigma_beta = nn.Parameter(torch.zeros(self.n_features))
        
        # q(theta) = N(mu_theta, Sigma_theta)
        # Initialize with small positive first coefficient (smoother spatial field)
        init_theta = torch.zeros(self.poly_order + 1)
        init_theta[0] = 0.5  # Bias towards positive spectral density
        self.mu_theta = nn.Parameter(init_theta)
        self.log_diag_sigma_theta = nn.Parameter(torch.ones(self.poly_order + 1) * (-2))
        
        # q(tau^2) = InverseGamma(a_tau, b_tau)
        # Initialize closer to expected true values
        # For InvGamma(a,b): E[X] = b/(a-1), so b = E[X]*(a-1)
        init_a_tau = 5.0
        init_b_tau = (init_a_tau - 1) * 0.5  # Target E[tau²] ≈ 0.5
        self.log_a_tau = nn.Parameter(torch.log(torch.tensor(init_a_tau)))
        self.log_b_tau = nn.Parameter(torch.log(torch.tensor(init_b_tau)))
        
        # q(sigma^2) = InverseGamma(a_sigma, b_sigma)
        init_a_sigma = 6.0  # Higher shape for more concentration
        init_b_sigma = (init_a_sigma - 1) * 0.25  # Target E[sigma²] ≈ 0.25
        self.log_a_sigma = nn.Parameter(torch.log(torch.tensor(init_a_sigma)))
        self.log_b_sigma = nn.Parameter(torch.log(torch.tensor(init_b_sigma)))
    
    @property
    def sigma_beta(self) -> torch.Tensor:
        """Standard deviations for beta (diagonal covariance)."""
        return torch.exp(self.log_diag_sigma_beta)
    
    @property
    def sigma_theta(self) -> torch.Tensor:
        """Standard deviations for theta (diagonal covariance)."""
        return torch.exp(self.log_diag_sigma_theta)
    
    @property
    def a_tau(self) -> torch.Tensor:
        """Shape parameter for tau^2."""
        return torch.exp(self.log_a_tau)
    
    @property
    def b_tau(self) -> torch.Tensor:
        """Rate parameter for tau^2."""
        return torch.exp(self.log_b_tau)
    
    @property
    def a_sigma(self) -> torch.Tensor:
        """Shape parameter for sigma^2."""
        return torch.exp(self.log_a_sigma)
    
    @property
    def b_sigma(self) -> torch.Tensor:
        """Rate parameter for sigma^2."""
        return torch.exp(self.log_b_sigma)
    
    def chebyshev_polynomials(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute Chebyshev polynomials T_0(x), ..., T_K(x).
        
        Args:
            x: Input values, shape (n,)
            
        Returns:
            Tensor of shape (n, K+1) where [:, k] contains T_k(x)
        """
        n = x.shape[0]
        T = torch.zeros(n, self.poly_order + 1, device=x.device)
        T[:, 0] = 1.0
        if self.poly_order >= 1:
            T[:, 1] = x
        for k in range(2, self.poly_order + 1):
            T[:, k] = 2 * x * T[:, k-1] - T[:, k-2]
        return T
    
    def spectral_density(self, theta: torch.Tensor) -> torch.Tensor:
        """
        Compute p(lambda_j; theta) = exp(sum_k theta_k * T_k(lambda_j)).
        
        Args:
            theta: Polynomial coefficients, shape (K+1,) or (batch, K+1)
            
        Returns:
            Spectral density values, shape (n_obs,) or (batch, n_obs)
        """
        T = self.chebyshev_polynomials(self.eigenvalues_normalized)  # (n_obs, K+1)
        
        if theta.dim() == 1:
            # Single sample: T @ theta
            log_p = torch.matmul(T, theta)  # (n_obs,)
        else:
            # Batch of samples: T @ theta.T -> (n_obs, batch) -> transpose to (batch, n_obs)
            log_p = torch.matmul(T, theta.T).T  # (batch, n_obs)
        
        return torch.exp(log_p)
    
    def sample_variational_params(self, n_samples: int) -> dict:
        """Sample from variational distributions."""
        # Sample beta
        beta_samples = self.mu_beta + self.sigma_beta * torch.randn(
            n_samples, self.n_features, device=self.mu_beta.device
        )
        
        # Sample theta
        theta_samples = self.mu_theta + self.sigma_theta * torch.randn(
            n_samples, self.poly_order + 1, device=self.mu_theta.device
        )
        
        # Sample tau^2 (using Inverse Gamma)
        # E[X] = b/(a-1) for InvGamma(a,b), sample via 1/Gamma(a, 1/b)
        gamma_samples = torch.distributions.Gamma(
            self.a_tau, 1.0 / self.b_tau
        ).sample((n_samples,))
        tau2_samples = 1.0 / gamma_samples
        
        # Sample sigma^2
        gamma_samples = torch.distributions.Gamma(
            self.a_sigma, 1.0 / self.b_sigma
        ).sample((n_samples,))
        sigma2_samples = 1.0 / gamma_samples
        
        return {
            'beta': beta_samples,
            'theta': theta_samples,
            'tau2': tau2_samples,
            'sigma2': sigma2_samples
        }
    
    def marginal_log_likelihood(
        self,
        y: torch.Tensor,
        X: torch.Tensor,
        beta: torch.Tensor,
        theta: torch.Tensor,
        tau2: torch.Tensor,
        sigma2: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute log p(y | beta, theta, tau^2, sigma^2) with phi marginalized.
        
        Uses spectral decomposition for O(n) computation.
        
        Args:
            y: Observations (n_obs,)
            X: Design matrix (n_obs, n_features)
            beta, theta, tau2, sigma2: Parameter samples
            
        Returns:
            Log likelihood value (scalar or batch)
        """
        # Compute spectral density p(lambda_j; theta)
        p_lambda = self.spectral_density(theta)  # (n_obs,) or (batch, n_obs)
        
        # Add small epsilon for numerical stability
        eps = 1e-6
        p_lambda = torch.clamp(p_lambda, min=eps)
        
        # Compute diagonal of precision matrix in spectral domain
        # D_jj = tau^2 * p(lambda_j) / (sigma^2 * tau^2 * p(lambda_j) + 1)
        if theta.dim() == 1:
            # Single sample
            denom = sigma2 * tau2 * p_lambda + 1.0
            D_diag = tau2 * p_lambda / denom
            
            # Compute residuals
            residual = y - X @ beta  # (n_obs,)
            
            # Transform to spectral domain
            y_tilde = self.eigenvectors.T @ residual  # (n_obs,)
            
            # Log determinant: sum_j log(sigma^2 * tau^2 * p_j + 1) - n*log(tau^2) - sum_j log(p_j)
            log_det = torch.sum(torch.log(denom)) - self.n_obs * torch.log(tau2) - torch.sum(torch.log(p_lambda))
            
            # Quadratic form: sum_j D_jj * y_tilde_j^2
            quad_form = torch.sum(D_diag * y_tilde**2)
            
        else:
            # Batch of samples
            batch_size = theta.shape[0]
            p_lambda = p_lambda.view(batch_size, self.n_obs)
            tau2 = tau2.view(batch_size, 1)
            sigma2 = sigma2.view(batch_size, 1)
            
            denom = sigma2 * tau2 * p_lambda + 1.0
            D_diag = tau2 * p_lambda / denom
            
            # Compute residuals for each sample
            # beta: (batch, n_features), X: (n_obs, n_features)
            # Need: y - X @ beta for each sample
            Xbeta = torch.matmul(beta, X.T)  # (batch, n_obs)
            residual = y.unsqueeze(0) - Xbeta  # (batch, n_obs)
            
            # Transform to spectral domain
            y_tilde = torch.matmul(residual, self.eigenvectors)  # (batch, n_obs)
            
            log_det = (torch.sum(torch.log(denom), dim=1) 
                      - self.n_obs * torch.log(tau2.squeeze()) 
                      - torch.sum(torch.log(p_lambda), dim=1))
            
            quad_form = torch.sum(D_diag * y_tilde**2, dim=1)
        
        log_lik = -0.5 * self.n_obs * np.log(2 * np.pi) - 0.5 * log_det - 0.5 * quad_form
        return log_lik
    
    def kl_divergence_terms(self) -> dict:
        """Compute KL divergences between variational and prior distributions."""
        # KL(q(beta) || p(beta))
        kl_beta = 0.5 * (
            torch.sum(self.sigma_beta**2 / torch.diag(self.prior_beta_cov))
            + torch.sum(((self.mu_beta - self.prior_beta_mean)**2) / torch.diag(self.prior_beta_cov))
            - self.n_features
            + torch.sum(torch.log(torch.diag(self.prior_beta_cov)))
            - 2 * torch.sum(self.log_diag_sigma_beta)
        )
        
        # KL(q(theta) || p(theta))
        kl_theta = 0.5 * (
            torch.sum(self.sigma_theta**2 / torch.diag(self.prior_theta_cov))
            + torch.sum(((self.mu_theta - self.prior_theta_mean)**2) / torch.diag(self.prior_theta_cov))
            - (self.poly_order + 1)
            + torch.sum(torch.log(torch.diag(self.prior_theta_cov)))
            - 2 * torch.sum(self.log_diag_sigma_theta)
        )
        
        # KL(q(tau^2) || p(tau^2)) for InverseGamma
        kl_tau = (
            self.a_tau * torch.log(self.b_tau / self.prior_tau_b)
            - torch.lgamma(self.a_tau) + torch.lgamma(torch.tensor(self.prior_tau_a))
            + (self.prior_tau_a - self.a_tau) * (torch.log(self.b_tau) - torch.digamma(self.a_tau))
            + self.a_tau * (self.prior_tau_b / self.b_tau - 1.0)
        )
        
        # KL(q(sigma^2) || p(sigma^2))
        kl_sigma = (
            self.a_sigma * torch.log(self.b_sigma / self.prior_sigma_b)
            - torch.lgamma(self.a_sigma) + torch.lgamma(torch.tensor(self.prior_sigma_a))
            + (self.prior_sigma_a - self.a_sigma) * (torch.log(self.b_sigma) - torch.digamma(self.a_sigma))
            + self.a_sigma * (self.prior_sigma_b / self.b_sigma - 1.0)
        )
        
        return {
            'kl_beta': kl_beta,
            'kl_theta': kl_theta,
            'kl_tau': kl_tau,
            'kl_sigma': kl_sigma
        }
    
    def elbo(self, y: torch.Tensor, X: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Compute the Evidence Lower Bound (ELBO) using MC approximation.
        
        Args:
            y: Observations (n_obs,)
            X: Design matrix (n_obs, n_features)
            
        Returns:
            ELBO value (scalar, to be maximized)
            Dictionary of diagnostics
        """
        # Sample from variational distributions
        samples = self.sample_variational_params(self.n_mc_samples)
        
        # Compute expected log likelihood via MC
        log_liks = self.marginal_log_likelihood(
            y, X, 
            samples['beta'], 
            samples['theta'],
            samples['tau2'],
            samples['sigma2']
        )
        expected_log_lik = torch.mean(log_liks)
        
        # Compute KL divergences (analytical)
        kl_terms = self.kl_divergence_terms()
        total_kl = sum(kl_terms.values())
        
        # ELBO = E[log p(y|params)] - KL
        elbo_value = expected_log_lik - total_kl
        
        diagnostics = {
            'expected_log_lik': expected_log_lik.item(),
            'total_kl': total_kl.item(),
            'log_lik_std': torch.std(log_liks).item(),
            **{k: v.item() for k, v in kl_terms.items()}
        }
        
        return elbo_value, diagnostics
    
    def fit(
        self,
        y: torch.Tensor,
        X: torch.Tensor,
        n_iterations: int = 3000,
        learning_rate: float = 0.03,
        n_mc_samples_final: int = 100,
        warmup_iterations: int = 500,
        use_scheduler: bool = True,
        verbose: bool = True,
        print_every: int = 100
    ) -> List[Dict]:
        """
        Fit the model using natural gradient variational inference.
        
        Args:
            y: Observations (n_obs,)
            X: Design matrix (n_obs, n_features)
            n_iterations: Number of optimization iterations
            learning_rate: Initial learning rate
            n_mc_samples_final: Final number of MC samples (ramped up during training)
            warmup_iterations: Number of iterations before ramping up MC samples
            use_scheduler: Whether to use learning rate scheduler
            verbose: Whether to print progress
            print_every: Print frequency
        """
        # Different learning rates for different parameter groups
        optimizer = torch.optim.Adam([
            {'params': [self.mu_beta, self.log_diag_sigma_beta], 'lr': learning_rate},
            {'params': [self.mu_theta, self.log_diag_sigma_theta], 'lr': learning_rate * 0.8},
            {'params': [self.log_a_tau, self.log_b_tau], 'lr': learning_rate * 0.3},
            {'params': [self.log_a_sigma, self.log_b_sigma], 'lr': learning_rate * 0.3}
        ])
        
        # Learning rate scheduler
        if use_scheduler:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='max', factor=0.7, patience=150, 
                threshold=0.01
            )
        
        history = []
        best_elbo = -float('inf')
        patience_counter = 0
        
        for iteration in range(n_iterations):
            # Gradually increase MC samples after warmup
            if iteration < warmup_iterations:
                self.n_mc_samples = self.n_mc_samples_initial
            else:
                progress = (iteration - warmup_iterations) / (n_iterations - warmup_iterations)
                self.n_mc_samples = int(
                    self.n_mc_samples_initial + 
                    progress * (n_mc_samples_final - self.n_mc_samples_initial)
                )
            
            optimizer.zero_grad()
            
            # Compute negative ELBO (loss to minimize)
            elbo_value, diagnostics = self.elbo(y, X)
            loss = -elbo_value
            
            # Backpropagation
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5.0)
            
            optimizer.step()
            
            # Learning rate scheduling
            if use_scheduler:
                scheduler.step(elbo_value.detach())
            
            # Track best model (could implement early stopping here)
            if elbo_value.item() > best_elbo:
                best_elbo = elbo_value.item()
                patience_counter = 0
            else:
                patience_counter += 1
            
            # Logging
            diagnostics['elbo'] = elbo_value.item()
            diagnostics['iteration'] = iteration
            diagnostics['n_mc_samples'] = self.n_mc_samples
            diagnostics['learning_rate'] = optimizer.param_groups[0]['lr']
            
            # Add current parameter values to diagnostics
            with torch.no_grad():
                diagnostics['tau2_current'] = (self.b_tau / (self.a_tau - 1)).item()
                diagnostics['sigma2_current'] = (self.b_sigma / (self.a_sigma - 1)).item()
            
            history.append(diagnostics)
            
            if verbose and (iteration % print_every == 0 or iteration == n_iterations - 1):
                print(f"Iter {iteration:4d} | ELBO: {elbo_value.item():8.2f} | "
                      f"E[log p(y|.)]: {diagnostics['expected_log_lik']:8.2f} | "
                      f"KL: {diagnostics['total_kl']:6.2f} | "
                      f"LL_std: {diagnostics['log_lik_std']:5.2f} | "
                      f"MC: {self.n_mc_samples:2d} | "
                      f"τ²: {diagnostics['tau2_current']:.3f} | "
                      f"σ²: {diagnostics['sigma2_current']:.3f}")
        
        return history
    
    def predict_spatial_effect(self, y: torch.Tensor, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict the spatial random effect phi given data.
        
        Returns mean and standard deviation.
        """
        with torch.no_grad():
            # Use posterior mean of hyperparameters
            beta = self.mu_beta
            theta = self.mu_theta
            tau2 = self.b_tau / (self.a_tau - 1)  # E[tau^2]
            sigma2 = self.b_sigma / (self.a_sigma - 1)  # E[sigma^2]
            
            # Compute residuals
            residual = y - X @ beta
            
            # Transform to spectral domain
            y_tilde = self.eigenvectors.T @ residual
            
            # Compute spectral density
            p_lambda = self.spectral_density(theta)
            p_lambda = torch.clamp(p_lambda, min=1e-6)
            
            # Posterior mean in spectral domain: M_jj * y_tilde_j
            # M_jj = 1 / (sigma^2 * tau^2 * p_j + 1)
            M_diag = 1.0 / (sigma2 * tau2 * p_lambda + 1.0)
            alpha_mean = M_diag * y_tilde
            
            # Transform back to spatial domain
            phi_mean = self.eigenvectors @ alpha_mean
            
            # Posterior variance (diagonal approximation)
            phi_var = self.eigenvectors @ (M_diag.unsqueeze(-1) * self.eigenvectors.T)
            phi_std = torch.sqrt(torch.diag(phi_var))
            
            return phi_mean, phi_std
    
    def get_parameter_summary(self) -> dict:
        """Get summary of estimated parameters."""
        with torch.no_grad():
            # Check if a > 2 for valid variance computation
            tau2_mean = (self.b_tau / (self.a_tau - 1)).item()
            tau2_var = self.b_tau**2 / ((self.a_tau - 1)**2 * torch.clamp(self.a_tau - 2, min=1e-6))
            tau2_std = torch.sqrt(tau2_var).item()
            
            sigma2_mean = (self.b_sigma / (self.a_sigma - 1)).item()
            sigma2_var = self.b_sigma**2 / ((self.a_sigma - 1)**2 * torch.clamp(self.a_sigma - 2, min=1e-6))
            sigma2_std = torch.sqrt(sigma2_var).item()
            
            return {
                'beta_mean': self.mu_beta.cpu().numpy(),
                'beta_std': self.sigma_beta.cpu().numpy(),
                'theta_mean': self.mu_theta.cpu().numpy(),
                'theta_std': self.sigma_theta.cpu().numpy(),
                'tau2_mean': tau2_mean,
                'tau2_std': tau2_std,
                'sigma2_mean': sigma2_mean,
                'sigma2_std': sigma2_std,
            }


def create_example_graph_laplacian(n_nodes: int, grid_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create a simple grid graph Laplacian for demonstration.
    
    Args:
        n_nodes: Number of nodes (should be grid_size^2)
        grid_size: Size of square grid
        
    Returns:
        eigenvalues, eigenvectors of the graph Laplacian
    """
    # Create adjacency matrix for grid graph
    W = torch.zeros(n_nodes, n_nodes)
    
    for i in range(grid_size):
        for j in range(grid_size):
            idx = i * grid_size + j
            # Right neighbor
            if j < grid_size - 1:
                W[idx, idx + 1] = 1
                W[idx + 1, idx] = 1
            # Bottom neighbor
            if i < grid_size - 1:
                W[idx, idx + grid_size] = 1
                W[idx + grid_size, idx] = 1
    
    # Degree matrix
    D = torch.diag(W.sum(dim=1))
    
    # Graph Laplacian
    L = D - W
    
    # Eigendecomposition
    eigenvalues, eigenvectors = torch.linalg.eigh(L)
    
    return eigenvalues, eigenvectors


# Example usage
if __name__ == "__main__":
    torch.manual_seed(42)
    
    # Setup
    grid_size = 8
    n_obs = grid_size ** 2
    n_features = 3
    
    print(f"Creating {grid_size}x{grid_size} grid graph ({n_obs} nodes)")
    
    # Create graph structure
    eigenvalues, eigenvectors = create_example_graph_laplacian(n_obs, grid_size)
    
    # Generate synthetic data
    X = torch.randn(n_obs, n_features)
    X[:, 0] = 1.0  # Intercept
    
    true_beta = torch.tensor([2.0, -1.0, 0.5])
    
    # True spatial effect (smooth pattern)
    true_theta = torch.tensor([0.5, -1.0, 0.3, 0.0, 0.0, 0.0])
    true_tau2 = 0.5
    true_sigma2 = 0.25
    
    model_true = SpectralCARVariationalBayes(
        n_obs, n_features, eigenvalues, eigenvectors, poly_order=5
    )
    true_p_lambda = model_true.spectral_density(true_theta)
    
    # Sample spatial effect
    Q_inv = eigenvectors @ torch.diag(1.0 / (true_tau2 * true_p_lambda)) @ eigenvectors.T
    phi_true = torch.distributions.MultivariateNormal(
        torch.zeros(n_obs), Q_inv
    ).sample()
    
    # Generate observations
    y = X @ true_beta + phi_true + torch.randn(n_obs) * torch.sqrt(torch.tensor(true_sigma2))
    
    print(f"\nTrue parameters:")
    print(f"  beta: {true_beta.numpy()}")
    print(f"  tau^2: {true_tau2:.3f}")
    print(f"  sigma^2: {true_sigma2:.3f}")
    
    # Fit model with improved settings
    print(f"\nFitting variational Bayes model...")
    model = SpectralCARVariationalBayes(
        n_obs=n_obs,
        n_features=n_features,
        eigenvalues=eigenvalues,
        eigenvectors=eigenvectors,
        poly_order=5,
        n_mc_samples=20,
        prior_tau_a=3.0,
        prior_tau_b=1.0,
        prior_sigma_a=5.0,  # Stronger prior
        prior_sigma_b=1.0   # Tighter around E[sigma²] = 0.25
    )
    
    history = model.fit(
        y, X, 
        n_iterations=3000,
        learning_rate=0.03,
        n_mc_samples_final=100,
        warmup_iterations=500,
        use_scheduler=True,
        verbose=True, 
        print_every=200
    )
    
    # Results
    print(f"\n{'='*60}")
    print(f"RESULTS")
    print(f"{'='*60}")
    
    param_summary = model.get_parameter_summary()
    
    print(f"\nEstimated parameters (mean ± std):")
    print(f"  beta:")
    for i, (m, s, t) in enumerate(zip(param_summary['beta_mean'], 
                                       param_summary['beta_std'],
                                       true_beta.numpy())):
        print(f"    β_{i}: {m:6.3f} ± {s:5.3f}  (true: {t:6.3f})")
    
    print(f"\n  tau^2:   {param_summary['tau2_mean']:6.3f} ± {param_summary['tau2_std']:5.3f}  (true: {true_tau2:.3f})")
    print(f"  sigma^2: {param_summary['sigma2_mean']:6.3f} ± {param_summary['sigma2_std']:5.3f}  (true: {true_sigma2:.3f})")
    
    # Predict spatial effects
    phi_mean, phi_std = model.predict_spatial_effect(y, X)
    
    print(f"\nSpatial effect prediction:")
    print(f"  Correlation with truth: {torch.corrcoef(torch.stack([phi_true, phi_mean]))[0,1].item():.3f}")
    print(f"  Mean absolute error: {torch.mean(torch.abs(phi_true - phi_mean)).item():.3f}")
    print(f"  RMSE: {torch.sqrt(torch.mean((phi_true - phi_mean)**2)).item():.3f}")
    
    # Print final ELBO convergence
    print(f"\nConvergence:")
    print(f"  Final ELBO: {history[-1]['elbo']:.2f}")
    print(f"  Best ELBO: {max([h['elbo'] for h in history]):.2f}")
    print(f"  Final log-likelihood std: {history[-1]['log_lik_std']:.3f}")
    
    print(f"\n{'='*60}")

Creating 8x8 grid graph (64 nodes)

True parameters:
  beta: [ 2.  -1.   0.5]
  tau^2: 0.500
  sigma^2: 0.250

Fitting variational Bayes model...


Consider using tensor.detach() first. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  current = float(metrics)


Iter    0 | ELBO:  -181.89 | E[log p(y|.)]:  -167.17 | KL:  14.72 | LL_std: 28.83 | MC: 20 | τ²: 0.490 | σ²: 0.245
Iter  200 | ELBO:  -122.91 | E[log p(y|.)]:  -106.39 | KL:  16.51 | LL_std:  4.24 | MC: 20 | τ²: 0.500 | σ²: 0.250
Iter  400 | ELBO:  -132.77 | E[log p(y|.)]:  -116.25 | KL:  16.51 | LL_std: 29.53 | MC: 20 | τ²: 0.500 | σ²: 0.250
Iter  600 | ELBO:  -123.91 | E[log p(y|.)]:  -107.43 | KL:  16.48 | LL_std:  5.63 | MC: 23 | τ²: 0.500 | σ²: 0.250
Iter  800 | ELBO:  -126.27 | E[log p(y|.)]:  -109.54 | KL:  16.73 | LL_std:  7.37 | MC: 29 | τ²: 0.500 | σ²: 0.250
Iter 1000 | ELBO:  -126.34 | E[log p(y|.)]:  -109.60 | KL:  16.73 | LL_std:  8.62 | MC: 36 | τ²: 0.500 | σ²: 0.250
Iter 1200 | ELBO:  -127.32 | E[log p(y|.)]:  -110.88 | KL:  16.44 | LL_std: 16.46 | MC: 42 | τ²: 0.500 | σ²: 0.250
Iter 1400 | ELBO:  -124.82 | E[log p(y|.)]:  -108.35 | KL:  16.48 | LL_std:  7.46 | MC: 48 | τ²: 0.500 | σ²: 0.250
Iter 1600 | ELBO:  -124.88 | E[log p(y|.)]:  -108.42 | KL:  16.46 | LL_std:  7.6