In [3]:
import os
os.environ["OPENBLAS_NUM_THREADS"] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ["OMP_NUM_THREADS"] = '1'
from mpi4py import MPI
import pickle
import sys
pwd = '/pscratch/sd/s/sijingdu/VMC/fermion/data'
# torch
import torch
# quimb
import quimb.tensor as qtn
import autoray as ar
from autoray import do

from vmc_torch.experiment.tn_model import *
from vmc_torch.experiment.tn_model import init_weights_uniform
from vmc_torch.sampler import MetropolisExchangeSamplerSpinful_1D_reusable
from vmc_torch.variational_state import Variational_State
from vmc_torch.optimizer import SGD, SR, Adam, SGD_momentum, DecayScheduler
from vmc_torch.VMC import VMC
from vmc_torch.hamiltonian_torch import spinful_Fermi_Hubbard_chain_torch
from vmc_torch.torch_utils import SVD,QR

# Register safe SVD and QR functions to torch
ar.register_function('torch','linalg.svd',SVD.apply)
ar.register_function('torch','linalg.qr',QR.apply)

from vmc_torch.global_var import DEBUG
from vmc_torch.utils import closest_divisible

COMM = MPI.COMM_WORLD
SIZE = COMM.Get_size()
RANK = COMM.Get_rank()

# Hamiltonian parameters
L = int(24)
D = int(4)
N_samples = int(4)
init_lr = float(1)
init_step = int(0)
final_step = int(2)

symmetry = 'Z2'
t = 1.0
U = 8.0
N_f = int(20)
n_fermions_per_spin = (N_f//2, N_f//2)
H = spinful_Fermi_Hubbard_chain_torch(L, t, U, N_f, pbc=False, n_fermions_per_spin=n_fermions_per_spin)
graph = H.graph

# TN parameters
chi = -1
dtype = torch.float64

# Load mps
skeleton = pickle.load(open(pwd+f"/L={L}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/mps_skeleton.pkl", "rb"))
mps_params = pickle.load(open(pwd+f"/L={L}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/mps_su_params.pkl", "rb"))
mps = qtn.unpack(mps_params, skeleton)
mps.apply_to_arrays(lambda x: torch.tensor(x, dtype=dtype))

# VMC sample size
N_samples = closest_divisible(N_samples, SIZE)
if (N_samples / SIZE) % 2 != 0:
    N_samples += SIZE

# model = fMPSModel(mps, dtype=dtype)
# model = fMPS_backflow_attn_Model(mps, nn_eta=1.0, nn_hidden_dim=2*L, embedding_dim=16, attention_heads=4, dtype=dtype)
# model = fMPS_backflow_attn_Tensorwise_Model_v1(mps, nn_eta=1.0, embedding_dim=16, attention_heads=4, nn_final_dim=int(L/2), dtype=dtype)
radius = 1
# model = fMPS_BFA_cluster_Model(mps, nn_eta=1.0, embedding_dim=16, attention_heads=4, nn_final_dim=int(L/2), radius=radius, dtype=dtype)
model = fMPS_BFA_cluster_Model_reuse(mps, nn_eta=1.0, embedding_dim=16, attention_heads=4, nn_final_dim=int(L/2), radius=radius, dtype=dtype)
init_std = 5e-3
seed = 2
torch.manual_seed(seed)
model.apply(lambda x: init_weights_uniform(x, a=-init_std, b=init_std))
model_names = {
    fMPSModel: 'fMPS',
    fMPS_backflow_Model: 'fMPS_backflow',
    fMPS_backflow_attn_Model: 'fMPS_backflow_attn',
    fMPS_backflow_attn_Tensorwise_Model_v1: 'fMPS_backflow_attn_Tensorwise_v1',
    fMPS_BFA_cluster_Model: f'fMPS_BFA_cluster_r={radius}',
    fMPS_BFA_cluster_Model_reuse: f'fMPS_BFA_cluster_r={radius}_reuse',
}
model_name = model_names.get(type(model), 'UnknownModel')+f'_test'


total_steps = final_step - init_step

# Load model parameters
optimizer_state = None
if init_step != 0:
    saved_model_params = torch.load(pwd+f'/L={L}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/{model_name}/chi={chi}/model_params_step{init_step}.pth', weights_only=False)
    saved_model_state_dict = saved_model_params['model_state_dict']
    saved_model_params_vec = torch.tensor(saved_model_params['model_params_vec'])
    try:
        model.load_state_dict(saved_model_state_dict)
    except:
        model.load_params(saved_model_params_vec)
    optimizer_state = saved_model_params.get('optimizer_state', None)

# Set up optimizer and scheduler
learning_rate = init_lr
scheduler = DecayScheduler(init_lr=learning_rate, decay_rate=0.9, patience=50, min_lr=1e-3)
use_prev_opt = False
if optimizer_state is not None and use_prev_opt:
    optimizer_name = optimizer_state['optimizer']
    if optimizer_name == 'SGD_momentum':
        optimizer = SGD_momentum(learning_rate=learning_rate, momentum=0.9)
    elif optimizer_name == 'Adam':
        optimizer = Adam(learning_rate=learning_rate, weight_decay=1e-5)
    print('Loading optimizer: ', optimizer)
    optimizer.lr = learning_rate
    if isinstance(optimizer, SGD_momentum):
        optimizer.velocity = optimizer_state['velocity']
    if isinstance(optimizer, Adam):
        optimizer.m = optimizer_state['m']
        optimizer.v = optimizer_state['v']
        optimizer.t = optimizer_state['t']
else:
    optimizer = SGD(learning_rate=learning_rate)

# Set up sampler
# sampler = MetropolisExchangeSamplerSpinful(H.hilbert, graph, N_samples=N_samples, burn_in_steps=40, reset_chain=False, random_edge=False, equal_partition=False, dtype=dtype)
sampler = MetropolisExchangeSamplerSpinful_1D_reusable(H.hilbert, graph, N_samples=N_samples, burn_in_steps=1, reset_chain=False, random_edge=False, equal_partition=True, dtype=dtype)
# Set up variational state
variational_state = Variational_State(model, hi=H.hilbert, sampler=sampler, dtype=dtype)
# Set up SR preconditioner
preconditioner = SR(dense=False, exact=True if sampler is None else False, use_MPI4Solver=True, solver='minres', diag_eta=1e-3, iter_step=1e3, dtype=dtype, rtol=1e-5)
# Set up VMC
vmc = VMC(hamiltonian=H, variational_state=variational_state, optimizer=optimizer, preconditioner=preconditioner, scheduler=scheduler, SWO=False, beta=0.01)



In [None]:
# model.get_local_amp_tensors([0,1,2])
random_config = torch.tensor(H.hilbert.random_state())
# model.get_local_amp_tensors([0,1,2], random_config)
# random_config[0,1,2, 5]
random_config[[0,1,2]]

tensor([0, 1, 0])