In [7]:
import netket as nk
import numpy as np
from netket.operator.spin import sigmax, sigmaz
from scipy.sparse.linalg import eigsh

In [12]:
import jax
import netket as nk
import numpy as np
from jax import vmap
from netket.operator.spin import sigmax, sigmaz
import time
from jax import lax
from functools import partial
from model.model_utlis import *
import netket.nn as nknn
import flax.linen as nn
import jax.numpy as jnp
from scipy.sparse.linalg import eigsh
import argparse
import optax
import itertools
from jax.random import PRNGKey, split
import pickle

jax.config.update("jax_enable_x64", False)

parser = argparse.ArgumentParser()
parser.add_argument('--L', type = int, default = 16)
parser.add_argument('--p', type = int, default=1)
parser.add_argument('--numsamples', type = int, default = 2048)
parser.add_argument('--alpha', type = int, default=16)
parser.add_argument('--nchain_per_rank', type = int, default=8)
parser.add_argument('--numsteps', type = int, default=5000)
parser.add_argument('--dmrg', type = bool, default=False)
parser.add_argument('--H_type', type = str, default="cluster")
parser.add_argument('--angle', type = float, default=0.0)
parser.add_argument('--previous_training', type = bool, default=False)
args = parser.parse_args()

L = args.L
p = args.p
N = L
numsamples = args.numsamples
alpha = args.alpha
nchain_per_rank = args.nchain_per_rank
numsteps = args.numsteps
dmrg = args.dmrg
H_type = args.H_type
previous_training = args.previous_training
angle = args.angle
ang = round(angle, 3)

if dmrg == True:
    if H_type == "ES":
        M0 = jnp.load("DMRG/mps_tensors/ES_tensor_init_" + str(L * p) + "_angle_" + str(ang) + ".npy")
        M = jnp.load("DMRG/mps_tensors/ES_tensor_" + str(L * p) + "_angle_" + str(ang) + ".npy")
        Mlast = jnp.load("DMRG/mps_tensors/ES_tensor_last_" + str(L * p) + "_angle_" + str(ang) + ".npy")
    else:
        M0 = jnp.load("DMRG/mps_tensors/cluster_tensor_init_" + str(L * p) + "_angle_" + str(ang) + ".npy")
        M = jnp.load("DMRG/mps_tensors/cluster_tensor_" + str(L * p) + "_angle_" + str(ang) + ".npy")
        Mlast = jnp.load("DMRG/mps_tensors/cluster_tensor_last_" + str(L * p) + "_angle_" + str(ang) + ".npy")
    batch_log_phase_dmrg = jax.jit(vmap(partial(log_phase_dmrg, M0=M0, M=M, Mlast=Mlast, netket= True), 0))


    def uniform_init(scale):
        def init(key, shape, dtype=jnp.float32):
            return jax.random.uniform(key, shape, dtype=dtype, minval=-scale, maxval=scale)

        return init
    class RBM_dmrg_model(nn.Module):
        def setup(self):
            self.dense = nn.Dense(
            features=alpha * L,
            param_dtype=jnp.float32,
            kernel_init= uniform_init(1/jnp.sqrt(alpha * L)),
            bias_init= uniform_init(1/jnp.sqrt(alpha * L))
            )
            self.ai = jnp.zeros(L, dtype=jnp.float32)

        def __call__(self, x):
            y = self.dense(x)  # x shape: (batch_size, input_dim)
            y = nk.nn.activation.log_cosh(y)
            y_sum = jnp.sum(y, axis=-1).astype(jnp.complex64)
            # Apply batch_log_phase_dmrg to x
            phase_corrections = batch_log_phase_dmrg(x)
            y_sum += phase_corrections
            y_sum += jnp.dot(x, self.ai)
            return y_sum
    ma = RBM_dmrg_model()
else:
    ma = nk.models.RBM(alpha=alpha)

hi = nk.hilbert.Spin(s=1 / 2, N=N)
g = nk.graph.Hypercube(length=N, n_dim=1, pbc=False)

if H_type == "ES":
    h = - np.cos(angle) ** 2 * sigmaz(hi, 0) @ sigmax(hi, 1)
    h -= np.cos(angle) * np.sin(angle) * sigmaz(hi, 0) @ sigmaz(hi, 1)
    h += np.sin(angle) ** 2 * sigmax(hi, 0) @ sigmaz(hi, 1)
    h += np.cos(angle) * np.sin(angle) * sigmax(hi, 0) @ sigmax(hi, 1)

    # Last set of terms
    h -= np.cos(angle) ** 2 * sigmax(hi, L - 2) @ sigmax(hi, L - 1)
    h -= np.cos(angle) * np.sin(angle) * sigmax(hi, L - 2) @ sigmaz(hi, L - 1)
    h -= np.sin(angle) ** 2 * sigmaz(hi, L - 2) @ sigmaz(hi, L - 1)
    h -= np.cos(angle) * np.sin(angle) * sigmaz(hi, L - 2) @ sigmax(hi, L - 1)

    # Middle set of terms (for j = 1 to N-3)
    for j in range(1, L - 2):
        h -= np.cos(angle) ** 3 * sigmax(hi, j - 1) @ sigmaz(hi, j) @ sigmax(hi, j + 1)
        h -= np.cos(angle) ** 2 * np.sin(angle) * sigmaz(hi, j - 1) @ sigmaz(hi, j) @ sigmax(hi, j + 1)
        h += np.cos(angle) ** 2 * np.sin(angle) * sigmax(hi, j - 1) @ sigmax(hi, j) @ sigmax(hi, j + 1)
        h -= np.cos(angle) ** 2 * np.sin(angle) * sigmax(hi, j - 1) @ sigmaz(hi, j) @ sigmaz(hi, j + 1)
        h += np.cos(angle) * np.sin(angle) ** 2 * sigmaz(hi, j - 1) @ sigmax(hi, j) @ sigmax(hi, j + 1)
        h += np.cos(angle) * np.sin(angle) ** 2 * sigmax(hi, j - 1) @ sigmax(hi, j) @ sigmaz(hi, j + 1)
        h -= np.cos(angle) * np.sin(angle) ** 2 * sigmaz(hi, j - 1) @ sigmaz(hi, j) @ sigmaz(hi, j + 1)
        h += np.sin(angle) ** 3 * sigmaz(hi, j - 1) @ sigmax(hi, j) @ sigmaz(hi, j + 1)
    j += 1

    h -= np.cos(angle) ** 3 * sigmax(hi, j - 1) @ sigmaz(hi, j) @ sigmaz(hi, j + 1)
    h -= np.cos(angle) ** 2 * np.sin(angle) * sigmaz(hi, j - 1) @ sigmaz(hi, j) @ sigmaz(hi, j + 1)
    h += np.cos(angle) ** 2 * np.sin(angle) * sigmax(hi, j - 1) @ sigmax(hi, j) @ sigmaz(hi, j + 1)
    h += np.cos(angle) ** 2 * np.sin(angle) * sigmax(hi, j - 1) @ sigmaz(hi, j) @ sigmax(hi, j + 1)
    h += np.cos(angle) * np.sin(angle) ** 2 * sigmaz(hi, j - 1) @ sigmax(hi, j) @ sigmaz(hi, j + 1)
    h += np.cos(angle) * np.sin(angle) ** 2 * sigmaz(hi, j - 1) @ sigmaz(hi, j) @ sigmax(hi, j + 1)
    h -= np.cos(angle) * np.sin(angle) ** 2 * sigmax(hi, j - 1) @ sigmax(hi, j) @ sigmax(hi, j + 1)
    h -= np.sin(angle) ** 3 * sigmaz(hi, j - 1) @ sigmax(hi, j) @ sigmax(hi, j + 1)

elif H_type == "cluster":
    h = - np.cos(angle) ** 2 *sigmaz(hi, N-1)* sigmax(hi, 0) @ sigmaz(hi, 1)
    h -= np.cos(angle) * np.sin(angle) * sigmaz(hi, 0) @ sigmaz(hi, 1)
    h += np.cos(angle) * np.sin(angle) * sigmax(hi, 0) @ sigmax(hi, 1)
    h += np.sin(angle) ** 2 * sigmaz(hi, 0) @ sigmax(hi, 1)

    # Middle set of terms (for j = 1 to N-3)
    for j in range(1, L - 1):
        h -= np.cos(angle) ** 3 * sigmaz(hi, j - 1) @ sigmax(hi, j) @ sigmaz(hi, j + 1)
        h += np.cos(angle) ** 2 * np.sin(angle) * sigmax(hi, j - 1) @ sigmax(hi, j) @ sigmaz(hi, j + 1)
        h -= np.cos(angle) ** 2 * np.sin(angle) * sigmaz(hi, j - 1) @ sigmaz(hi, j) @ sigmaz(hi, j + 1)
        h += np.cos(angle) ** 2 * np.sin(angle) * sigmaz(hi, j - 1) @ sigmax(hi, j) @ sigmax(hi, j + 1)
        h += np.cos(angle) * np.sin(angle) ** 2 * sigmax(hi, j - 1) @ sigmaz(hi, j) @ sigmaz(hi, j + 1)
        h -= np.cos(angle) * np.sin(angle) ** 2 * sigmax(hi, j - 1) @ sigmax(hi, j) @ sigmax(hi, j + 1)
        h += np.cos(angle) * np.sin(angle) ** 2 * sigmaz(hi, j - 1) @ sigmaz(hi, j) @ sigmax(hi, j + 1)
        h -= np.sin(angle) ** 3 * sigmax(hi, j - 1) @ sigmaz(hi, j) @ sigmax(hi, j + 1)

    h -= np.cos(angle) ** 2 * sigmaz(hi, j) * sigmax(hi, j + 1) * sigmaz(hi, 0)
    h += np.cos(angle) * np.sin(angle) * sigmax(hi, j) @ sigmax(hi, j + 1)
    h -= np.cos(angle) * np.sin(angle) * sigmaz(hi, j) @ sigmaz(hi, j + 1)
    h += np.sin(angle) ** 2 * sigmax(hi, j) @ sigmaz(hi, j + 1)

#sampler
sa = nk.sampler.MetropolisLocal(hilbert=hi,
                                n_chains_per_rank = nchain_per_rank)
#learning rate schedule
schedule = optax.warmup_cosine_decay_schedule(init_value=1e-3,
                                              peak_value=3e-3,
                                              warmup_steps = 1000,
                                              decay_steps = 4500,
                                              end_value = 1e-4)
# Optimizer

# Stochastic Reconfiguration
if previous_training == False:
    op = nk.optimizer.Sgd(learning_rate=schedule)
    sr = nk.optimizer.SR(diag_shift = optax.linear_schedule(init_value = 0.03,
                                                            end_value = 0.001,
                                                            transition_steps = 1000))
else:
    op = nk.optimizer.Sgd(learning_rate=5e-4)
    sr = nk.optimizer.SR(diag_shift=0.001)

# The variational state

vs = nk.vqs.MCState(sa, ma, n_samples=numsamples)
if previous_training == True:
    with open(f"params/params_model1D_RBM_Htype{H_type}_L{L}_units{alpha}_batch{numsamples}_dmrg{dmrg}_angle{ang}.pkl", "rb") as f:
        params = pickle.load(f)
    vs.parameters = params

# The ground-state optimization loop
gs = nk.VMC(
    hamiltonian=h,
    optimizer=op,
    preconditioner=sr,
    variational_state=vs)

start = time.time()
gs.run(out='RBM' + "_angle=" + str(ang)+ "_L=" +str(N)+"_numsample="+str(numsamples), n_iter=numsteps)

with open(f"params/params_model1D_RBM_Htype{H_type}_L{L}_units{alpha}_batch{numsamples}_dmrg{dmrg}_angle{ang}.pkl", "wb") as f:
    pickle.dump(vs.parameters, f)

if N<= 20:
    combinations = np.array(list(itertools.product([-1, 1], repeat=N)))
    np.save("RBM" + "angle_=" + str(ang) +"_L="+str(N) + "_numsample"+str(numsamples)+"_amp.npy", vs.log_value(combinations))

end = time.time()
print('### RBM calculation')
print('Has', vs.n_parameters, 'parameters')
print('The RBM calculation took', end - start, 'seconds')


usage: ipykernel_launcher.py [-h] [--L L] [--p P] [--numsamples NUMSAMPLES]
                             [--alpha ALPHA]
                             [--nchain_per_rank NCHAIN_PER_RANK]
                             [--numsteps NUMSTEPS] [--dmrg DMRG]
                             [--H_type H_TYPE] [--angle ANGLE]
                             [--previous_training PREVIOUS_TRAINING]
ipykernel_launcher.py: error: unrecognized arguments: -f C:\Users\Administrator\AppData\Roaming\jupyter\runtime\kernel-d970d0df-3a25-40d0-963f-2d735c37f50d.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [17]:
import pickle
H_type = "ES"
L = 16
alpha = 16
numsamples = 2048
dmrg = False
ang = 0.0
with open(f"params/params_model1D_RBM_Htypecluster_L16_units16_batch2048_dmrgFalse_angle0.0.pkl", "rb") as f:
    params = pickle.load(f)

In [18]:
print(params)

{'Dense': {'bias': Array([-2.38109734e-02, -2.67544179e-03,  1.10629585e-03,  9.13409109e-04,
        5.30947931e-03, -1.59629900e-03,  8.94406170e-04, -3.88958189e-03,
       -1.01767005e-02, -4.95980727e-03,  7.68641068e-04,  9.58609954e-03,
        1.65765285e-02, -4.29517357e-03,  9.58315935e-03, -2.79951338e-02,
        1.31105846e-02, -6.31156238e-03,  4.30361228e-03,  1.39055625e-02,
        4.38441290e-03,  6.66679395e-03, -2.32704403e-03,  1.20315934e-02,
        3.81565746e-03,  9.21721850e-03, -3.94251151e-03, -1.32363371e-03,
       -6.17913017e-03, -1.35320295e-02, -2.40189373e-03,  1.23422779e-02,
       -4.36091749e-03, -6.37284154e-03,  2.85035484e-02,  1.24201784e-02,
       -1.39015168e-02,  9.98800155e-03,  3.51977855e-04, -1.12936126e-04,
        1.17451269e-02,  1.43117795e-04, -7.86920451e-03, -6.56824326e-03,
       -5.72321611e-03, -6.55319425e-04, -7.70693785e-03,  2.44107768e-02,
        9.15797241e-03, -1.26070541e-03,  1.39368568e-02,  1.61644164e-02,
      

In [30]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from jax.nn.initializers import uniform

# Custom complex uniform initializer
def complex_uniform_init(scale=1e-2, dtype=jnp.complex64):
    def init(key, shape, dtype=dtype):
        key_real, key_imag = jax.random.split(key)  # Split PRNG key
        real_part = uniform(-scale, scale)(key_real, shape, dtype=jnp.float32)
        imag_part = uniform(-scale, scale)(key_imag, shape, dtype=jnp.float32)
        return real_part + 1j * imag_part
    return init

# Define a complex MLP model in Flax
class ComplexMLP(nn.Module):
    n_hidden_units: int  # Number of hidden units
    @nn.compact
    def __call__(self, x):
        # Get the number of visible units from input
        n_visible_units = x.shape[-1]

        # Define complex weights matrix W (hidden units, visible units)
        W = self.param("W", complex_uniform_init(),  (n_visible_units, self.n_hidden_units), jnp.complex64)
        # Define complex hidden biases b (hidden units)
        b = self.param("b", complex_uniform_init(),  (self.n_hidden_units,), jnp.complex64)
        # Define complex visible biases a (visible units)
        a = self.param("a", complex_uniform_init(),  (n_visible_units,), jnp.complex64)

        # Compute hidden layer pre-activation: W @ x + b
        hidden_pre_activation = jnp.dot(x, W) + b  # W: (n_hidden_units, n_visible_units), x: (batch, n_visible_units)
        # Apply non-linearity (log_cosh for complex inputs)
        hidden_activation = nk.nn.activation.log_cosh(hidden_pre_activation)
        # Sum the activations over the hidden units (axis=-1)
        y_sum = jnp.sum(hidden_activation, axis=-1).astype(jnp.complex64)
        # Add the contribution from the visible bias term
        y_sum += jnp.dot(x, a)

        return y_sum

# Example of creating and initializing the MLP
key = jax.random.PRNGKey(0)
x = jnp.array([[1.0, 2.0, 3.0], [3.,4., 5.]], dtype=jnp.float32)  # Input example, shape (batch_size, input_dim)

# Instantiate the model
model = ComplexMLP(n_hidden_units = 6)

# Initialize the model parameters
variables = model.init(key, x)

# Forward pass
output = model.apply(variables, x)

print("Output of the complex MLP:", output)
print(output.shape)

Output of the complex MLP: [-0.0386328 -0.02924762j -0.07583031-0.04051371j]
(2,)


In [35]:
class JasShort(nn.Module):
    @nn.compact
    def __call__(self, x):
        
        # Define the two variational parameters J1 and J2
        j1 = self.param(
            "j1", nn.initializers.normal(), (1,), float
        )
        j2 =self.param(
            "j2", nn.initializers.normal(), (1,), float
        )

        # compute the nearest-neighbor correlations
        corr1=x*jnp.roll(x,-1,axis=-1)
        corr2=x*jnp.roll(x,-2,axis=-1)

        # sum the output
        return jnp.sum(j1*corr1+j2*corr2,axis=-1)
    
model=JasShort()

vstate = nk.vqs.MCState(sampler, model, n_samples=1008)

NameError: name 'sampler' is not defined

In [34]:
import jax, jax.numpy as jnp
initializer = jax.nn.initializers.truncated_normal(1.0)
initializer(jax.random.key(42), (2, 3), jnp.complex64)  

ValueError: dtype argument to `truncated_normal` must be a float dtype, got complex64

In [5]:
import flax.linen as nn
import jax.numpy as jnp

class MyModule(nn.Module):
    func: callable
    param: float


    def __call__(self, x):
        # Use self.func to dynamically choose a function
        return self.func(x)
def some_func(x):
    return x * 2

model = MyModule(func=some_func, param=1.0)
output = model(jnp.array([1.0, 2.0, 3.0]))
