In [12]:
import netket as nk
import numpy as np

from typing import Callable, Optional
from functools import partial

import jax
import jax.numpy as jnp

from netket.utils.types import PyTree, Array
import netket.jax as nkjax

from netket.vqs.mc import (
    kernels,
    check_hilbert,
    get_local_kernel_arguments,
    get_local_kernel,
)

from netket.vqs.mc.mc_state.state import MCState

from netket.operator import (
    AbstractOperator,
    DiscreteOperator,
    Squared,
    ContinuousOperator,
)

from netket.vqs.mc import (
    get_local_kernel_arguments,
    get_local_kernel,
)
from netket.stats import Stats, statistics

In [13]:
#in this test we want to test whether the local_value kernel works as intended
def local_value_kernel(logpsi: Callable, pars: PyTree, σ: Array, args: PyTree):
    """
    local_value kernel for MCState and generic operators
    """
    σp, mel = args
    return jnp.sum(mel * jnp.exp(logpsi(pars, σp) - logpsi(pars, σ)))


def get_local_kernel_arguments(vstate: MCState, Ô: DiscreteOperator):  # noqa: F811
    check_hilbert(vstate.hilbert, Ô.hilbert)

    σ = vstate.samples
    σp, mels = Ô.get_conn_padded(σ)
    return σ, (σp, mels)


In [14]:
def CSHam(N, B, Ak):
    # Make graph with of length N with no periodic boundary conditions
    g = nk.graph.Hypercube(length=N, n_dim=1, pbc=False)
    # Spin based Hilbert Space
    hilbertSpace = nk.hilbert.Spin(s=0.5, N=g.n_nodes)
    # Define spin operators with \hbar set to 1
    sz = 0.5 * np.array([[1, 0], [0, -1]])
    sx = 0.5 * np.array([[0, 1], [1, 0]])
    sy = 0.5 * np.array([[0, -1j], [1j, 0]])
    operators = []
    sites = []
    # Central spin term
    operators.append((B * sz).tolist()) #array to list(ordered and changeable)
    sites.append([0])
    # Interaction term
    itOp = np.kron(sz, sz) + np.kron(sx, sx) + np.kron(sy, sy) #kronecker product here
    for i in range(N - 1):
        operators.append((Ak[i] * itOp).tolist())
        sites.append([0, (i+1)])  #pretty convoluted indexing, but ok
    # Create hamiltonian
    hamiltonian = nk.operator.LocalOperator(hilbertSpace, operators=operators, acting_on=sites, dtype=complex)
    #acting_on specifier necessary as this is a central spin model
    return hamiltonian, hilbertSpace

In [26]:
#machine
ma = nk.models.RBM(alpha=1, dtype=complex,use_visible_bias=True, use_hidden_bias=True)

#get sampler

N=3
alpha = 1   #density of RBM
M = alpha*N
# Constant A
B = 0.95
# Variable A
#B=N/2
#A = N/2
#N0 = N/2
# List of Ak
Ak = []
for i in range(N - 1):
    # Constant A
    Ak_i = 1
    # Variable A
    #Ak_i = A / (N0) * np.exp(-i / N0)
    Ak.append(Ak_i)

ha, hi = CSHam(N,B,Ak)
sampler = nk.sampler.MetropolisLocal(hilbert=hi)

vs = nk.variational.MCState(sampler, ma, n_samples=3000, n_discard=300)
vs.init_parameters(nk.nn.initializers.normal(stddev=0.25))

  vs.init_parameters(nk.nn.initializers.normal(stddev=0.25))


In [27]:

σ, args = get_local_kernel_arguments(vs, ha)
logpsi = vs._apply_fun
pars = vs.parameters
model_state = vs.model_state

σ_shape = σ.shape
if jnp.ndim(σ) != 2:
    σ = σ.reshape((-1, σ_shape[-1]))

In [17]:
local_estimator_fun = get_local_kernel(vs, ha)

#now we can get the output
E_loc = local_estimator_fun(
        logpsi,
        {"params": pars, **model_state},
        σ,
        args,
    )
print(σ.shape)
print(E_loc.shape)
print(E_loc.reshape(σ_shape[:-1]).T.shape)

(3008, 3)
(3008,)
(16, 188)


In [18]:
a = np.array(args[0])
print(a.shape)

(188, 16, 3, 3)


In [19]:
print(logpsi({"params": pars, **model_state}, args[0]).shape)

(188, 16, 3)


In [20]:
print(logpsi({"params": pars, **model_state},σ ).shape)

(3008,)


In [21]:
import inspect
lines = inspect.getsource(local_estimator_fun)
print(lines)

    def vmapped_kernel(logpsi, pars, σ, args):
        """
        local_value kernel for MCState and generic operators
        """
        σp, mels = args

        if jnp.ndim(σp) != 3:
            σp = σp.reshape((σ.shape[0], -1, σ.shape[-1]))
            mels = mels.reshape(σp.shape[:-1])

        vkernel = jax.vmap(kernel, in_axes=(None, None, 0, (0, 0)), out_axes=0)
        return vkernel(logpsi, pars, σ, (σp, mels))



In [22]:
#everything above was for checking how it works, now we check our own kernel function
vs1 = nk.variational.MCState(sampler, ma, n_samples=1000, n_discard=100)
vs1.init_parameters(nk.nn.initializers.normal(stddev=0.25))

vs2 = nk.variational.MCState(sampler, ma, n_samples=1000, n_discard=100)
vs2.init_parameters(nk.nn.initializers.normal(stddev=0.25))

  vs1.init_parameters(nk.nn.initializers.normal(stddev=0.25))
  vs2.init_parameters(nk.nn.initializers.normal(stddev=0.25))


In [23]:
def penalty_kernel(logpsi: Callable, pars1: PyTree, pars2: PyTree, σ2: Array):
    return jnp.exp(logpsi(pars1, σ2) - logpsi(pars2, σ2))

In [24]:
logpsi = vs1._apply_fun
pars1 = vs1.parameters
pars2 = vs2.parameters
model_state1 = vs1.model_state
model_state2 = vs2.model_state

σ2 = vs2.samples
σ2_shape = σ2.shape
if jnp.ndim(σ2) != 2:
    σ2 = σ2.reshape((-1, σ2_shape[-1]))

psi_loc = penalty_kernel(
        logpsi,
        {"params": pars1, **model_state1},
        {"params": pars2, **model_state2},
        σ2,
    )

psi = statistics(psi_loc)
print(psi.mean)

(0.5318746540769356-0.08575881979458345j)


In [30]:
c = nk.stats.statistics(psi_loc.reshape(σ_shape[:-1]).T)
print(c)

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (1008,) and (188, 16)

In [32]:
print(logpsi)

<hashable partial maybe_scalar_fun with args=(<hashable partial <lambda> with args=(RBM(
    # attributes
    dtype = complex
    activation = log_cosh
    alpha = 1
    use_hidden_bias = True
    use_visible_bias = True
    precision = None
    kernel_init = init
    hidden_bias_init = init
    visible_bias_init = init
),) and kwargs={}, hash=-896942998488472433>,) and kwargs={}, hash=-6737409820451515932>


In [34]:
logpsi1 = vs1._apply_fun
logpsi2 = vs2._apply_fun

In [36]:
#let's try cheating, by pre-setting other states' parameters and input only one of the state parameters
def initial(pars_1, σ):
    return logpsi1(pars_1,σ) + logpsi2(pars2,σ)
    
    

In [38]:
a = sampler.sample(machine = initial, parameters = pars1)

ApplyScopeInvalidVariablesError: The first argument passed to an apply function should be a dictionary of collections. Each collection should be a dictionary with string keys. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesError)

In [38]:
#we define necessary functions here
def batch_discrete_kernel(kernel):
    """
    Batch a kernel that only works with 1 sample so that it works with a
    batch of samples.
    Works only for discrete-kernels who take two args as inputs
    """

    def vmapped_kernel(logpsi_1, logpsi_2, pars_1,pars_2, σ, args):
        """
        local_value kernel for MCState and generic operators
        """
        σp, mels = args

        if jnp.ndim(σp) != 3:
            σp = σp.reshape((σ.shape[0], -1, σ.shape[-1]))
            mels = mels.reshape(σp.shape[:-1])

        vkernel = jax.vmap(kernel, in_axes=(None, None,None,None, 0, (0, 0)), out_axes=0)
        return vkernel(logpsi_1, logpsi_2, pars_1, pars_2, σ, (σp, mels))

    return vmapped_kernel

@batch_discrete_kernel
def local_value_kernel(logpsi_1: Callable, logpsi_2: Callable, pars_1: PyTree, pars_2: PyTree, σ: Array, args: PyTree):
    """
    local_value kernel for MCState and generic operators
    """
    σp, mel = args
    return jnp.sum(mel * jnp.exp(logpsi_1(pars_1, σp) - logpsi_2(pars_2, σ)))

In [40]:
σ, args = get_local_kernel_arguments(vs, ha)
logpsi = vs._apply_fun
pars = vs.parameters
model_state_1 = vs1.model_state
model_state_2 = vs2.model_state

logpsi_1 = vs1._apply_fun
logpsi_2 = vs2._apply_fun

pars_1 = vs1.parameters
pars_2 = vs2.parameters

σ_shape = σ.shape
if jnp.ndim(σ) != 2:
    σ = σ.reshape((-1, σ_shape[-1]))

E_loc = local_value_kernel(
        logpsi_1,
        logpsi_2,
        {"params": pars_1, **model_state_1},
        {"params": pars_2, **model_state_2},
        σ,
        args,
    )

print(E_loc)

[0.01171846-0.00381545j 0.38957479+0.52711731j 0.01171846-0.00381545j ...
 0.01171846-0.00381545j 0.9611983 +0.10100919j 1.57259989+1.0424999j ]
