# Variational Monte Carlo (VMC) Methods with NetKet

This notebook provides a comprehensive exploration of Variational Monte Carlo methods using NetKet. VMC is a powerful technique for finding approximate ground states and computing expectation values of quantum many-body systems.

## Table of Contents
1. [Introduction to Variational Monte Carlo](#introduction)
2. [Setting up the Problem](#setup)
3. [Basic VMC with RBM](#basic-vmc)
4. [Advanced Optimizers](#optimizers)
5. [Sampling Strategies](#sampling)
6. [Energy Minimization](#energy-minimization)
7. [Computing Observables](#observables)
8. [Performance Analysis](#performance)
9. [Conclusion](#conclusion)

## 1. Introduction to Variational Monte Carlo {#introduction}

The Variational Monte Carlo method uses the variational principle to find approximate ground states:

$$E[\psi] = \frac{\langle \psi | H | \psi \rangle}{\langle \psi | \psi \rangle} \geq E_0$$

For a parameterized ansatz $|\psi(\theta)\rangle$, we minimize:

$$E(\theta) = \sum_\sigma p_\sigma(\theta) E_{\text{loc}}(\sigma, \theta)$$

where:
- $p_\sigma(\theta) = |\psi(\sigma, \theta)|^2 / \sum_{\sigma'} |\psi(\sigma', \theta)|^2$ is the probability distribution
- $E_{\text{loc}}(\sigma, \theta) = \frac{\sum_{\sigma'} H_{\sigma,\sigma'} \psi(\sigma', \theta)}{\psi(\sigma, \theta)}$ is the local energy

### Key Components:
1. **Variational ansatz**: Neural network representation
2. **Monte Carlo sampling**: Generate configurations from $|\psi|^2$
3. **Local energy**: Efficiently compute matrix elements
4. **Optimization**: Minimize energy using gradients

In [None]:
# Import necessary libraries
import netket as nk
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import matplotlib.pyplot as plt
import optax
from typing import Any
import warnings
warnings.filterwarnings('ignore')

# Configure JAX
jax.config.update('jax_platform_name', 'cpu')
key = jax.random.PRNGKey(42)

print(f"NetKet version: {nk.__version__}")
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")

## 2. Setting up the Problem {#setup}

Let's set up a quantum many-body problem - the 1D transverse field Ising model:

$$H = -J \sum_{i} \sigma^z_i \sigma^z_{i+1} - h \sum_i \sigma^x_i$$

In [None]:
# Define the system
N = 8  # Number of spins

# Create 1D chain with periodic boundary conditions
lattice = nk.graph.Chain(length=N, pbc=True)

# Define Hilbert space
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)

# Parameters
J = 1.0  # Coupling strength
h = 0.5  # Transverse field

# Create Hamiltonian
hamiltonian = nk.operator.Ising(hilbert=hilbert, graph=lattice, J=J, h=h)

print(f"System: 1D chain with {N} spins")
print(f"Hilbert space dimension: {hilbert.n_states}")
print(f"Parameters: J={J}, h={h}")
print(f"Expected phase: {'Ordered' if h < J else 'Disordered'}")

In [None]:
# For comparison, let's compute the exact ground state
try:
    print("Computing exact ground state...")
    eigenvalues, eigenvectors = nk.exact.lanczos_ed(hamiltonian, k=1, which='SR')
    exact_energy = eigenvalues[0]
    print(f"Exact ground state energy: {exact_energy:.8f}")
    exact_available = True
except Exception as e:
    print(f"Exact diagonalization not available: {e}")
    exact_available = False
    exact_energy = None

## 3. Basic VMC with RBM {#basic-vmc}

Let's start with a basic VMC calculation using a Restricted Boltzmann Machine as our variational ansatz.

In [None]:
# Create the variational ansatz (RBM)
alpha = 2  # Hidden unit density
model = nk.models.RBM(
    alpha=alpha,
    dtype=jnp.complex128,
    use_visible_bias=True,
    use_hidden_bias=True
)

print(f"RBM parameters:")
print(f"- Visible units: {N}")
print(f"- Hidden units: {alpha * N}")
print(f"- Total parameters: ~{N + alpha * N + N * alpha * N}")

In [None]:
# Create Monte Carlo sampler
sampler = nk.sampler.MetropolisLocal(
    hilbert=hilbert,
    n_chains=16,  # Number of parallel chains
    sweep_size=hilbert.size  # Number of local moves per sweep
)

print(f"Sampler configuration:")
print(f"- Type: {type(sampler).__name__}")
print(f"- Chains: {sampler.n_chains}")
print(f"- Sweep size: {sampler.sweep_size}")

In [None]:
# Create the Variational Monte Carlo state
vqs = nk.vqs.MCState(
    sampler=sampler,
    model=model,
    n_samples=1000,  # Number of samples per iteration
    n_discard_per_chain=100,  # Thermalization
    chunk_size=None  # Auto-determine chunk size
)

print(f"Variational state configuration:")
print(f"- Samples per iteration: {vqs.n_samples}")
print(f"- Thermalization: {vqs.n_discard_per_chain}")
print(f"- Model: {type(vqs.model).__name__}")

In [None]:
# Initialize the variational state
key, subkey = jax.random.split(key)
vqs.init_parameters(subkey)

print("Initialized parameters:")
for name, param in vqs.parameters.items():
    if hasattr(param, 'shape'):
        print(f"- {name}: {param.shape}")
    else:
        for subname, subparam in param.items():
            print(f"- {name}.{subname}: {subparam.shape}")

In [None]:
# Compute initial energy
initial_energy = vqs.expect(hamiltonian)
print(f"Initial energy: {initial_energy.mean:.6f} ± {initial_energy.error:.6f}")

if exact_available:
    error = abs(initial_energy.mean - exact_energy)
    print(f"Error vs exact: {error:.6f}")
    print(f"Relative error: {error/abs(exact_energy)*100:.2f}%")

## 4. Advanced Optimizers {#optimizers}

NetKet provides several sophisticated optimizers specifically designed for variational quantum states.

In [None]:
# Compare different optimizers
optimizers = {
    'SGD': nk.optimizer.Sgd(learning_rate=0.05),
    'Adam': nk.optimizer.Adam(learning_rate=0.01),
    'AdaMax': nk.optimizer.AdaMax(learning_rate=0.01),
    'Momentum': nk.optimizer.Momentum(learning_rate=0.02, beta=0.9)
}

print("Available optimizers:")
for name, opt in optimizers.items():
    print(f"- {name}: {type(opt).__name__} (lr={opt.learning_rate})")

In [None]:
# Create variational optimization driver
optimizer = optimizers['Adam']
vmc = nk.VMC(
    hamiltonian=hamiltonian,
    optimizer=optimizer,
    variational_state=vqs
)

print(f"VMC driver created with {type(optimizer).__name__} optimizer")

## 5. Sampling Strategies {#sampling}

Different sampling strategies can significantly affect the efficiency and accuracy of VMC calculations.

In [None]:
# Compare different samplers
samplers = {
    'MetropolisLocal': nk.sampler.MetropolisLocal(hilbert, n_chains=16),
    'MetropolisExchange': nk.sampler.MetropolisExchange(hilbert, graph=lattice, n_chains=16),
}

# Test each sampler
sampler_stats = {}

for name, sampler in samplers.items():
    print(f"\nTesting {name} sampler:")
    
    # Create temporary VQS
    temp_vqs = nk.vqs.MCState(
        sampler=sampler,
        model=model,
        n_samples=500
    )
    temp_vqs.parameters = vqs.parameters  # Use same parameters
    
    # Sample and compute statistics
    samples = temp_vqs.sample()
    acceptance_rate = sampler.acceptance if hasattr(sampler, 'acceptance') else 'N/A'
    
    print(f"- Samples shape: {samples.shape}")
    print(f"- Acceptance rate: {acceptance_rate}")
    print(f"- Unique configurations: {len(np.unique(samples.reshape(-1, samples.shape[-1]), axis=0))}")
    
    sampler_stats[name] = {
        'samples': samples,
        'acceptance': acceptance_rate
    }

## 6. Energy Minimization {#energy-minimization}

Now let's perform the actual variational optimization to find the ground state.

In [None]:
# Run VMC optimization
n_iter = 200
print(f"Running VMC optimization for {n_iter} iterations...")

# Storage for results
energies = []
errors = []
iterations = []

# Run optimization loop
for i in range(n_iter):
    # Perform one optimization step
    vmc.advance()
    
    # Store results
    energy = vmc.energy
    energies.append(energy.mean)
    errors.append(energy.error)
    iterations.append(i + 1)
    
    # Print progress
    if (i + 1) % 50 == 0 or i == 0:
        print(f"Iteration {i+1:3d}: E = {energy.mean:.6f} ± {energy.error:.6f}")
        if exact_available:
            error = abs(energy.mean - exact_energy)
            print(f"              Error vs exact: {error:.6f}")

print("\nOptimization completed!")
final_energy = energies[-1]
final_error = errors[-1]
print(f"Final energy: {final_energy:.6f} ± {final_error:.6f}")

if exact_available:
    final_diff = abs(final_energy - exact_energy)
    print(f"Final error vs exact: {final_diff:.6f}")
    print(f"Final relative error: {final_diff/abs(exact_energy)*100:.2f}%")

In [None]:
# Visualize optimization progress
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Energy convergence
ax1.errorbar(iterations, energies, yerr=errors, alpha=0.8, capsize=2)
if exact_available:
    ax1.axhline(y=exact_energy, color='red', linestyle='--', label='Exact')
    ax1.legend()
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Energy')
ax1.set_title('VMC Energy Convergence')
ax1.grid(True, alpha=0.3)

# Error vs exact (if available)
if exact_available:
    energy_errors = [abs(e - exact_energy) for e in energies]
    ax2.semilogy(iterations, energy_errors)
    ax2.set_xlabel('Iteration')
    ax2.set_ylabel('|E - E_exact|')
    ax2.set_title('Absolute Error vs Exact')
    ax2.grid(True, alpha=0.3)
else:
    ax2.plot(iterations, errors)
    ax2.set_xlabel('Iteration')
    ax2.set_ylabel('Statistical Error')
    ax2.set_title('Monte Carlo Error')
    ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Computing Observables {#observables}

Once we have optimized our variational state, we can compute various physical observables.

In [None]:
# Define observables
observables = {}

# Magnetization in z-direction
mag_z = sum([nk.operator.spin.sigmaz(hilbert, i) for i in range(N)]) / N
observables['⟨σᶻ⟩'] = mag_z

# Magnetization in x-direction
mag_x = sum([nk.operator.spin.sigmax(hilbert, i) for i in range(N)]) / N
observables['⟨σˣ⟩'] = mag_x

# Energy density
observables['⟨H⟩/N'] = hamiltonian

print("Computing observables...")
results = {}

for name, observable in observables.items():
    if name == '⟨H⟩/N':
        expectation = vqs.expect(observable)
        results[name] = expectation.mean / N
        error = expectation.error / N
    else:
        expectation = vqs.expect(observable)
        results[name] = expectation.mean
        error = expectation.error
    
    print(f"{name}: {results[name]:.6f} ± {error:.6f}")

print(f"\nPhysical interpretation:")
print(f"- |⟨σᶻ⟩| = {abs(results['⟨σᶻ⟩']):.4f} (ordered if > 0.1)")
print(f"- ⟨σˣ⟩ = {results['⟨σˣ⟩']:.4f} (quantum fluctuations)")
print(f"- Energy per site: {results['⟨H⟩/N']:.4f}")

In [None]:
# Compute correlation functions
print("Computing correlation functions...")

# Z-Z correlations
correlations_zz = []
distances = []

for d in range(1, N//2 + 1):
    corr_sum = 0
    count = 0
    
    for i in range(N):
        j = (i + d) % N
        sigma_z_i = nk.operator.spin.sigmaz(hilbert, i)
        sigma_z_j = nk.operator.spin.sigmaz(hilbert, j)
        corr_op = sigma_z_i * sigma_z_j
        
        corr_val = vqs.expect(corr_op).mean
        corr_sum += corr_val
        count += 1
    
    avg_corr = corr_sum / count
    correlations_zz.append(avg_corr)
    distances.append(d)
    print(f"C_zz({d}): {avg_corr:.6f}")

# Plot correlations
plt.figure(figsize=(8, 5))
plt.plot(distances, correlations_zz, 'o-', label='⟨σᶻᵢσᶻⱼ⟩')
plt.xlabel('Distance |i-j|')
plt.ylabel('Correlation')
plt.title('Spin-Spin Correlations')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 8. Performance Analysis {#performance}

Let's analyze the performance characteristics of our VMC calculation.

In [None]:
# Analyze sampling efficiency
print("Sampling efficiency analysis:")

# Get current samples
samples = vqs.sample()
print(f"Sample shape: {samples.shape}")
print(f"Total samples: {samples.size // samples.shape[-1]}")

# Compute autocorrelation time (approximate)
def autocorr_function(x, max_lag=50):
    """Compute autocorrelation function"""
    n = len(x)
    x = x - np.mean(x)
    autocorr = np.correlate(x, x, mode='full')
    autocorr = autocorr[n-1:]
    autocorr = autocorr / autocorr[0]
    return autocorr[:min(max_lag, len(autocorr))]

# Analyze first chain
first_chain = samples[0, :, 0]  # First spin of first chain
autocorr = autocorr_function(first_chain)

# Find decorrelation time
decorr_time = None
for i, ac in enumerate(autocorr):
    if ac < 1/np.e:  # 1/e criterion
        decorr_time = i
        break

print(f"Approximate decorrelation time: {decorr_time if decorr_time else '> 50'} steps")

# Effective sample size
if decorr_time:
    eff_samples = samples.shape[1] // (2 * decorr_time)
    print(f"Effective samples per chain: ~{eff_samples}")

# Plot autocorrelation
plt.figure(figsize=(8, 5))
plt.plot(autocorr, 'o-')
plt.axhline(y=1/np.e, color='red', linestyle='--', label='1/e')
if decorr_time:
    plt.axvline(x=decorr_time, color='red', linestyle='--', alpha=0.5)
plt.xlabel('Lag')
plt.ylabel('Autocorrelation')
plt.title('Autocorrelation Function')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Performance vs sample size
print("Performance vs sample size analysis:")

sample_sizes = [100, 500, 1000, 2000, 5000]
energy_estimates = []
energy_errors = []
computation_times = []

for n_samples in sample_sizes:
    print(f"Testing {n_samples} samples...")
    
    # Create temporary VQS with different sample size
    temp_vqs = nk.vqs.MCState(
        sampler=sampler,
        model=model,
        n_samples=n_samples
    )
    temp_vqs.parameters = vqs.parameters
    
    # Time the energy computation
    import time
    start_time = time.time()
    energy = temp_vqs.expect(hamiltonian)
    comp_time = time.time() - start_time
    
    energy_estimates.append(energy.mean)
    energy_errors.append(energy.error)
    computation_times.append(comp_time)
    
    print(f"  Energy: {energy.mean:.6f} ± {energy.error:.6f} (time: {comp_time:.2f}s)")

# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Error vs sample size
ax1.loglog(sample_sizes, energy_errors, 'o-')
ax1.loglog(sample_sizes, 1/np.sqrt(sample_sizes) * energy_errors[0] * np.sqrt(sample_sizes[0]), 
           '--', alpha=0.5, label='1/√N scaling')
ax1.set_xlabel('Number of samples')
ax1.set_ylabel('Energy error')
ax1.set_title('Statistical Error vs Sample Size')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Computation time vs sample size
ax2.loglog(sample_sizes, computation_times, 'o-', color='red')
ax2.set_xlabel('Number of samples')
ax2.set_ylabel('Computation time (s)')
ax2.set_title('Computation Time vs Sample Size')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Conclusion {#conclusion}

In this notebook, we've explored the key aspects of Variational Monte Carlo methods in NetKet:

### Key Concepts Covered:

1. **VMC Framework**: Variational principle and energy minimization
2. **Neural Quantum States**: RBM as variational ansatz
3. **Monte Carlo Sampling**: Different sampling strategies and efficiency
4. **Optimization**: Various optimizers and convergence analysis
5. **Observables**: Computing physical quantities and correlations
6. **Performance**: Scaling and efficiency considerations

### Key Findings:

- **Convergence**: VMC can efficiently find approximate ground states
- **Accuracy**: Depends on ansatz expressivity and optimization quality
- **Efficiency**: Proper sampling is crucial for accurate results
- **Scalability**: Method scales polynomially with system size

### Best Practices:

1. **Choose appropriate ansatz** for the quantum phase
2. **Use sufficient samples** for stable gradient estimates
3. **Monitor convergence** carefully
4. **Validate results** against exact solutions when possible
5. **Consider symmetries** for improved efficiency

### Extensions:

- **Excited states**: Using penalty methods or orthogonalization
- **Real-time evolution**: Time-dependent variational principle
- **Finite temperature**: Thermal states and free energy
- **Open systems**: Non-Hermitian Hamiltonians

VMC with neural quantum states represents a powerful approach to quantum many-body problems, offering a good balance between accuracy and computational efficiency for medium to large systems.