# Simplest example

In [1]:
# import torch
# import torch.nn as nn
# from functorch import vmap, jacrev

# class MyModel(nn.Module):
#     def __init__(self, in_dim, out_dim):
#         super().__init__()
#         self.linear = nn.Linear(in_dim, out_dim, bias=True)

#     def forward(self, x):
#         # x shape: [batch_size, in_dim]
#         return self.linear(x)  # [batch_size, out_dim]

# # Instantiate a toy model
# model = MyModel(in_dim=3, out_dim=1)

# # Create a batch of inputs: shape [m, in_dim]
# X = torch.randn(5, 3)  # m=5 examples, each dimension=3

# # Extract model parameters as a tuple
# # For a single Linear layer: params=(weight, bias)
# params0 = tuple(model.parameters())  # (W, b)
# params = {'W': params0[0], 'b': params0[1]}

# def model_functional(params, x):
#     """
#     params = {'W': W, 'b': b}
#     x      = single input, shape: [in_dim]
#     Return f_theta(x), shape: [out_dim].
#     """
#     W = params['W']
#     b = params['b']
#     return x @ W.T + b

# def single_sample_jac(params, x):
#     """
#     Return the Jacobian of model_functional w.r.t. 'params'
#     for a single input x.

#     Shape details:
#       * f_theta(x) in R^(out_dim)
#       * 'params' is a tuple (W, b)
#     Result is a tuple of the same structure as 'params':
#       (Jac_of_W, Jac_of_b)
#     """
#     # 'lambda p: model_functional(p, x)' is a function of 'p' only
#     return jacrev(lambda p: model_functional(p, x))(params)

# # We'll define a "batched" version of single_sample_jac:
# batched_param_jac = vmap(single_sample_jac, in_dims=(None, 0))
# #  -> 'params' is not varying (None), 
# #     'x' is taken from the 0th dim of X.

# # Now compute the per-example Jacobian:
# jac_per_sample = batched_param_jac(params, X)

# jac_per_sample

In [2]:
import os
os.environ["OPENBLAS_NUM_THREADS"] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ["OMP_NUM_THREADS"] = '1'
os.environ['DENSE_TENSOR'] = '0'
import sys
import warnings
warnings.filterwarnings("ignore")
from mpi4py import MPI
import numpy as np
import pickle
pwd = '/home/sijingdu/TNVMC/VMC_code/vmc_torch/data'
# torch
import torch
torch.autograd.set_detect_anomaly(False)

# quimb
import quimb.tensor as qtn
import autoray as ar

from vmc_torch.experiment.tn_model import *
from vmc_torch.sampler import MetropolisExchangeSamplerSpinful, MetropolisMPSSamplerSpinful
from vmc_torch.variational_state import Variational_State
from vmc_torch.optimizer import SGD, SR,Adam, SGD_momentum, DecayScheduler, TrivialPreconditioner
from vmc_torch.VMC import VMC
from vmc_torch.hamiltonian_torch import spinful_Fermi_Hubbard_square_lattice_torch
# from vmc_torch.torch_utils import SVD,QR

# # Register safe SVD and QR functions to torch
# ar.register_function('torch','linalg.svd',SVD.apply)
# ar.register_function('torch','linalg.qr',QR.apply)

from vmc_torch.global_var import DEBUG
from vmc_torch.utils import closest_divisible


COMM = MPI.COMM_WORLD
SIZE = COMM.Get_size()
RANK = COMM.Get_rank()

# Hamiltonian parameters
Lx = int(4)
Ly = int(4)
symmetry = 'Z2'
t = 1.0
U = 8.0
N_f = int(Lx*Ly)
# N_f = int(Lx*Ly)
n_fermions_per_spin = (N_f//2, N_f//2)
H = spinful_Fermi_Hubbard_square_lattice_torch(Lx, Ly, t, U, N_f, pbc=False, n_fermions_per_spin=n_fermions_per_spin)
graph = H.graph
# TN parameters
D = 4
chi = -1
dtype=torch.float64
torch.random.manual_seed(RANK)
np.random.seed(RANK)

# Load PEPS
skeleton = pickle.load(open(pwd+f"/{Lx}x{Ly}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/peps_skeleton.pkl", "rb"))
peps_params = pickle.load(open(pwd+f"/{Lx}x{Ly}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/peps_su_params.pkl", "rb"))
peps = qtn.unpack(peps_params, skeleton)
device = torch.device("cuda")
peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=dtype, device=device))
peps.exponent = torch.tensor(peps.exponent, dtype=dtype, device=device)

# # randomize the PEPS tensors
# peps.apply_to_arrays(lambda x: torch.randn_like(torch.tensor(x, dtype=dtype), dtype=dtype))

# VMC sample size
N_samples = int(20)
N_samples = closest_divisible(N_samples, SIZE)
if (N_samples/SIZE)%2 != 0:
    N_samples += SIZE
        
# nn_hidden_dim = Lx*Ly
model = fTNModel_vec(peps, max_bond=chi, dtype=dtype, functional=False, device=device)
model1 = fTNModel(peps, max_bond=chi, dtype=dtype, functional=False)
model1.tree = model.tree
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = fTN_backflow_attn_Tensorwise_Model_v1(
#     peps,
#     max_bond=chi,
#     embedding_dim=16,
#     attention_heads=4,
#     nn_final_dim=4,
#     nn_eta=1.0,
#     dtype=dtype,
# )

# Set up sampler
sampler = MetropolisExchangeSamplerSpinful(H.hilbert, graph, N_samples=N_samples, burn_in_steps=2, reset_chain=False, random_edge=False, equal_partition=True, dtype=dtype, subchain_length=10)
# mps_dir = '/home/sijingdu/TNVMC/VMC_code/vmc_torch/data'+f'/{Lx}x{Ly}/t={t}_U={U}/N={N_f}/tmp'
# sampler = MetropolisMPSSamplerSpinful(H.hilbert, graph, mps_dir=mps_dir, mps_n_sample=1, N_samples=N_samples, burn_in_steps=20, reset_chain=True, equal_partition=True, dtype=dtype)

F=4.62 C=5.61 S=10.00 P=11.35: 100%|██████████| 10/10 [00:00<00:00, 90.47it/s]


In [3]:
# X = [H.hilbert.random_state(i) for i in range(10)]
# X = torch.tensor(X, dtype=dtype, device=device)
# amp0 = peps.get_amp(X[0], functional=True)
# amp00 = peps.get_amp(X[0], functional=False)

# amp1 = amp0.contract_boundary_from_xmin(max_bond=-1, xrange=(0, Lx//2-1))
# amp2 = amp1.contract_boundary_from_xmax(max_bond=-1, xrange=(Lx//2, Lx-1))
# amp00.contract(), amp0.contract()

In [4]:
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.func import functional_call
from torch.autograd.functional import jacobian
import pyinstrument

param_dict = dict(model.named_parameters())
# 1) Suppose we've already built a param_dict, but we also want to store
#    the shape/size info for each param in a fixed order:
names, params = zip(*param_dict.items())  # separate keys, values
shapes = [p.shape for p in params]
numels = [p.numel() for p in params]

def vector_to_param_dict(vec):
    """
    vec: 1D Tensor containing *all* parameters in the correct order.
    returns a dict { name_i : param_tensor_i }, with shapes matching the original.
    """
    out = {}
    start = 0
    for name, shape, length in zip(names, shapes, numels):
        end = start + length
        out[name] = vec[start:end].reshape(shape)
        start = end
    return out

model.to(device)
model1.to(device)
model.skeleton.exponent = model.skeleton.exponent.to(device)

# Example usage
new_vec = model.from_params_to_vec()
new_param_dict = vector_to_param_dict(new_vec)  # {"linear1.weight": tensor(...), ...}

# # 2) Now define a "functional" forward using functional_call:
# def fmodel(vec, x):
#     # Overwrite the model's original parameters with the new ones from vec
#     pdict = vector_to_param_dict(vec,)
#     return functional_call(model, pdict, (x,))

# 3) Finally, we can compute the Jacobian:
np.random.seed(0)
rand_number = np.random.randint(0, 1000)
X = [H.hilbert.random_state(i) for i in range(10)]
X = torch.tensor(X, dtype=dtype, device=device)

# vmodel = vmap(model)
# amps_vec, amps = vmodel(X), model1(X)
# amps_vec/amps

In [5]:
model_c = torch.compile(model)
# model_c(X)
model_c(X[1])

W0424 23:22:28.406000 11034 torch/_dynamo/variables/builtin.py:783] [13/0] incorrect arg count <bound method BuiltinVariable.call_next of BuiltinVariable(next)> too many positional arguments and no constant handler
W0424 23:22:29.764000 11034 torch/_dynamo/variables/builtin.py:783] [15/2] incorrect arg count <bound method BuiltinVariable._call_min_max of BuiltinVariable(max)> got an unexpected keyword argument 'key' and no constant handler
W0424 23:22:29.778000 11034 torch/_dynamo/variables/builtin.py:783] [16/0] incorrect arg count <bound method BuiltinVariable._call_min_max of BuiltinVariable(max)> got an unexpected keyword argument 'key' and no constant handler
W0424 23:22:29.792000 11034 torch/_dynamo/variables/builtin.py:783] [17/0] incorrect arg count <bound method BuiltinVariable._call_min_max of BuiltinVariable(max)> got an unexpected keyword argument 'key' and no constant handler
W0424 23:22:30.152000 11034 torch/_dynamo/variables/builtin.py:783] [16/1] incorrect arg count <bo

tensor(-1.1611e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<CompiledFunctionBackward>)

In [6]:
vmodel_c = vmap(model_c)
vmodel_c(X)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/sijingdu/TNVMC/VMC_code/mpsds/mpsds/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3549, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_11034/4088809724.py", line 2, in <module>
    vmodel_c(X)
  File "/home/sijingdu/TNVMC/VMC_code/mpsds/mpsds/lib/python3.11/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/home/sijingdu/TNVMC/VMC_code/mpsds/mpsds/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/home/sijingdu/TNVMC/VMC_code/mpsds/mpsds/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sijingdu/TNVMC/VMC_code/mpsds/mpsds/lib/python3.11/site-packages/torch/nn/modules/module.py", l

In [35]:
# choose number of streams to interleave launches (tunable)
num_streams = 10
streams = [torch.cuda.Stream(device=device) for _ in range(num_streams)]

# prepare output list
B = X.shape[0]
outputs = [None] * B

# launch each sample on a stream in round-robin
for i in range(B):
    stream = streams[i % num_streams]
    x_i = X[i]  # shape (...)
    with torch.cuda.stream(stream):
        # model_c only handles single-sample input
        outputs[i] = model_c(x_i)

# wait for all streams to finish
torch.cuda.synchronize()

# stack results back into (B, ...)
Y = torch.stack(outputs, dim=0)
Y, model1(X)

(tensor([-7.4144e-11, -1.1611e-09, -2.9750e-16,  7.4415e-13, -6.3460e-09,
          2.3613e-12, -8.1625e-07, -1.4848e-09,  7.5873e-14,  9.3872e-16],
        device='cuda:0', dtype=torch.float64, grad_fn=<StackBackward0>),
 tensor([-7.4144e-11, -1.1611e-09,  2.9750e-16,  7.4415e-13,  6.3460e-09,
          2.3613e-12,  8.1625e-07, -1.4848e-09, -7.5873e-14,  9.3872e-16],
        device='cuda:0', dtype=torch.float64, grad_fn=<StackBackward0>))

In [27]:
amp_val = torch.tensor([peps.get_amp(x).contract() for x in X], device=device)
model1(X)/amp_val

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000], dtype=torch.float64, grad_fn=<DivBackward0>)

In [31]:
# check contraction tree
# for i, (_, left_tids, right_tids) in enumerate(model.tree.traverse()):
#     print(i, left_tids, right_tids)

In [32]:
with pyinstrument.Profiler() as prof:
    model(X)
prof.print()
with pyinstrument.Profiler() as prof:
    model1(X)
prof.print()


  _     ._   __/__   _ _  _  _ _/_   Recorded: 20:54:39  Samples:  159
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.164     CPU time: 0.164
/   _/                      v5.0.1

Profile at /tmp/ipykernel_22040/3565115910.py:1

0.163 <module>  ../ipykernel_22040/3565115910.py:1
└─ 0.163 fTNModel_vec._wrapped_call_impl  torch/nn/modules/module.py:1735
   └─ 0.163 fTNModel_vec._call_impl  torch/nn/modules/module.py:1743
      └─ 0.163 fTNModel_vec.forward  ../tn_model.py:1374
         └─ 0.163 fTNModel_vec.amplitude  ../tn_model.py:1330
            ├─ 0.160 wrapped  torch/_functorch/apis.py:202
            │  └─ 0.160 vmap_impl  torch/_functorch/vmap.py:309
            │     └─ 0.160 _flat_vmap  torch/_functorch/vmap.py:472
            │        └─ 0.160 amplitude_func  ../tn_model.py:1343
            │           ├─ 0.072 PEPS.contract  quimb/tensor/tensor_core.py:8934
            │           │     [42 frames hidden]  functools, quimb, cotengra, symmray, ...
            │           ├─ 0

In [4]:
# use vmap to compute the Jacobian
from functorch import jacrev, vmap

In [5]:
# Set up variational state
variational_state = Variational_State(model, hi=H.hilbert, sampler=sampler, dtype=dtype)
with pyinstrument.Profiler() as prof:
    J = jacobian(lambda v: fmodel(v, X), new_vec, vectorize=True)
prof.print()
with pyinstrument.Profiler() as prof:
    amp_list = []
    for x in X:
        amp, _ = variational_state.amplitude_grad(x)
        amp_list.append(amp)
prof.print()
print("amp_list", amp_list)


  _     ._   __/__   _ _  _  _ _/_   Recorded: 18:05:36  Samples:  348
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.876     CPU time: 0.860
/   _/                      v5.0.1

Profile at /tmp/ipykernel_31234/2870813720.py:3

0.876 <module>  ../ipykernel_31234/2870813720.py:1
└─ 0.875 jacobian  torch/autograd/functional.py:575
      [6 frames hidden]  torch
         0.566 <lambda>  ../ipykernel_31234/2870813720.py:4
         └─ 0.566 fmodel  ../ipykernel_31234/2641142638.py:36
            └─ 0.564 functional_call  torch/_functorch/functional_call.py:11
                  [2 frames hidden]  torch
                     0.561 fTNModel_vec._call_impl  torch/nn/modules/module.py:1743
                     └─ 0.561 fTNModel_vec.forward  ../tn_model.py:1364
                        └─ 0.561 fTNModel_vec.amplitude  ../tn_model.py:1317
                           └─ 0.554 wrapped  torch/_functorch/apis.py:202
                              └─ 0.554 vmap_impl  torch/_functorch/vmap.py:309
        

TypeError: iteration over a 0-d tensor

In [12]:
device = torch.device("cpu")
model.to(device)
model.skeleton.exponent = model.skeleton.exponent.to(device)
print(model.tree)

# Example usage
new_vec = model.from_params_to_vec()
new_param_dict = vector_to_param_dict(new_vec)  # {"linear1.weight": tensor(...), ...}

# 2) Now define a "functional" forward using functional_call:
def fmodel(vec, x):
    # Overwrite the model's original parameters with the new ones from vec
    pdict = vector_to_param_dict(vec,)
    return functional_call(model, pdict, (x,))
# 3) Finally, we can compute the Jacobian:
np.random.seed(0)
rand_number = np.random.randint(0, 1000)
X = [H.hilbert.random_state(i) for i in range(2)]
X = torch.tensor(X, dtype=dtype, device=device)
# Set up variational state
variational_state = Variational_State(model, hi=H.hilbert, sampler=sampler, dtype=dtype)
with pyinstrument.Profiler() as prof:
    J = jacobian(lambda v: fmodel(v, X), new_vec, vectorize=True)
prof.print()
with pyinstrument.Profiler() as prof:
    amp_list = []
    for x in X:
        amp, _ = variational_state.amplitude_grad(x)
        amp_list.append(amp)
prof.print()
print("amp_list", amp_list)

<ContractionTree(N=16, F=4.62, C=5.61, S=10.00, P=11.35)>

  _     ._   __/__   _ _  _  _ _/_   Recorded: 15:31:44  Samples:  57
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.096     CPU time: 0.096
/   _/                      v5.0.1

Profile at /tmp/ipykernel_6982/2801438995.py:22

0.095 <module>  /tmp/ipykernel_6982/2801438995.py:1
└─ 0.095 jacobian  torch/autograd/functional.py:575
      [6 frames hidden]  torch
         0.056 <lambda>  /tmp/ipykernel_6982/2801438995.py:23
         └─ 0.056 fmodel  /tmp/ipykernel_6982/2801438995.py:11
            └─ 0.056 functional_call  torch/_functorch/functional_call.py:11
                  [11 frames hidden]  torch
                     0.053 fTN_backflow_attn_Tensorwise_Model_v1._call_impl  torch/nn/modules/module.py:1743
                     └─ 0.053 fTN_backflow_attn_Tensorwise_Model_v1.forward  ../tn_model.py:91
                        └─ 0.053 fTN_backflow_attn_Tensorwise_Model_v1.amplitude  ../tn_model.py:2306
                          

In [6]:
next(variational_state.vstate_func.parameters())

Parameter containing:
tensor([[[ 0.0507,  0.0008],
         [ 0.0943, -0.0176],
         [ 0.0776,  0.0564],
         [ 0.1230,  0.0463]],

        [[ 0.0230,  0.0078],
         [-0.0103, -0.0031],
         [ 0.0686, -0.0179],
         [ 0.0319,  0.0093]],

        [[-0.1914,  0.0003],
         [ 0.0407, -0.0029],
         [ 0.0526,  0.0109],
         [-0.0262, -0.0017]],

        [[-0.0308,  0.0062],
         [-0.0817,  0.0053],
         [-0.0261,  0.0110],
         [ 0.0314,  0.0036]]], dtype=torch.float64, requires_grad=True)

In [242]:
rand_number = np.random.randint(0, 1000)
X = [H.hilbert.random_state(11+rand_number), H.hilbert.random_state(22+rand_number)]
X = torch.tensor(X, dtype=dtype)
with pyinstrument.Profiler() as prof:
    for x in X:
        variational_state.amplitude_grad(x)
prof.print()


  _     ._   __/__   _ _  _  _ _/_   Recorded: 00:41:44  Samples:  39
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.049     CPU time: 0.050
/   _/                      v5.0.1

Profile at /tmp/ipykernel_31016/2912553885.py:4

0.049 <module>  ../ipykernel_31016/2912553885.py:1
├─ 0.048 wrapper  ../../utils.py:39
│  └─ 0.048 Variational_State.amplitude_grad  ../../variational_state.py:98
│     ├─ 0.035 fTNModel._wrapped_call_impl  torch/nn/modules/module.py:1735
│     │  └─ 0.035 fTNModel._call_impl  torch/nn/modules/module.py:1743
│     │     └─ 0.035 fTNModel.forward  ../tn_model.py:94
│     │        └─ 0.035 fTNModel.amplitude  ../ipykernel_31016/3828677844.py:122
│     │           ├─ 0.022 PEPS.contract  quimb/tensor/tensor_core.py:8934
│     │           │     [43 frames hidden]  functools, quimb, cotengra, symmray, ...
│     │           ├─ 0.007 fPEPS.get_amp  ../../fermion_utils.py:158
│     │           │  └─ 0.007 fPEPS.get_amp_efficient  ../../fermion_utils.py:185
│     │     

In [64]:
import ast
from vmc_torch.experiment.tn_model import wavefunctionModel, fMPSModel

class fMPSModel_GPU(wavefunctionModel):
    def __init__(self, ftn, dtype=torch.float32):
        super().__init__()
        self.param_dtype = dtype
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fMPS (exact contraction)':{'D': ftn.max_bond(), 'L': ftn.L, 'symmetry': self.symmetry, 'cyclic': ftn.cyclic, 'skeleton': self.skeleton},
        }
    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        
        # Ensure x is a tensor of the correct dtype and move to GPU
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=self.param_dtype)
        elif x.dtype != self.param_dtype:
            x = x.to(self.param_dtype)
        
        # Move x to GPU and enable gradient computation
        x = x.to('cuda')

        # Get model parameters list
        params_list = list(self.parameters())

        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Get the amplitude
            with torch.no_grad():
                amp = psi.get_amp(x_i, conj=True)
                amp_val = amp.contract()
            if amp_val == 0.0:
                amp_val = torch.tensor(0.0, device='cuda')
            batch_amps.append(amp_val)
        
        # Stack the amplitudes into a tensor
        batch_amps = torch.stack(batch_amps).to('cuda')
        return batch_amps

    def amplitude_grad(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        
        # Ensure x is a tensor of the correct dtype and move to GPU
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=self.param_dtype)
        elif x.dtype != self.param_dtype:
            x = x.to(self.param_dtype)
        
        # # Move x to GPU and enable gradient computation
        # x = x.to('cuda')

        # Get model parameters list
        params_list = list(self.parameters())

        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Get the amplitude
            amp = psi.get_amp(x_i, conj=True)
            amp_val = amp.contract()
            if amp_val == 0.0:
                amp_val = torch.tensor(0.0, device='cuda')
            batch_amps.append(amp_val)

        # Stack the amplitudes into a tensor
        batch_amps = torch.stack(batch_amps).to('cuda')

        # Compute gradients with respect to the parameters
        gradients = []
        for amp in batch_amps:
            grad = torch.autograd.grad(amp, self.parameters(), retain_graph=True, allow_unused=True)
            flatten_grad = []
            for i in range(len(grad)):
                if grad[i] is None:
                    flatten_grad.append(torch.zeros_like(params_list[i]))
                else:
                    flatten_grad.append(grad[i])
            gradients.append(torch.cat([g.flatten() for g in flatten_grad]))
        # Stack the gradients into a tensor
        gradients = torch.stack(gradients)

        return batch_amps, gradients

In [182]:
import jax
import pyinstrument
random_config = [H.hilbert.random_state(key=jax.random.PRNGKey(1)), H.hilbert.random_state(key=jax.random.PRNGKey(2))]
random_config = torch.tensor(random_config, dtype=dtype)
random_config_gpu = random_config.to(device)
with pyinstrument.Profiler() as prof:
    model.amplitude_grad(random_config_gpu)
print(prof.output_text(unicode=True, color=True))


  _     ._   __/__   _ _  _  _ _/_   Recorded: 22:43:56  Samples:  43
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.056     CPU time: 0.056
/   _/                      v4.7.3

Profile at /tmp/ipykernel_12453/2548696919.py:6

[31m0.055[0m [48;5;24m[38;5;15m<module>[0m  [2m../../../../../tmp/ipykernel_12453/2548696919.py:6[0m
└─ [31m0.055[0m [48;5;24m[38;5;15mfMPSModel_GPU.amplitude_grad[0m  [2m../../../../../tmp/ipykernel_12453/269258344.py:69[0m
   ├─ [33m0.032[0m [48;5;24m[38;5;15mfMPS.get_amp[0m  [2mvmc_torch/fermion_utils.py:471[0m
   │  ├─ [33m0.027[0m [48;5;24m[38;5;15mTensorNetwork.contract[0m  [2mquimb/tensor/tensor_core.py:8438[0m
   │  │  └─ [33m0.027[0m [48;5;24m[38;5;15mTensorNetwork.contract_tags[0m  [2mquimb/tensor/tensor_core.py:8328[0m
   │  │     ├─ [33m0.025[0m wrapper[0m  [2mfunctools.py:883[0m
   │  │     │  └─ [33m0.025[0m [48;5;24m[38;5;15mtensor_contract[0m  [2mquimb/tensor/tensor_core.py:207[0m
   │  │     │     

In [209]:
with pyinstrument.Profiler() as prof:
    for config in random_config:
        if config.ndim == 1:
            config = config.unsqueeze(0)
        amp = model1.amplitude(config)
        amp.backward()
        grad = model1.params_grad_to_vec()
print(prof.output_text(unicode=True, color=True))


  _     ._   __/__   _ _  _  _ _/_   Recorded: 22:44:01  Samples:  20
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.023     CPU time: 0.023
/   _/                      v4.7.3

Profile at /tmp/ipykernel_12453/3278166090.py:1

[31m0.022[0m [48;5;24m[38;5;15m<module>[0m  [2m../../../../../tmp/ipykernel_12453/3278166090.py:1[0m
├─ [31m0.018[0m [48;5;24m[38;5;15mfMPSModel.amplitude[0m  [2m../tn_model.py:566[0m
│  ├─ [33m0.012[0m [48;5;24m[38;5;15mfMPS.get_amp[0m  [2mvmc_torch/fermion_utils.py:471[0m
│  │  ├─ [33m0.009[0m [48;5;24m[38;5;15mTensorNetwork.contract[0m  [2mquimb/tensor/tensor_core.py:8438[0m
│  │  │  └─ [33m0.009[0m [48;5;24m[38;5;15mTensorNetwork.contract_tags[0m  [2mquimb/tensor/tensor_core.py:8328[0m
│  │  │     └─ [33m0.009[0m wrapper[0m  [2mfunctools.py:883[0m
│  │  │        └─ [33m0.009[0m [48;5;24m[38;5;15mtensor_contract[0m  [2mquimb/tensor/tensor_core.py:207[0m
│  │  │           ├─ [33m0.008[0m [48;5;24m[38;5;15marr