# [Defining the quantum system](https://www.netket.org/tutorials/jax.html)

 Transverse-field Ising model defined on a graph with random edges

In [8]:
# ensure we run on the CPU and not on the GPU
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [2]:

import netket as nk

In [4]:
#Define a random graph
n_nodes=10
n_edges=20
from numpy.random import choice
rand_edges=[choice(n_nodes, size=2,replace=False).tolist() for i in range(n_edges)]

graph=nk.graph.Graph(nodes=[i for i in range(n_nodes)], edges=rand_edges)


#Define the local hilbert space
hi=nk.hilbert.Spin(s=0.5)**graph.n_nodes

In [7]:
hi

Spin(s=1/2, N=10)

In [6]:

#Define the Hamiltonian as a sum of local operators
from netket.operator import LocalOperator as Op

# Pauli Matrices
sx = [[0, 1], [1, 0]]
sz = [[1, 0], [0, -1]]


# Defining the Hamiltonian as a LocalOperator acting on the given Hilbert space
ha = Op(hi)

#Adding a transverse field term on each node of the graph
for i in range(graph.n_nodes):
    ha += Op(hi, sx, [i])

#Adding nearest-neighbors interactions on the edges of the given graph
from numpy import kron
J=0.5
for edge in graph.edges():
    ha += J*Op(hi, kron(sz, sz), edge)




# Defining a JAX module to be used as a wave function

In [9]:
import jax
from jax.experimental import stax

#We define a custom layer that performs the sum of its inputs
def SumLayer():
    def init_fun(rng, input_shape):
        output_shape = (-1, 1)
        return output_shape, ()

    def apply_fun(params, inputs, **kwargs):
        return inputs.sum(axis=-1)

    return init_fun, apply_fun

#We construct a fully connected network with tanh activation
model=stax.serial(stax.Dense(2 * graph.n_nodes, W_init=nk.nn.initializers.normal(stddev=0.1, dtype=complex),
                             b_init=nk.nn.initializers.normal(stddev=0.1, dtype=complex)),
                  stax.Tanh,SumLayer())

# Train the neural network to find an approximate ground state

In [10]:
# Defining a sampler that performs local moves
# NetKet automatically dispatches here to MCMC sampler written using JAX types
sa = nk.sampler.MetropolisLocal(hilbert=hi, n_chains=2)

In [14]:
# Defining a sampler that performs local moves
# NetKet automatically dispatches here to MCMC sampler written using JAX types
sa = nk.sampler.MetropolisLocal(hilbert=hi, n_chains=2)

# Construct the variational state
vs = nk.variational.MCState(sa, model, n_samples=1000)

  _warn_deprecation(


In [18]:
# Using Sgd
# Also dispatching to JAX optimizer
op = nk.optimizer.Sgd(learning_rate=0.01)

# Using Stochastic Reconfiguration a.k.a. quantum natural gradient
# Also dispatching to a pure JAX version
sr = nk.optimizer.SR(diag_shift=0.01)

# Create the Variational Monte Carlo instance to learn the ground state
vmc = nk.VMC(
    hamiltonian=ha, optimizer=op, variational_state=vs, preconditioner=sr
)

In [19]:
# Running the learning loop and printing the energy every 50 steps
# [notice that the very first iteration is slow because of JIT compilation]
for it in vmc.iter(500,50):
    print(it,vmc.energy)

0 5.98-0.05j ± 0.13 [σ²=11.59, R̂=0.9998]
50 -6.57-0.09j ± 0.13 [σ²=9.37, R̂=0.9993]
100 -9.312-0.064j ± 0.046 [σ²=1.689, R̂=0.9997]
150 -10.067-0.000j ± 0.039 [σ²=0.796, R̂=1.0050]
200 -10.415+0.003j ± 0.041 [σ²=1.115, R̂=0.9997]
250 -10.711-0.005j ± 0.021 [σ²=0.312, R̂=0.9991]
300 -10.843-0.025j ± 0.018 [σ²=0.246, R̂=0.9990]
350 -10.947-0.014j ± 0.015 [σ²=0.182, R̂=0.9991]
400 -11.146-0.009j ± 0.022 [σ²=0.322, R̂=1.0004]
450 -11.366-0.009j ± 0.021 [σ²=0.193, R̂=0.9998]


In [20]:
import scipy
exact_ens=scipy.sparse.linalg.eigsh(ha.to_sparse(),k=1,which='SA',return_eigenvectors=False)
print("Exact energy is : ",exact_ens[0])
print("Relative error is : ", (abs((vmc.energy.mean-exact_ens[0])/exact_ens[0])))

Exact energy is :  -11.550612728038686
Relative error is :  0.006060526726524103
