In [21]:
from vmap_utils import Transformer_fPEPS_Model_batchedAttn, Transformer_fPEPS_Model, random_initial_config, fPEPS_Model, compute_grads

import quimb.tensor as qtn

import symmray as sr

import torch

# System parameters
Lx = 4
Ly = 4
nsites = Lx * Ly
D = 4
seed = 42
flat = True

# random fPEPS
peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=[
        (0, 0),  # linear index 0 -> charge 0, offset 0
        (1, 1),  # linear index 1 -> charge 1, offset 1
        (1, 0),  # linear index 2 -> charge 1, offset 0
        (0, 1),  # linear index 3 -> charge 0, offset 1
    ],  # -> (0, 3), (2, 1)
    # put an odd number of odd sites in, for testing
    # site_charge=lambda site: int(site in [(0, 0), (0, 1), (1, 0)]),
    subsizes="equal",
    flat=flat,
    seed=seed,
)

# Prepare initial samples
# batchsize per rank
B = 10
fxs = []
for _ in range(B):
    fxs.append(random_initial_config(nsites, nsites))
fxs = torch.stack(fxs, dim=0)
model = Transformer_fPEPS_Model_batchedAttn(
    tn=peps,
    max_bond=-1,
    nn_eta=1,
    nn_hidden_dim=16,
    embed_dim=16,
    attn_heads=4,
    dtype=torch.float64,
)
# model = fPEPS_Model(
#     tn=peps,
#     max_bond=D,
#     dtype=torch.float64,
# )
nparams = sum(p.numel() for p in model.parameters())
model(fxs)

tensor([ -56626.0253, -126844.2878,  -27390.6366,   12296.2005,   43295.4322,
         -29739.2064,  -13452.7840,  -24764.9240,  -52658.3799,  -45615.5187],
       dtype=torch.float64, grad_fn=<MulBackward0>)

In [49]:
from vmc_torch.experiment.vmap.vmap_utils import flatten_params
from torch.utils._pytree import tree_map, tree_flatten

def compute_grads_decoupled(fxs, fpeps_model, batch_size=None):
    """
    解耦合梯度计算 (支持 Chunking 以节省内存)：
    Step 1: 前向计算 NN, 拿到数值上的 delta_P (backflow correction)
    Step 2: 计算 TN 的梯度 (分批 vmap, 拿到 sensitivity vector)
    Step 3: 将 sensitivity vector 回传给 NN (VJP, 逐样本循环)
    """
    B = fxs.shape[0]
    dtype = fpeps_model.dtype
    
    # 确定 chunk size，如果未指定则一次性算完
    B_grad = batch_size if batch_size is not None else B
    
    # === 准备参数 ===
    ftn_params = list(fpeps_model.ftn_params)
    nn_params = list(fpeps_model.nn_backflow.parameters())
    nn_params_dict = dict(zip(fpeps_model.nn_param_names, nn_params))

    # =================================================================
    # Step 1: NN Forward (Native Batch)
    # 目的：获取 delta_P 的数值，作为 TN 的输入。
    # =================================================================
    # 这一步通常内存占用不大（相比 TN），所以我们可以一次性算完。
    # 如果显存非常吃紧，也可以把这一步放入下面的循环中，但逻辑会稍微复杂一点。
    with torch.no_grad():
        batch_delta_p = torch.func.functional_call(
            fpeps_model.nn_backflow, nn_params_dict, fxs.to(dtype)
        )
    # batch_delta_p shape: (B, ftn_params_length)

    # =================================================================
    # Step 2: TN Backward (Chunked vmap over grad)
    # 目的：计算 psi 对 delta_P 的敏感度 (Sensitivity)
    # =================================================================
    
    # 定义纯 TN 收缩函数
    def tn_only_func(x_i, ftn_p_list, delta_p_i):
        amp = fpeps_model.tn_contraction(x_i, ftn_p_list, delta_p_i)
        return amp, amp # (Target, Aux)

    # 定义 vmap 函数
    tn_grad_vmap_func = torch.vmap(
        torch.func.grad(tn_only_func, argnums=(1, 2), has_aux=True), 
        in_dims=(0, None, 0)
    )

    # --- 开始 Chunking 循环 ---
    g_ftn_chunks = []
    g_sensitivity_chunks = []
    amps_chunks = []

    for b_start in range(0, B, B_grad):
        b_end = min(b_start + B_grad, B)
        
        # 1. 切片 (Slicing)
        fxs_chunk = fxs[b_start:b_end]
        delta_p_chunk = batch_delta_p[b_start:b_end]
        
        # 2. 计算当前 chunk 的梯度
        # (g_ftn_chunk, g_sensitivity_chunk), amps_chunk
        (g_ftn_c, g_sens_c), amps_c = tn_grad_vmap_func(fxs_chunk, ftn_params, delta_p_chunk)
        
        # 3. 立即 Detach Amps 以释放 TN 计算图 (关键!)
        if amps_c.requires_grad:
            amps_c = amps_c.detach()
            
        # 4. 存储结果
        g_ftn_chunks.append(g_ftn_c)        # 这是一个 list/tuple of tensors
        g_sensitivity_chunks.append(g_sens_c)
        amps_chunks.append(amps_c)
        
        # 显式删除临时变量，辅助 GC
        del g_ftn_c, g_sens_c, amps_c

    # --- 拼接结果 (Aggregation) ---
    
    # 1. 拼接 g_sensitivity (B, Out)
    g_sensitivity = torch.cat(g_sensitivity_chunks, dim=0)
    
    # 2. 拼接 amps (B, 1)
    amps = torch.cat(amps_chunks, dim=0)
    if amps.dim() == 1:
        amps = amps.unsqueeze(-1)

    # 3. 拼接 g_ftn (PyTree 结构)
    # g_ftn_chunks 是一个列表，里面每一个元素都是一个 tuple(tensor_param_1, tensor_param_2, ...)
    # 我们需要把它们按照 parameter 的位置，沿着 dim=0 拼起来
    # 使用 tree_map 可以优雅地处理这个结构
    g_ftn = tree_map(lambda *leaves: torch.cat(leaves, dim=0), *g_ftn_chunks)

    # 此时，TN 部分的大内存已经释放完毕

    # =================================================================
    # Step 3: NN Backward (Sequential Loop)
    # 目的：利用 g_sensitivity 计算 NN 参数的梯度
    # =================================================================
    # 这一步本身就是逐样本的，天然节省内存，直接复用你之前的逻辑即可
    
    g_nn_params_list = []
    
    for i in range(B):
        x_i = fxs[i].unsqueeze(0) 
        g_sens_i = g_sensitivity[i].unsqueeze(0) 
        
        fpeps_model.nn_backflow.zero_grad()
        
        with torch.enable_grad():
            out_i = torch.func.functional_call(
                fpeps_model.nn_backflow, 
                nn_params_dict, 
                x_i.to(dtype)
            )
            target = torch.sum(out_i * g_sens_i.detach())
            grads_i = torch.autograd.grad(target, nn_params, retain_graph=False)
            
        flat_g = flatten_params(grads_i)
        g_nn_params_list.append(flat_g)
        
    g_nn_params_vec = torch.stack(g_nn_params_list)

    # =================================================================
    # Step 4: 拼装最终结果
    # =================================================================
    
    # Flatten g_ftn
    leaves, _ = tree_flatten(g_ftn)
    flat_g_ftn_list = [leaf.flatten(start_dim=1) for leaf in leaves]
    g_ftn_params_vec = torch.cat(flat_g_ftn_list, dim=1)

    g_params_vec = torch.cat([g_ftn_params_vec, g_nn_params_vec], dim=1) # (B, Np_total)
    
    return g_params_vec, amps

g_params_vec, amps = compute_grads_decoupled(fxs, model, batch_size=6)
g_params_vec_benchmark, amps_benchmark = compute_grads(fxs, model, vectorize=True)
g_params_vec.shape, g_params_vec_benchmark.shape
torch.allclose(g_params_vec, g_params_vec_benchmark)

True

In [8]:
import quimb as qu
tn = peps.copy()
params, skeleton = qtn.pack(tn)
# for torch, further flatten pytree into a single list
ftn_params_flat, ftn_params_pytree = qu.utils.tree_flatten(
    params, get_ref=True
)
ftn_params = torch.nn.ParameterList([
    torch.as_tensor(x, dtype=torch.float64) for x in ftn_params_flat
])
ftn_params

ParameterList(
    (0): Parameter containing: [torch.float64 of size 4x2x2x2]
    (1): Parameter containing: [torch.float64 of size 4x2x2x2]
    (2): Parameter containing: [torch.float64 of size 8x2x2x2x2]
    (3): Parameter containing: [torch.float64 of size 8x2x2x2x2]
    (4): Parameter containing: [torch.float64 of size 8x2x2x2x2]
    (5): Parameter containing: [torch.float64 of size 8x2x2x2x2]
    (6): Parameter containing: [torch.float64 of size 4x2x2x2]
    (7): Parameter containing: [torch.float64 of size 4x2x2x2]
)