In [10]:
import netket as nk
import numpy as np

import jax.numpy as jnp
from flax import linen as nn # flax is library that brings a more pytorch-like API to JAX

In [9]:
# defining the hamiltonian on a 4x4 spin lattice

g = nk.graph.Hypercube(length=4, n_dim=2, pbc=True)
hi = nk.hilbert.Spin(s=1/2, N=g.n_nodes)

# netket already provides a function to create the hamiltonian for the ising model
h = nk.operator.Ising(hi, g, h=1.0, J=1.0, dtype=complex)

In [None]:
# jastrow ansatz

class Jastrow(nn.Module):
    @nn.compact
    def __call__(self, x):
        n_sites = x.shape[-1]

        J = self.param("J", nn.initializers.normal(), (n_sites,n_sites), float)

        # gets the more restrictive type between the two, we don't know whether the initializer supports x.dtype
        dtype = jnp.promote_types(J.dtype, x.dtype)
        J = J.astype(dtype)
        x = x.astype(dtype)

        # the initialization of J is not symmetric, so we symmetrize it
        J_symm = J.T + J

        # einstein summation iterating over all possible pairs of spins ensuring that we have the matching indices for J
        total_prob = jnp.einsum("...i,ij,...j", x, J_symm, x)
        return total_prob