# Neural Quantum States (NQS) with NetKet

This notebook provides a comprehensive exploration of Neural Network representations of quantum states using NetKet. Neural Quantum States (NQS) are a powerful approach to represent quantum many-body systems using artificial neural networks.

## Table of Contents
1. [Introduction to Neural Quantum States](#introduction)
2. [Setting up NetKet](#setup)
3. [Basic NQS: Restricted Boltzmann Machine](#rbm)
4. [Feed-Forward Neural Networks](#ffnn)
5. [Complex-valued Neural Networks](#complex)
6. [Autoregressive Neural Networks](#autoregressive)
7. [Performance Comparison](#comparison)
8. [Conclusion](#conclusion)

## 1. Introduction to Neural Quantum States {#introduction}

Neural Quantum States represent quantum many-body wavefunctions using neural networks. For a quantum state $|\psi\rangle$, the NQS ansatz is:

$$\psi(\sigma) = \langle \sigma | \psi \rangle = \mathcal{N} \exp\left(\sum_i \ln f_i(\sigma; \theta)\right)$$

where:
- $\sigma$ represents a configuration in the computational basis
- $f_i(\sigma; \theta)$ are neural network outputs parameterized by $\theta$
- $\mathcal{N}$ is a normalization constant

### Key Advantages:
- **Expressivity**: Can represent complex quantum states
- **Scalability**: Polynomial scaling with system size
- **Flexibility**: Can incorporate physical symmetries
- **Optimization**: Leverages automatic differentiation

In [2]:
# Import necessary libraries
import netket as nk
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import matplotlib.pyplot as plt
from typing import Any, Callable
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
jax.config.update('jax_platform_name', 'cpu')  # Use CPU for this tutorial
key = jax.random.PRNGKey(42)

print(f"NetKet version: {nk.__version__}")
print(f"JAX version: {jax.__version__}")

NetKet version: 3.19.2
JAX version: 0.5.3


## 2. Setting up NetKet {#setup}

Let's start by setting up a simple quantum system - the 1D Ising model with transverse field:

$$H = -J \sum_{i} \sigma^z_i \sigma^z_{i+1} - h \sum_i \sigma^x_i$$

In [3]:
# Define the lattice
N = 8  # Number of spins
lattice = nk.graph.Chain(length=N, pbc=True)  # Periodic boundary conditions

# Define the Hilbert space (spin-1/2)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)

# Define the Hamiltonian
J = 1.0  # Coupling strength
h = 0.5  # Transverse field strength

# Ising interaction
hamiltonian = nk.operator.Ising(hilbert=hilbert, graph=lattice, J=J, h=h)

print(f"System size: {N} spins")
print(f"Hilbert space dimension: {hilbert.n_states}")
print(f"Hamiltonian: {hamiltonian}")

System size: 8 spins
Hilbert space dimension: 256
Hamiltonian: IsingJax(J=1.0, h=0.5; dim=8)


## 3. Basic NQS: Restricted Boltzmann Machine {#rbm}

The Restricted Boltzmann Machine (RBM) is one of the simplest and most successful NQS architectures. It consists of visible and hidden units:

$$\psi_{\text{RBM}}(\sigma) = \exp\left(\sum_i a_i \sigma_i + \sum_j b_j h_j + \sum_{ij} W_{ij} \sigma_i h_j\right)$$

where the hidden units are marginalized: $h_j = \tanh(b_j + \sum_i W_{ij} \sigma_i)$

In [None]:
# Create an RBM-based Neural Quantum State
alpha = 2  # Hidden unit density (hidden units = alpha * visible units)

# RBM with default parameters (dtype is handled internally)
rbm_real = nk.models.RBM(
    alpha=alpha,
    use_visible_bias=True,
    use_hidden_bias=True
)

print(f"RBM Architecture:")
print(f"- Visible units: {N}")
print(f"- Hidden units: {alpha * N}")
print(f"- Total parameters: {N + alpha * N + N * alpha * N}")

# Initialize the model with proper input shape
key, subkey = jax.random.split(key)
sample_input = hilbert.random_state(subkey, size=1)

# Initialize parameters correctly
try:
    rbm_params = rbm_real.init(subkey, sample_input)
    print(f"\nModel initialization successful!")
except Exception as e:
    print(f"Standard initialization failed: {e}")
    # Try alternative initialization
    rbm_params = rbm_real.init({'params': subkey}, sample_input)
    print(f"Alternative initialization successful!")

print(f"\nParameter structure:")
try:
    if 'params' in rbm_params:
        for name, param in rbm_params['params'].items():
            print(f"- {name}: {param.shape} (dtype: {param.dtype})")
    else:
        # Handle different parameter structure
        for name, param in rbm_params.items():
            if hasattr(param, 'shape'):
                print(f"- {name}: {param.shape} (dtype: {param.dtype})")
            else:
                print(f"- {name}: {type(param)}")
except Exception as e:
    print(f"Parameter inspection failed: {e}")

# Verify the model works
try:
    test_output = rbm_real.apply(rbm_params, sample_input)
    print(f"\nModel test successful. Output shape: {test_output.shape}, dtype: {test_output.dtype}")
except Exception as e:
    print(f"\nModel test failed: {e}")

TypeError: RBM.__init__() got an unexpected keyword argument 'dtype'

In [None]:
# Demonstrate RBM evaluation
# Generate some sample configurations
key, subkey = jax.random.split(key)
sample_configs = hilbert.random_state(subkey, size=5)

print("Sample configurations and their RBM log-amplitudes:")
print(f"Sample configs shape: {sample_configs.shape}")
print(f"Sample configs dtype: {sample_configs.dtype}")

try:
    log_psi = rbm_real.apply(rbm_params, sample_configs)
    
    print(f"RBM output shape: {log_psi.shape}")
    print(f"RBM output dtype: {log_psi.dtype}")
    
    for i, (config, log_amp) in enumerate(zip(sample_configs, log_psi)):
        print(f"Config {i+1}: {config} -> log|ψ| = {log_amp:.4f}")
    
    # The probability amplitude
    print(f"\nProbability amplitudes |ψ|²:")
    probs = jnp.exp(2 * jnp.real(log_psi))
    for i, prob in enumerate(probs):
        print(f"Config {i+1}: |ψ|² = {prob:.6f}")
        
except Exception as e:
    print(f"Error in RBM evaluation: {e}")
    print("Debugging - trying individual configurations:")
    
    # Try with individual configurations
    for i, config in enumerate(sample_configs):
        try:
            # Ensure proper shape for single configuration
            single_config = config.reshape(1, -1) if len(config.shape) == 1 else config
            log_amp = rbm_real.apply(rbm_params, single_config)
            print(f"Config {i+1}: {config} -> log|ψ| = {log_amp[0]:.4f}")
        except Exception as inner_e:
            print(f"Config {i+1} failed: {inner_e}")
            # Try without reshaping
            try:
                log_amp = rbm_real.apply(rbm_params, config)
                print(f"Config {i+1} (no reshape): {config} -> log|ψ| = {log_amp:.4f}")
            except Exception as inner_e2:
                print(f"Config {i+1} (no reshape) also failed: {inner_e2}")

## 4. Feed-Forward Neural Networks {#ffnn}

Feed-forward neural networks can provide more expressive representations than RBMs. In NetKet, we can easily create custom architectures using Flax.

In [None]:
class DeepFFNN(nn.Module):
    """
    Deep Feed-Forward Neural Network for quantum states.
    """
    layers: tuple = (32, 32, 16)
    activation: Callable = nn.relu
    dtype: Any = jnp.complex128
    
    @nn.compact
    def __call__(self, x):
        # Convert spin configuration to real values
        x = jnp.array(x, dtype=jnp.float64)
        
        # Hidden layers
        for layer_size in self.layers:
            x = nn.Dense(layer_size, dtype=self.dtype)(x)
            x = self.activation(x)
        
        # Output layer (single complex number)
        x = nn.Dense(1, dtype=self.dtype)(x)
        return jnp.squeeze(x, axis=-1)

# Create the deep neural network
deep_nn = DeepFFNN(layers=(16, 16, 8), dtype=jnp.complex128)

# Initialize parameters
key, subkey = jax.random.split(key)
deep_params = deep_nn.init(subkey, sample_input[0])

print(f"Deep FFNN Architecture:")
total_params = 0
for name, param in deep_params['params'].items():
    param_count = np.prod(param.shape)
    total_params += param_count
    print(f"- {name}: {param.shape} ({param_count} parameters)")
print(f"Total parameters: {total_params}")

In [None]:
# Test the deep neural network
print("Deep FFNN evaluation on sample configurations:")
log_psi_deep = deep_nn.apply(deep_params, sample_configs)

for i, (config, log_amp) in enumerate(zip(sample_configs, log_psi_deep)):
    print(f"Config {i+1}: {config} -> log ψ = {log_amp:.4f}")

# Compare magnitude with RBM
print(f"\nComparison of |log ψ| magnitudes:")
print(f"RBM: {jnp.abs(log_psi).mean():.4f} ± {jnp.abs(log_psi).std():.4f}")
print(f"Deep NN: {jnp.abs(log_psi_deep).mean():.4f} ± {jnp.abs(log_psi_deep).std():.4f}")

## 5. Complex-valued Neural Networks {#complex}

For quantum states, we often need complex-valued wavefunctions. NetKet supports complex neural networks naturally through JAX's complex number support.

In [None]:
# Complex RBM - also remove dtype parameter
rbm_complex = nk.models.RBM(
    alpha=alpha,
    use_visible_bias=True,
    use_hidden_bias=True
)

# Initialize with complex parameters
key, subkey = jax.random.split(key)
try:
    rbm_complex_params = rbm_complex.init(subkey, sample_input)
    
    # Evaluate complex amplitudes
    log_psi_complex = rbm_complex.apply(rbm_complex_params, sample_configs)
    
    print("Complex RBM evaluation:")
    for i, (config, log_amp) in enumerate(zip(sample_configs, log_psi_complex)):
        # Handle both real and complex outputs
        if jnp.iscomplexobj(log_amp):
            phase = jnp.angle(jnp.exp(log_amp))
            magnitude = jnp.abs(jnp.exp(log_amp))
            print(f"Config {i+1}: log ψ = {log_amp:.4f}, |ψ| = {magnitude:.4f}, phase = {phase:.4f}")
        else:
            magnitude = jnp.exp(log_amp)
            print(f"Config {i+1}: log ψ = {log_amp:.4f}, |ψ| = {magnitude:.4f} (real)")
            
except Exception as e:
    print(f"Complex RBM failed: {e}")
    print("This version may not support complex RBMs directly")

## 6. Autoregressive Neural Networks {#autoregressive}

Autoregressive models factorize the wavefunction as:
$$\psi(\sigma_1, \ldots, \sigma_N) = \prod_{i=1}^N \psi_i(\sigma_i | \sigma_1, \ldots, \sigma_{i-1})$$

This approach can be very expressive and allows for exact sampling.

In [None]:
# Create an autoregressive neural network
# For simplicity, we'll use NetKet's FastARNN
try:
    autoregressive_nn = nk.models.FastARNN(
        hilbert=hilbert,
        layers=2,
        features=8,
        dtype=jnp.complex128
    )
    
    # Initialize parameters
    key, subkey = jax.random.split(key)
    ar_params = autoregressive_nn.init(subkey, sample_input)
    
    # Evaluate
    log_psi_ar = autoregressive_nn.apply(ar_params, sample_configs)
    
    print("Autoregressive NN evaluation:")
    for i, (config, log_amp) in enumerate(zip(sample_configs, log_psi_ar)):
        print(f"Config {i+1}: {config} -> log ψ = {log_amp:.4f}")
        
except Exception as e:
    print(f"Autoregressive model not available in this NetKet version: {e}")
    print("Using alternative implementation...")
    
    # Simple alternative: factorized approach
    class SimpleAutoregressive(nn.Module):
        features: int = 8
        dtype: Any = jnp.complex128
        
        @nn.compact
        def __call__(self, x):
            x = jnp.array(x, dtype=jnp.float64)
            log_psi = 0.0
            
            for i in range(x.shape[-1]):
                # Condition on previous spins
                context = x[..., :i+1]
                h = nn.Dense(self.features, dtype=self.dtype, name=f'dense_{i}')(context)
                h = nn.tanh(h)
                contrib = nn.Dense(1, dtype=self.dtype, name=f'output_{i}')(h)
                log_psi += jnp.squeeze(contrib, axis=-1)
            
            return log_psi
    
    simple_ar = SimpleAutoregressive()
    key, subkey = jax.random.split(key)
    simple_ar_params = simple_ar.init(subkey, sample_input[0])
    log_psi_simple_ar = simple_ar.apply(simple_ar_params, sample_configs)
    
    print("Simple Autoregressive evaluation:")
    for i, (config, log_amp) in enumerate(zip(sample_configs, log_psi_simple_ar)):
        print(f"Config {i+1}: {config} -> log ψ = {log_amp:.4f}")

## 7. Performance Comparison {#comparison}

Let's compare different NQS architectures in terms of expressivity and computational cost.

In [None]:
# Generate a larger sample for statistics
key, subkey = jax.random.split(key)
large_sample = hilbert.random_state(subkey, size=1000)

# Evaluate all models
models = {
    'RBM (Real)': (rbm_real, rbm_params),
    'RBM (Complex)': (rbm_complex, rbm_complex_params),
    'Deep FFNN': (deep_nn, deep_params)
}

results = {}

for name, (model, params) in models.items():
    # Time the evaluation
    import time
    start_time = time.time()
    log_psi = model.apply(params, large_sample)
    eval_time = time.time() - start_time
    
    # Calculate statistics
    log_psi_var = jnp.var(jnp.real(log_psi))
    log_psi_mean = jnp.mean(jnp.real(log_psi))
    
    results[name] = {
        'mean': log_psi_mean,
        'variance': log_psi_var,
        'eval_time': eval_time,
        'log_psi': log_psi
    }

# Display results
print("Model Comparison (1000 evaluations):")
print("-" * 60)
for name, stats in results.items():
    print(f"{name:15} | Mean: {stats['mean']:8.4f} | Var: {stats['variance']:8.4f} | Time: {stats['eval_time']:6.3f}s")

In [None]:
# Visualize the distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, (name, stats) in enumerate(results.items()):
    axes[i].hist(jnp.real(stats['log_psi']), bins=50, alpha=0.7, density=True)
    axes[i].set_title(f'{name}\nlog|ψ| Distribution')
    axes[i].set_xlabel('log|ψ|')
    axes[i].set_ylabel('Density')
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Parameter count comparison
print("\nParameter Count Comparison:")
print("-" * 40)
param_counts = {
    'RBM': N + alpha * N + N * alpha * N,
    'Deep FFNN': total_params
}

for name, count in param_counts.items():
    print(f"{name:15}: {count:6d} parameters")

## 8. Symmetries and Physical Constraints {#symmetries}

Real quantum systems often have symmetries that can be incorporated into NQS for better performance and physical accuracy.

In [None]:
# Example: Translation-invariant RBM
# For periodic systems, we can enforce translation symmetry

class TranslationInvariantRBM(nn.Module):
    """
    RBM with enforced translation symmetry.
    """
    alpha: int = 1
    dtype: Any = jnp.complex128
    
    @nn.compact
    def __call__(self, x):
        N = x.shape[-1]
        
        # Visible bias (same for all sites)
        a = self.param('visible_bias', nn.initializers.normal(stddev=0.1), (1,), self.dtype)
        visible_contrib = a[0] * jnp.sum(x, axis=-1)
        
        # Hidden units with translation symmetry
        # Each hidden unit connects to a local pattern that's translated
        W = self.param('weights', nn.initializers.normal(stddev=0.1), (N, self.alpha), self.dtype)
        b = self.param('hidden_bias', nn.initializers.normal(stddev=0.1), (self.alpha,), self.dtype)
        
        hidden_activations = []
        for shift in range(N):
            # Translate the weights
            x_shifted = jnp.roll(x, shift, axis=-1)
            for j in range(self.alpha):
                activation = b[j] + jnp.sum(W[:, j] * x_shifted, axis=-1)
                hidden_activations.append(jnp.log(jnp.cosh(activation)))
        
        hidden_contrib = jnp.sum(jnp.array(hidden_activations), axis=0)
        
        return visible_contrib + hidden_contrib

# Create and test the symmetric RBM
sym_rbm = TranslationInvariantRBM(alpha=1)
key, subkey = jax.random.split(key)
sym_params = sym_rbm.init(subkey, sample_input[0])

# Test translation invariance
test_config = jnp.array([1, -1, 1, -1, 1, -1, 1, -1])
shifted_config = jnp.roll(test_config, 1)

log_psi_orig = sym_rbm.apply(sym_params, test_config)
log_psi_shifted = sym_rbm.apply(sym_params, shifted_config)

print(f"Original config: {test_config}")
print(f"Shifted config:  {shifted_config}")
print(f"log ψ (original): {log_psi_orig:.6f}")
print(f"log ψ (shifted):  {log_psi_shifted:.6f}")
print(f"Difference: {abs(log_psi_orig - log_psi_shifted):.8f}")
print(f"Translation invariant: {abs(log_psi_orig - log_psi_shifted) < 1e-10}")

## 9. Conclusion {#conclusion}

In this notebook, we've explored various Neural Quantum State architectures in NetKet:

### Key Takeaways:

1. **RBMs** are simple, effective, and well-studied NQS architectures
2. **Deep neural networks** offer more expressivity but require more parameters
3. **Complex-valued networks** are essential for representing quantum phases
4. **Autoregressive models** provide exact sampling capabilities
5. **Symmetries** can be incorporated for better physical accuracy and efficiency

### Practical Considerations:

- **Parameter count**: More parameters don't always mean better performance
- **Initialization**: Proper initialization is crucial for training stability
- **Architecture choice**: Depends on the specific quantum system and properties of interest
- **Computational cost**: Balance between expressivity and evaluation speed

### Next Steps:

- Explore variational optimization with these NQS (covered in VMC notebook)
- Investigate ground state and excited state calculations
- Study quantum phase transitions
- Implement custom architectures for specific problems

The flexibility of NetKet's design allows for easy experimentation with different NQS architectures, making it an ideal platform for quantum many-body research.