In [8]:
import os
os.environ['MKL_NUM_THREADS'] = '1'
os.environ["OMP_NUM_THREADS"] = '1'
# suppress warnings
import warnings
warnings.filterwarnings("ignore")
from mpi4py import MPI
import pickle
import sys

# torch
import torch
torch.autograd.set_detect_anomaly(False)

# quimb
import quimb.tensor as qtn
import autoray as ar

from vmc_torch.experiment.tn_model import fMPSModel, fMPS_backflow_Model, fMPS_backflow_attn_Tensorwise_Model_v1, fMPS_BFA_cluster_Model, fMPS_BFA_cluster_Model_reuse
from vmc_torch.experiment.tn_model import init_weights_to_zero
from vmc_torch.sampler import MetropolisExchangeSamplerSpinful, MetropolisExchangeSamplerSpinful_1D_reusable
from vmc_torch.variational_state import Variational_State
from vmc_torch.optimizer import SGD, SR, Adam, SGD_momentum, DecayScheduler
from vmc_torch.VMC import VMC
from vmc_torch.hamiltonian_torch import spinful_Fermi_Hubbard_chain_torch
from vmc_torch.torch_utils import SVD,QR
from vmc_torch.fermion_utils import generate_random_fmps, form_gated_fmps_tnf

# Register safe SVD and QR functions to torch
ar.register_function('torch','linalg.svd',SVD.apply)
ar.register_function('torch','linalg.qr',QR.apply)

from vmc_torch.global_var import DEBUG
from vmc_torch.utils import closest_divisible
pwd = '/home/sijingdu/TNVMC/VMC_code/vmc_torch/data'

COMM = MPI.COMM_WORLD
SIZE = COMM.Get_size()
RANK = COMM.Get_rank()

# Hamiltonian parameters
L = int(100)
symmetry = 'Z2'
t = 1.0
U = 8.0
N_f = int(L)
n_fermions_per_spin = (N_f//2, N_f//2)
H = spinful_Fermi_Hubbard_chain_torch(L, t, U, N_f, pbc=False, n_fermions_per_spin=n_fermions_per_spin)
graph = H.graph
# TN parameters
D = 10
chi = -1
dtype=torch.float64

# Load mps
skeleton = pickle.load(open(pwd+f"/L={L}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/mps_skeleton.pkl", "rb"))
mps_params = pickle.load(open(pwd+f"/L={L}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/mps_su_params.pkl", "rb"))
mps = qtn.unpack(mps_params, skeleton)
# fmps_tnf = form_gated_fmps_tnf(fmps=mps, ham=quimb_ham, depth=2)
mps.apply_to_arrays(lambda x: torch.tensor(2*x, dtype=dtype))

# # randomize the mps tensors
# mps.apply_to_arrays(lambda x: torch.randn_like(torch.tensor(x, dtype=dtype), dtype=dtype))

# VMC sample size
N_samples = int(4)
N_samples = closest_divisible(N_samples, SIZE)
if (N_samples/SIZE)%2 != 0:
    N_samples += SIZE

# model = fMPSModel(mps, dtype=dtype)
# model = fMPS_backflow_attn_Tensorwise_Model_v1(mps, embedding_dim=16, attention_heads=4, nn_final_dim=int(L/2), nn_eta=1.0, dtype=dtype)
model0 = fMPS_BFA_cluster_Model(mps, embedding_dim=16, attention_heads=4, nn_final_dim=int(L/2), nn_eta=1.0, radius=1, dtype=dtype)
model = fMPS_BFA_cluster_Model_reuse(mps, embedding_dim=16, attention_heads=4, nn_final_dim=int(L/2), nn_eta=1.0, radius=1, dtype=dtype)
# model = fMPS_backflow_Model(mps, nn_eta=1.0, num_hidden_layer=2, nn_hidden_dim=2*L, dtype=dtype)
init_std = 5e-2
seed = 2
torch.manual_seed(seed)
model.apply(lambda x: init_weights_to_zero(x, std=init_std))
model0.apply(lambda x: init_weights_to_zero(x, std=init_std))
# model.apply(lambda x: init_weights_kaiming(x))

model_names = {
    fMPSModel: 'fMPS',
    fMPS_backflow_Model: 'fMPS_backflow',
    fMPS_backflow_attn_Tensorwise_Model_v1: 'fMPS_backflow_attn_Tensorwise_v1',
    fMPS_BFA_cluster_Model: 'fMPS_BFA_cluster',
    fMPS_BFA_cluster_Model_reuse: 'fMPS_BFA_cluster_reuse',
}
model_name = model_names.get(type(model), 'UnknownModel')


init_step = 0
final_step = 250
total_steps = final_step - init_step
# Load model parameters
if init_step != 0:
    saved_model_params = torch.load(pwd+f'/L={L}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/{model_name}/chi={chi}/model_params_step{init_step}.pth')
    saved_model_state_dict = saved_model_params['model_state_dict']
    saved_model_params_vec = torch.tensor(saved_model_params['model_params_vec'])
    try:
        model.load_state_dict(saved_model_state_dict)
    except:
        model.load_params(saved_model_params_vec)
    optimizer_state = saved_model_params.get('optimizer_state', None)

# Set up optimizer and scheduler
learning_rate = 1e-1
scheduler = DecayScheduler(init_lr=learning_rate, decay_rate=0.9, patience=50, min_lr=1e-2)
optimizer_state = None
use_prev_opt = True
if optimizer_state is not None and use_prev_opt:
    optimizer_name = optimizer_state['optimizer']
    if optimizer_name == 'SGD_momentum':
        optimizer = SGD_momentum(learning_rate=learning_rate, momentum=0.9)
    elif optimizer_name == 'Adam':
        optimizer = Adam(learning_rate=learning_rate, weight_decay=1e-5)
    print('Loading optimizer: ', optimizer)
    optimizer.lr = learning_rate
    if isinstance(optimizer, SGD_momentum):
        optimizer.velocity = optimizer_state['velocity']
    if isinstance(optimizer, Adam):
        optimizer.m = optimizer_state['m']
        optimizer.v = optimizer_state['v']
        optimizer.t = optimizer_state['t']
else:
    # optimizer = SignedSGD(learning_rate=learning_rate)
    # optimizer = SignedRandomSGD(learning_rate=learning_rate)
    optimizer = SGD(learning_rate=learning_rate)
    # optimizer = SGD_momentum(learning_rate=learning_rate, momentum=0.9)
    # optimizer = Adam(learning_rate=learning_rate, t_step=init_step, weight_decay=1e-5)

# Set up sampler
sampler0 = MetropolisExchangeSamplerSpinful(H.hilbert, graph, N_samples=N_samples, burn_in_steps=10, reset_chain=False, random_edge=False, equal_partition=True, dtype=dtype)
sampler = MetropolisExchangeSamplerSpinful_1D_reusable(H.hilbert, graph, N_samples=N_samples, burn_in_steps=10, reset_chain=False, random_edge=False, equal_partition=True, dtype=dtype)
# Set up variational state
variational_state0 = Variational_State(model0, hi=H.hilbert, sampler=sampler0, dtype=dtype)
variational_state = Variational_State(model, hi=H.hilbert, sampler=sampler, dtype=dtype)
# Set up SR preconditioner
preconditioner = SR(dense=False, exact=True if sampler is None else False, use_MPI4Solver=True, diag_eta=1e-3, iter_step=1e5, dtype=dtype)
# preconditioner = TrivialPreconditioner()
# Set up VMC
vmc = VMC(hamiltonian=H, variational_state=variational_state, optimizer=optimizer, preconditioner=preconditioner, scheduler=scheduler, SWO=False, beta=0.01)


In [9]:
config, amp_val = sampler._sample_next(variational_state, burn_in=True)
variational_state.set_cache_env_mode(True)
variational_state.amplitude_grad(config)
variational_state.set_cache_env_mode(False)
# model.env_left_cache, model.env_right_cache,

In [10]:
etas, _ = H.get_conn(config)

In [11]:
import pyinstrument
with pyinstrument.Profiler() as prof:
    with torch.no_grad():
        eta_amps = model(etas)
prof.print()


  _     ._   __/__   _ _  _  _ _/_   Recorded: 22:59:57  Samples:  1460
 /_//_/// /_\ / //_// / //_'/ //     Duration: 1.680     CPU time: 1.680
/   _/                      v5.0.1

Profile at /tmp/ipykernel_12210/4286058046.py:2

1.679 <module>  /tmp/ipykernel_12210/4286058046.py:1
└─ 1.679 fMPS_BFA_cluster_Model_reuse._wrapped_call_impl  torch/nn/modules/module.py:1735
   └─ 1.679 fMPS_BFA_cluster_Model_reuse._call_impl  torch/nn/modules/module.py:1743
      └─ 1.679 fMPS_BFA_cluster_Model_reuse.forward  ../tn_model.py:96
         └─ 1.679 fMPS_BFA_cluster_Model_reuse.amplitude  ../tn_model.py:1562
            ├─ 0.988 fMPS_BFA_cluster_Model_reuse.get_amp_tn  ../tn_model.py:1410
            │  ├─ 0.278 MatrixProductState.select  quimb/tensor/tensor_core.py:5100
            │  │     [5 frames hidden]  quimb
            │  ├─ 0.209 reconstruct_proj_params  ../../fermion_utils.py:1317
            │  │  ├─ 0.138 [self]  ../../fermion_utils.py
            │  │  └─ 0.061 Tensor.reshape  <b

In [12]:
with pyinstrument.Profiler() as prof0:
    with torch.no_grad():
        eta_amps = model0(etas)
prof0.print()


  _     ._   __/__   _ _  _  _ _/_   Recorded: 22:59:59  Samples:  10045
 /_//_/// /_\ / //_// / //_'/ //     Duration: 10.330    CPU time: 10.331
/   _/                      v5.0.1

Profile at /tmp/ipykernel_12210/2291535159.py:1

10.329 <module>  /tmp/ipykernel_12210/2291535159.py:1
└─ 10.329 fMPS_BFA_cluster_Model._wrapped_call_impl  torch/nn/modules/module.py:1735
   └─ 10.329 fMPS_BFA_cluster_Model._call_impl  torch/nn/modules/module.py:1743
      └─ 10.329 fMPS_BFA_cluster_Model.forward  ../tn_model.py:96
         └─ 10.329 fMPS_BFA_cluster_Model.amplitude  ../tn_model.py:1139
            ├─ 6.083 MatrixProductState.contract  quimb/tensor/tensor_core.py:8934
            │     [47 frames hidden]  quimb, functools, cotengra, symmray, ...
            ├─ 2.201 <listcomp>  ../tn_model.py:1164
            │  └─ 2.153 SelfAttn_FFNN_block._wrapped_call_impl  torch/nn/modules/module.py:1735
            │     └─ 2.142 SelfAttn_FFNN_block._call_impl  torch/nn/modules/module.py:1743
       