In [2]:
%pip install --quiet netket





In [1]:
import os
import netket as nk
import jax
os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [2]:
N_chain = 10
hi = nk.hilbert.Spin(s=1 / 2, N=N_chain)

In [5]:
from netket.operator.spin import sigmax,sigmaz 

Gamma = -4
V=-1

H = sum([Gamma*sigmax(hi,i) for i in range(N_chain)])
H += sum([V*sigmaz(hi,i)*sigmaz(hi,(i+1)%N_chain) for i in range(N_chain)])

In [6]:
from scipy.sparse.linalg import eigsh

sp_h=H.to_sparse()
sp_h.shape

eig_vals, eig_vecs = eigsh(sp_h, k=2, which="SA")

print("Energy per spin with scipy sparse:", eig_vals[0]/N_chain)

E_gs = eig_vals[0]

Energy per spin with scipy sparse: -4.062748120226056


In [15]:
import jax.numpy as jnp
import flax
import flax.linen as nn

class FFN(nn.Module):
    
    # You can define attributes at the module-level
    # with a default. This allows you to easily change
    # some hyper-parameter without redefining the whole 
    # flax module.
    alpha : int = 1
            
    @nn.compact
    def __call__(self, x):

        # here we construct the first dense layer using a
        # pre-built implementation in flax.
        # features is the number of output nodes
        # WARNING: Won't work with complex hamiltonians because
        # of a bug in flax. Use nk.nn.Dense otherwise. 
        dense = nn.Dense(features=self.alpha * x.shape[-1])
        
        # we apply the dense layer to the input
        y = dense(x)

        # the non-linearity is a simple ReLu
        y = nn.relu(y)
                
        # sum the output
        return jnp.sum(y, axis=-1)
    
sampler = nk.sampler.MetropolisLocal(hi)
model = FFN(alpha=1)

vstate = nk.vqs.MCState(sampler, model, n_samples=1008)

In [29]:
optimizer = nk.optimizer.Sgd(learning_rate=0.1)

# Notice the use, again of Stochastic Reconfiguration, which considerably improves the optimisation
gs = nk.driver.VMC(H, optimizer, variational_state=vstate,preconditioner=nk.optimizer.SR(diag_shift=0.1))

log=nk.logging.RuntimeLog()
gs.run(n_iter=300,out=log)

ffn_energy=vstate.expect(H)
error=abs((ffn_energy.mean-eig_vals[0])/eig_vals[0])
print('Found ground state', ffn_energy.mean/N_chain)
print("Optimized energy and relative error: ",ffn_energy,error)

100%|██████████| 300/300 [00:06<00:00, 47.16it/s, Energy=-10.63492 ± 0.00021 [σ²=0.00004, R̂=1.0115]]


Found ground state -1.063598586236023
Optimized energy and relative error:  -10.63599 ± 0.00024 [σ²=0.00006, R̂=1.0195] 3.586566382153778e-05
