In [51]:
import os
os.environ["OPENBLAS_NUM_THREADS"] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ["OMP_NUM_THREADS"] = '1'
from mpi4py import MPI
import pickle

# torch
import torch
import torch.nn as nn
torch.autograd.set_detect_anomaly(False)

# quimb
import quimb.tensor as qtn
import autoray as ar

from vmc_torch.experiment.tn_model import fMPSModel, fMPS_backflow_Model, fMPS_backflow_attn_Tensorwise_Model_v1
from vmc_torch.experiment.tn_model import init_weights_to_zero
from vmc_torch.sampler import MetropolisExchangeSamplerSpinful
from vmc_torch.variational_state import Variational_State
from vmc_torch.optimizer import SGD, SR, Adam, SGD_momentum, DecayScheduler
from vmc_torch.VMC import VMC
from vmc_torch.hamiltonian import spinful_Fermi_Hubbard_chain, spinful_Fermi_Hubbard_chain_quimb
from vmc_torch.torch_utils import SVD,QR


# Register safe SVD and QR functions to torch
ar.register_function('torch','linalg.svd',SVD.apply)
ar.register_function('torch','linalg.qr',QR.apply)

from vmc_torch.global_var import DEBUG
from vmc_torch.utils import closest_divisible
pwd = '/home/sijingdu/TNVMC/VMC_code/vmc_torch/data'

COMM = MPI.COMM_WORLD
SIZE = COMM.Get_size()
RANK = COMM.Get_rank()

# Hamiltonian parameters
L = int(10)
symmetry = 'Z2'
t = 1.0
U = 8.0
N_f = int(L-2)
n_fermions_per_spin = (N_f//2, N_f//2)
H = spinful_Fermi_Hubbard_chain(L, t, U, N_f, pbc=False, n_fermions_per_spin=n_fermions_per_spin)
quimb_ham = spinful_Fermi_Hubbard_chain_quimb(L, t, U, mu=0.0, pbc=False, symmetry=symmetry)
graph = H.graph
# TN parameters
D = 8
chi = -2
dtype=torch.float64

# Load mps
skeleton = pickle.load(open(pwd+f"/L={L}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/mps_skeleton.pkl", "rb"))
mps_params = pickle.load(open(pwd+f"/L={L}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/mps_su_params.pkl", "rb"))
mps = qtn.unpack(mps_params, skeleton)
# fmps_tnf = form_gated_fmps_tnf(fmps=mps, ham=quimb_ham, depth=2)
mps.apply_to_arrays(lambda x: torch.tensor(x, dtype=dtype))

# # randomize the mps tensors
# mps.apply_to_arrays(lambda x: torch.randn_like(torch.tensor(x, dtype=dtype), dtype=dtype))

# VMC sample size
N_samples = int(15000)
N_samples = closest_divisible(N_samples, SIZE)
if (N_samples/SIZE)%2 != 0:
    N_samples += SIZE


In [64]:
import ast
from vmc_torch.experiment.tn_model import wavefunctionModel, fMPSModel

class fMPSModel_GPU(wavefunctionModel):
    def __init__(self, ftn, dtype=torch.float32):
        super().__init__()
        self.param_dtype = dtype
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fMPS (exact contraction)':{'D': ftn.max_bond(), 'L': ftn.L, 'symmetry': self.symmetry, 'cyclic': ftn.cyclic, 'skeleton': self.skeleton},
        }
    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        
        # Ensure x is a tensor of the correct dtype and move to GPU
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=self.param_dtype)
        elif x.dtype != self.param_dtype:
            x = x.to(self.param_dtype)
        
        # Move x to GPU and enable gradient computation
        x = x.to('cuda')

        # Get model parameters list
        params_list = list(self.parameters())

        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Get the amplitude
            with torch.no_grad():
                amp = psi.get_amp(x_i, conj=True)
                amp_val = amp.contract()
            if amp_val == 0.0:
                amp_val = torch.tensor(0.0, device='cuda')
            batch_amps.append(amp_val)
        
        # Stack the amplitudes into a tensor
        batch_amps = torch.stack(batch_amps).to('cuda')
        return batch_amps

    def amplitude_grad(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        
        # Ensure x is a tensor of the correct dtype and move to GPU
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=self.param_dtype)
        elif x.dtype != self.param_dtype:
            x = x.to(self.param_dtype)
        
        # # Move x to GPU and enable gradient computation
        # x = x.to('cuda')

        # Get model parameters list
        params_list = list(self.parameters())

        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Get the amplitude
            amp = psi.get_amp(x_i, conj=True)
            amp_val = amp.contract()
            if amp_val == 0.0:
                amp_val = torch.tensor(0.0, device='cuda')
            batch_amps.append(amp_val)

        # Stack the amplitudes into a tensor
        batch_amps = torch.stack(batch_amps).to('cuda')

        # Compute gradients with respect to the parameters
        gradients = []
        for amp in batch_amps:
            grad = torch.autograd.grad(amp, self.parameters(), retain_graph=True, allow_unused=True)
            flatten_grad = []
            for i in range(len(grad)):
                if grad[i] is None:
                    flatten_grad.append(torch.zeros_like(params_list[i]))
                else:
                    flatten_grad.append(grad[i])
            gradients.append(torch.cat([g.flatten() for g in flatten_grad]))
        # Stack the gradients into a tensor
        gradients = torch.stack(gradients)

        return batch_amps, gradients

In [65]:
model = fMPSModel_GPU(mps, dtype=dtype)
model1 = fMPSModel(mps, dtype=dtype)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

fMPSModel_GPU(
  (torch_tn_params): ModuleDict(
    (0): ParameterDict(
        ((0, 1)): Parameter containing: [torch.cuda.DoubleTensor of size 2x2 (cuda:0)]
        ((1, 0)): Parameter containing: [torch.cuda.DoubleTensor of size 2x2 (cuda:0)]
    )
    (1): ParameterDict(
        ((0, 0, 0)): Parameter containing: [torch.cuda.DoubleTensor of size 2x4x2 (cuda:0)]
        ((0, 1, 1)): Parameter containing: [torch.cuda.DoubleTensor of size 2x4x2 (cuda:0)]
        ((1, 0, 1)): Parameter containing: [torch.cuda.DoubleTensor of size 2x4x2 (cuda:0)]
        ((1, 1, 0)): Parameter containing: [torch.cuda.DoubleTensor of size 2x4x2 (cuda:0)]
    )
    (2): ParameterDict(
        ((0, 0, 0)): Parameter containing: [torch.cuda.DoubleTensor of size 4x4x2 (cuda:0)]
        ((0, 1, 1)): Parameter containing: [torch.cuda.DoubleTensor of size 4x4x2 (cuda:0)]
        ((1, 0, 1)): Parameter containing: [torch.cuda.DoubleTensor of size 4x4x2 (cuda:0)]
        ((1, 1, 0)): Parameter containing: [torch.

In [182]:
import jax
import pyinstrument
random_config = [H.hilbert.random_state(key=jax.random.PRNGKey(1)), H.hilbert.random_state(key=jax.random.PRNGKey(2))]
random_config = torch.tensor(random_config, dtype=dtype)
random_config_gpu = random_config.to(device)
with pyinstrument.Profiler() as prof:
    model.amplitude_grad(random_config_gpu)
print(prof.output_text(unicode=True, color=True))


  _     ._   __/__   _ _  _  _ _/_   Recorded: 22:43:56  Samples:  43
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.056     CPU time: 0.056
/   _/                      v4.7.3

Profile at /tmp/ipykernel_12453/2548696919.py:6

[31m0.055[0m [48;5;24m[38;5;15m<module>[0m  [2m../../../../../tmp/ipykernel_12453/2548696919.py:6[0m
└─ [31m0.055[0m [48;5;24m[38;5;15mfMPSModel_GPU.amplitude_grad[0m  [2m../../../../../tmp/ipykernel_12453/269258344.py:69[0m
   ├─ [33m0.032[0m [48;5;24m[38;5;15mfMPS.get_amp[0m  [2mvmc_torch/fermion_utils.py:471[0m
   │  ├─ [33m0.027[0m [48;5;24m[38;5;15mTensorNetwork.contract[0m  [2mquimb/tensor/tensor_core.py:8438[0m
   │  │  └─ [33m0.027[0m [48;5;24m[38;5;15mTensorNetwork.contract_tags[0m  [2mquimb/tensor/tensor_core.py:8328[0m
   │  │     ├─ [33m0.025[0m wrapper[0m  [2mfunctools.py:883[0m
   │  │     │  └─ [33m0.025[0m [48;5;24m[38;5;15mtensor_contract[0m  [2mquimb/tensor/tensor_core.py:207[0m
   │  │     │     

In [209]:
with pyinstrument.Profiler() as prof:
    for config in random_config:
        if config.ndim == 1:
            config = config.unsqueeze(0)
        amp = model1.amplitude(config)
        amp.backward()
        grad = model1.params_grad_to_vec()
print(prof.output_text(unicode=True, color=True))


  _     ._   __/__   _ _  _  _ _/_   Recorded: 22:44:01  Samples:  20
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.023     CPU time: 0.023
/   _/                      v4.7.3

Profile at /tmp/ipykernel_12453/3278166090.py:1

[31m0.022[0m [48;5;24m[38;5;15m<module>[0m  [2m../../../../../tmp/ipykernel_12453/3278166090.py:1[0m
├─ [31m0.018[0m [48;5;24m[38;5;15mfMPSModel.amplitude[0m  [2m../tn_model.py:566[0m
│  ├─ [33m0.012[0m [48;5;24m[38;5;15mfMPS.get_amp[0m  [2mvmc_torch/fermion_utils.py:471[0m
│  │  ├─ [33m0.009[0m [48;5;24m[38;5;15mTensorNetwork.contract[0m  [2mquimb/tensor/tensor_core.py:8438[0m
│  │  │  └─ [33m0.009[0m [48;5;24m[38;5;15mTensorNetwork.contract_tags[0m  [2mquimb/tensor/tensor_core.py:8328[0m
│  │  │     └─ [33m0.009[0m wrapper[0m  [2mfunctools.py:883[0m
│  │  │        └─ [33m0.009[0m [48;5;24m[38;5;15mtensor_contract[0m  [2mquimb/tensor/tensor_core.py:207[0m
│  │  │           ├─ [33m0.008[0m [48;5;24m[38;5;15marr