# Variational methods for Sherrington-Kirkpatrick model
Pan Zhang
Institute of Theoretical Physics, Chinese Academy of Sciences

We compuare variational free energy given by the variational mean-field method and the Variational Autoregressive Network (Physical Review Letters 122, 080602), for the Sherrington-Kirkpatrick spin glass model.

## SK model
$n$ denotes number of variables, $J_{ij}=\frac{1}{\sqrt{n}}\mathcal N(0,1)$ is the couplings matrix and $\beta$ is the inverse temperature. One needs to notice that diagonal terms of the couplings matrix must be zero, i.e.
$J_{ii}=0$, and the coupling matrix $\mathbf J$ is symmetic.

The following code generates a small instance of the SK model

In [203]:
import torch,math
import numpy as np
from torch import nn
from scipy.special import logsumexp
import sys
device=torch.device('cpu')
device=torch.device('cuda:0')
n=20 # number of spins
beta=0.5 # inverse temperature
seed=1
torch.manual_seed(seed)
J=torch.randn(n,n,device=device)/math.sqrt(n)
J = torch.triu(J,diagonal=1) # take the upper triangular matrix
J = J+J.t() # make the coupling matrix symmetric
J_np = J.cpu().numpy()
J.requires_grad=False

## Exact enumerations for small systems
Whent the number of spins is small, e.g. $n\leq 20$, we can compute exactly the free energy, energy and entropy by enumerating all $2^n$ configurations.

In [204]:
def cfg_id_to_sample(cfg_id):
    return np.array( [((cfg_id >> i) & 1) * 2 - 1 for i in range(n-1,-1,-1)])

def list_energy(print_step=float('inf')):
    samples=[]
    energy_arr = []
    for cfg_id in range(1 << n):
        if (cfg_id + 1) % print_step == 0:
            sys.stdout.write("\rEnumerating all configurations: %d / 100"%(int(cfg_id /(1 << n)*100)+1))
        sample = cfg_id_to_sample(cfg_id)
        energy_arr.append(sample.dot(J_np).dot(sample)/2.0)
        samples.append(sample)
        
    cfg_id_arr = np.arange(1 << (n), dtype=int)
    energy_arr = np.array(energy_arr)
    energy_arr *= -1.0
    samples = np.array(samples)

    return cfg_id_arr, energy_arr
    
def f_exact():
    if(n>20):
        return 0,0,0
    step=int((1<<n)/10.)
    arr, energy_arr = list_energy(step)
    logz = logsumexp(-beta * energy_arr)
    f=-1.0*logz/beta/n
    prob_arr=np.exp(-1.0*beta*energy_arr-logz)
    E=np.sum(prob_arr*energy_arr)/n
    S=beta*E+logz/n
    print("\nExact:\tf=%.6f\te=%.6f\ts=%.6f"%(f,E,S))
    return f,E,S



## Variational Mean-field
The joint distribution of variables $x\in\{+1,-1\}^n$ are factorized $P(x)=\prod_ip(x_i).$

Defining magnetization as $$m_i=\sum_{x_i=\{1,-1\}}p(x_i)=p(x_i=1)-p(x_i=-1)$$

Then the entropy of the factorized distribution is $$S=-\sum_i\sum_{x_i}p(x_i)\log(p(x_i))=-\sum_i\left( \frac{1+m_i}{2}\log \frac{1+m_i}{2}+ \frac{1-m_i}{2}\log \frac{1-m_i}{2} \right).$$.

The average energy under the factorized distribution is written as $$\langle E \rangle=\sum_{(ij)}\langle E_{ij}(x_i,x_j)\rangle_{x\sim p(x)}=\frac{1}{2}\sum_{(ij)}\sum_{x_i}\sum_{x_j}E_{ij}(x_i,x_j)p(x_i)p(x_j)=-\frac{1}{2}J_{ij}m_im_j.$$

Using above expressions for average energy and free energy, the variational free energy is written as
$$f=\langle E \rangle-\frac{1}{\beta}S.$$

In [205]:
def get_entropy_fact(m):
    """ get_h(m) returns entropy of a factoriz distribution of len(m) boolean variables.'
m is the magnetization, i.e. (1+m)/2 = p(+1) and (1-m)/2 = p(-1)"""
    return -1.0*torch.sum( (1+m)/2*torch.log((1+m)/2)+(1-m)/2*torch.log((1-m)/2) )

def get_free_energy_fact(m):
    entropy=get_entropy_fact(m)/n
    energy=-0.5*m.t()@J@m/n
    free_energy=energy-entropy/beta
    return [free_energy,energy,entropy]

def nmf():
    damping=0.9
    torch.manual_seed(seed)
    diff=100;
    max_iter=8000;
    conv_crit = 1.0e-8
    m=0
    m_old=torch.tanh(torch.randn(n,1,device=device)) # magnetization, randomly initialized between [-1,1]
    for i in range(max_iter):
        max_iter=i+0
        m=damping*m+(1-damping)*torch.tanh(beta*J@m_old)
        diff=torch.norm(m-m_old)
        if(diff<conv_crit):
            break
        m_old=m.clone()
        m_old=m_old.clone()
    [f_mf,e_mf,s_mf]=get_free_energy_fact(m)
    print("Naive Mean-Field:\tf=%.6f\te=%.6f\ts=%.6f\tdiff=%.3f"%(f_mf,e_mf,s_mf,diff))
    return f_mf,e_mf,s_mf,diff

f,E,S=f_exact()
f_nmf,E_nmf,s_nmf,err_nmf=nmf()

Enumerating all configurations: 100 / 100
Exact:	f=-1.507863	e=-0.244720	s=0.631572
Naive Mean-Field:	f=-1.386294	e=-0.000000	s=0.693147	diff=0.000


## Variational autoregressive network
In VAN, the variational distribution is the product of conditional distributions

$p(\mathbf x)=\prod_ip(x_i|\mathbf x_{j<i})$

Here we used a very simple version of the VAN that 

$p(\mathbf x)=\prod_i\left[\hat x_i\delta(x_i-1)+(1-\hat x_i)\delta(x_i+1)\right]$ 

with 

$\mathbf {\hat x}=\mathrm{Sigmoid}(\mathbf W\mathbf x)$,

and $W\in\mathbf {n\times n}$ is a lower-triangular matrix with diagonal terms being $0$. The REINFORCE (Williams 1992) algorithms with a simple baseline is used for computing gradients of variational free energy with respect to parameters $\mathbf W$

In [206]:
mask=torch.ones([n, n],device=device)
mask = 1 - torch.triu(mask).t() # upper triangular mask
mask.requires_grad=False
W=torch.randn(n,n,device=device)
W.requires_grad=True
optimizer = torch.optim.Adam([W], lr=1e-2)
batch_size=5000
epsilon=1e-7
print_steps=100

def do_sample(batch_size):
    samples = torch.zeros([batch_size, n],device=device)
    for i in range(n):
        x_hat = torch.sigmoid(samples@W)
        samples[:, i] = torch.bernoulli(x_hat[:, i]) * 2 - 1
    return samples,x_hat


for step in range(1000):
    optimizer.zero_grad()
    W.data =W.data*mask
    with torch.no_grad():
        samples, x_hat = do_sample(batch_size)
    x_hat = torch.sigmoid(samples@W)
    m = (samples + 1) / 2
    log_prob = (torch.log(x_hat + epsilon) * m + torch.log(1 - x_hat + epsilon) * (1 - m)).view(batch_size, -1).sum(dim=1)
    with torch.no_grad():
        energy = -0.5*torch.sum((samples@J)*samples,dim=1)
        loss = log_prob + beta * energy
    assert not energy.requires_grad
    assert not loss.requires_grad
    loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
    loss_reinforce.backward()
    optimizer.step()
    if(step % print_steps == 0):
        free_energy_mean = loss.mean() / beta / n
        free_energy_std = loss.std() / beta / n
        entropy_mean = -log_prob.mean() / n
        energy_mean = energy.mean() / n
        print("#%d\t free energy=%.6g std=%.6g energy=%.3g entropy=%.6g"%(step,free_energy_mean,free_energy_std,energy_mean,entropy_mean))



#0	 free energy=-0.904866 std=0.25451 energy=-0.0622 entropy=0.421315
#100	 free energy=-1.31669 std=0.127507 energy=-0.326 entropy=0.49556
#200	 free energy=-1.48864 std=0.0530314 energy=-0.265 entropy=0.611913
#300	 free energy=-1.50781 std=0.00295863 energy=-0.245 entropy=0.631463
#400	 free energy=-1.50784 std=0.00185979 energy=-0.243 entropy=0.632562
#500	 free energy=-1.50791 std=0.00185464 energy=-0.242 entropy=0.632705
#600	 free energy=-1.50784 std=0.0018542 energy=-0.244 entropy=0.632042
#700	 free energy=-1.50785 std=0.00187904 energy=-0.244 entropy=0.632098
#800	 free energy=-1.50782 std=0.00182032 energy=-0.244 entropy=0.631871
#900	 free energy=-1.50783 std=0.00188075 energy=-0.242 entropy=0.632923
