In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.distributions import Normal  # Gaussian distribution utility class


# -------------------------- Core Parameters and Data Generation --------------------------
# True model parameters (target parameters to be estimated)
true_mu = 3.0       # Mean of the true Gaussian distribution
true_sigma = 1.5    # Standard deviation of the true Gaussian distribution

def generate_data(n_samples=100):
    """
    Generate observation data (sampling from the true distribution)
    Input: n_samples - number of samples (default 100)
    Output: torch.Tensor - sampled observation data
    Significance: Simulates real-world observable data that implicitly contains information 
                 about the true parameters, serving as the basis for subsequent inference
    """
    true_dist = Normal(true_mu, true_sigma)  # Define the true distribution
    return true_dist.sample((n_samples,))    # Sample from the true distribution


# -------------------------- Probability Model Definition --------------------------
def prior_distribution():
    """
    Define the prior distribution p(θ)
    Output: Normal distribution instance - Gaussian prior with mean 0 and standard deviation 3
    Significance: Initial belief about parameter θ without observation data 
                 (here we assume θ is more likely to be near 0)
    """
    return Normal(loc=torch.tensor(0.0), scale=torch.tensor(3.0))

def likelihood(x, theta):
    """
    Define the likelihood function p(x|θ): probability of observation data x given parameter θ
    Input:
        x - observation data (torch.Tensor)
        theta - candidate value of parameter θ (scalar)
    Output: scalar - sum of log likelihoods (log p(x|θ))
    Significance: Describes "the likelihood of observing current data x under parameter θ",
                 serving as a bridge between data and parameters
    """
    # Assume likelihood follows Gaussian distribution (mean=θ, fixed standard deviation=1)
    # Return sum of log likelihoods for all data points
    return Normal(loc=theta, scale=torch.tensor(1.0)).log_prob(x).sum()


# -------------------------- Variational Posterior Definition --------------------------
class VariationalPosterior:
    """
    Approximate posterior distribution q(θ) = N(μ, σ²), using Gaussian distribution
    to approximate the true posterior p(θ|x)
    
    The core idea is to optimize variational parameters μ and σ to make q(θ) 
    as close as possible to p(θ|x)
    """
    def __init__(self):
        # Variational parameters: mean μ and log standard deviation log_sigma
        # Using log to ensure σ is positive after softplus transformation
        self.mu = torch.tensor(0.0, requires_grad=True)  # Mean parameter (differentiable)
        self.log_sigma = torch.tensor(0.0, requires_grad=True)  # Log standard deviation (differentiable)
        
    def distribution(self):
        """Return the currently parameterized Gaussian distribution"""
        sigma = torch.nn.functional.softplus(self.log_sigma)  # Ensure standard deviation is positive
        return Normal(loc=self.mu, scale=sigma)
    
    def sample(self, n_samples=1):
        """Sample from the approximate posterior (using rsample for reparameterization to enable gradient flow)"""
        return self.distribution().rsample((n_samples,))


# -------------------------- ELBO Calculation and Training --------------------------
def elbo(x, q, prior, n_samples=10):
    """
    Calculate the Evidence Lower Bound (ELBO): core objective function for variational inference
    ELBO = E_q[log p(x|θ) + log p(θ) - log q(θ)]
    
    Input:
        x - observation data
        q - variational posterior instance
        prior - prior distribution instance
        n_samples - number of Monte Carlo samples (default 10)
    Output: scalar - average ELBO value
    Significance: ELBO is a lower bound on the log probability of the true posterior.
                 Maximizing ELBO helps q(θ) approximate p(θ|x)
    """
    thetas = q.sample(n_samples)  # Sample n_samples θ values from q(θ)
    
    elbo_values = []
    for theta in thetas:
        log_likelihood = likelihood(x, theta)  # Likelihood term: log p(x|θ)
        log_prior = prior.log_prob(theta)       # Prior term: log p(θ)
        log_q = q.distribution().log_prob(theta)  # Variational posterior term: log q(θ)
        elbo_val = log_likelihood + log_prior - log_q  # ELBO for a single sample
        elbo_values.append(elbo_val)
    
    return torch.mean(torch.stack(elbo_values))  # Average ELBO (reduces estimation variance)

def train_vi(x, num_epochs=1000, lr=0.01, n_samples=10):
    """
    Train the variational inference model: optimize variational parameters by maximizing ELBO
    
    Input:
        x - observation data
        num_epochs - number of training iterations (default 1000)
        lr - learning rate (default 0.01)
        n_samples - number of samples for ELBO calculation (default 10)
    Output:
        q - optimized variational posterior instance
        elbo_history - record of ELBO changes during training (for plotting)
    Significance: Gradient descent maximizes ELBO, gradually making q(θ) approximate the true posterior
    """
    q = VariationalPosterior()  # Initialize variational posterior
    optimizer = optim.Adam([q.mu, q.log_sigma], lr=lr)  # Optimizer (Adam)
    elbo_history = []  # Record historical ELBO values
    prior = prior_distribution()  # Prior distribution
    
    for epoch in range(num_epochs):
        current_elbo = elbo(x, q, prior, n_samples)  # Calculate current ELBO
        loss = -current_elbo  # Maximizing ELBO is equivalent to minimizing negative ELBO
        
        # Gradient descent parameter update
        optimizer.zero_grad()  # Clear gradients
        loss.backward()        # Backpropagation to calculate gradients
        optimizer.step()       # Update parameters
        
        elbo_history.append(current_elbo.item())  # Record ELBO
        
        # Print progress every 100 epochs
        if (epoch + 1) % 100 == 0:
            sigma = torch.nn.functional.softplus(q.log_sigma).item()
            print(f"Epoch {epoch+1}/{num_epochs}, ELBO: {current_elbo.item():.2f}, "
                  f"mu: {q.mu.item():.2f}, sigma: {sigma:.2f}")
    
    return q, elbo_history


# -------------------------- Main Function and Visualization --------------------------
def main():
    # Generate observation data (core input data)
    x = generate_data(n_samples=100)
    print(f"Sample of generated data: {x[:5].numpy()}")  # Print first 5 data points
    
    # Train VI model
    q, elbo_history = train_vi(x, num_epochs=1000, lr=0.01)
    
    # Output final results (core output)
    final_sigma = torch.nn.functional.softplus(q.log_sigma).item()
    print("\nFinal variational parameters:")
    print(f"True mean: {true_mu}, Estimated mean: {q.mu.item():.2f}")
    print(f"True standard deviation: {true_sigma}, Estimated standard deviation: {final_sigma:.2f}")
    
    # Plot ELBO convergence curve (core visualization output)
    plt.figure(figsize=(10, 6))
    plt.plot(elbo_history)
    plt.title('ELBO Convergence', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('ELBO Value', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.show()
    
    # Plot comparison between true distribution and approximate posterior (core visualization output)
    thetas = np.linspace(-1, 7, 1000)
    true_probs = np.exp([Normal(true_mu, true_sigma).log_prob(torch.tensor(t)).item() for t in thetas])
    q_probs = np.exp([q.distribution().log_prob(torch.tensor(t)).item() for t in thetas])
    
    plt.figure(figsize=(10, 6))
    plt.plot(thetas, true_probs, label='True Posterior', color='blue')
    plt.plot(thetas, q_probs, label='Variational Posterior', color='red', linestyle='--')
    plt.axvline(x=true_mu, color='blue', linestyle=':', label=f'True Mean ({true_mu})')
    plt.axvline(x=q.mu.item(), color='red', linestyle=':', label=f'Estimated Mean ({q.mu.item():.2f})')
    plt.title('True vs Variational Posterior', fontsize=14)
    plt.xlabel('θ', fontsize=12)
    plt.ylabel('Probability Density', fontsize=12)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

if __name__ == "__main__":
    main()