In [1]:
import os
import torch
import torch.distributed as dist
import numpy as np # 仅用于非计算的简单统计或IO
import pickle
import json
import time
from tqdm import tqdm

# 假设这些是你现有的工具库
from vmc_torch.experiment.vmap.GPU_vmap_utils import sample_next, evaluate_energy, compute_grads, random_initial_config
from vmc_torch.experiment.vmap.GPU_vmap_utils import Transformer_fPEPS_Model_batchedAttn, fPEPS_Model
from vmc_torch.hamiltonian_torch import spinful_Fermi_Hubbard_square_lattice_torch
from vmc_torch.experiment.tn_model import init_weights_to_zero
import quimb.tensor as qtn

# ==========================================
# 1. 初始化 Distributed 环境 (GPU Native)
# ==========================================
def setup_distributed():
    if "RANK" not in os.environ:
        # 调试模式：如果没有用 torchrun 启动，默认单卡运行
        print("Warning: Not using torchrun. Defaulting to single device.")
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12355"
        os.environ["LOCAL_RANK"] = "0"

    dist.init_process_group(backend="nccl", init_method="env://")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    local_rank = int(os.environ["LOCAL_RANK"])
    
    # 核心：设置当前进程使用的 GPU
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")
    return rank, world_size, device

RANK, WORLD_SIZE, device = setup_distributed()

# 设置默认精度
torch.set_default_dtype(torch.float64)
# 不同 Rank 设置不同随机种子，保证采样独立
torch.manual_seed(42 + RANK)

# ==========================================
# 2. 参数设置与模型加载
# ==========================================
Lx, Ly = 2, 2
nsites = Lx * Ly
N_f = nsites
D = 4
chi = -1

# 路径配置 (保持你的原样)
pwd = '/home/sijingdu/TNVMC/VMC_code/vmc_torch/vmc_torch/experiment/vmap'
u1z2 = True
appendix = '_U1SU' if u1z2 else ''

# 加载骨架 (这部分很快，可以在 CPU 做完再转 GPU)
# 注意：pickle load 最好只在 Rank 0 做然后广播，或者大家各自读文件(如果文件系统支持并发)
# 这里假设大家各自读文件没问题
params_pkl = pickle.load(open(pwd+f'/{Lx}x{Ly}/t=1.0_U=8.0/N={N_f}/Z2/D={D}/peps_su_params{appendix}.pkl', 'rb'))
skeleton = pickle.load(open(pwd+f'/{Lx}x{Ly}/t=1.0_U=8.0/N={N_f}/Z2/D={D}/peps_skeleton{appendix}.pkl', 'rb'))
peps = qtn.unpack(params_pkl, skeleton)

# 预处理 (CPU)
for ts in peps.tensors:
    ts.modify(data=ts.data.to_flat() * 10)
for site in peps.sites:
    peps[site].data._label = site
    peps[site].data.indices[-1]._linearmap = ((0, 0), (1, 0), (1, 1), (0, 1))

# 初始化模型并移动到 GPU
# fpeps_model = Transformer_fPEPS_Model_batchedAttn(
#     tn=peps, max_bond=chi, embed_dim=8, attn_heads=4, nn_hidden_dim=16, nn_eta=1, dtype=torch.float64,
# )
fpeps_model = fPEPS_Model(
    tn=peps, max_bond=chi, dtype=torch.float64,
)
fpeps_model.to(device) # <--- 关键：模型全在 GPU

# 初始化权重
model_params_vec = torch.nn.utils.parameters_to_vector(fpeps_model.parameters())
init_std = float(model_params_vec.std().item()) * 0.1
fpeps_model.apply(lambda x: init_weights_to_zero(x, std=init_std))

n_params = sum(p.numel() for p in fpeps_model.parameters())
if RANK == 0:
    print(f'Model parameters: {n_params} | World Size: {WORLD_SIZE} | Device: {device}')

# Hamiltonians
t, U = 1.0, 8.0
n_fermions_per_spin = (N_f // 2, N_f // 2)
H = spinful_Fermi_Hubbard_square_lattice_torch(
    Lx, Ly, t, U, N_f, pbc=False, n_fermions_per_spin=n_fermions_per_spin, no_u1_symmetry=False,gpu=True
)
graph = H.graph

# ==========================================
# 3. 采样配置
# ==========================================
Total_Ns = int(5e2)  # 总样本数
# 确保每个 Rank 分到的样本数是整数
assert Total_Ns % WORLD_SIZE == 0, f"Total samples {Total_Ns} must be divisible by World Size {WORLD_SIZE}"
samples_per_rank = Total_Ns // WORLD_SIZE

# 并行运行的 Chain 数量 (Batch Size)
# 如果显存够，可以直接设为 samples_per_rank，这样一步到位
# 如果显存不够，可以设小一点，循环多次累积
batch_size_per_rank = 64
# 确保初始化 walkers 在 GPU 上
fxs_list = [random_initial_config(N_f, nsites, seed=None) for _ in range(batch_size_per_rank)]
fxs = torch.stack(fxs_list).to(device)

# Burn-in (Warmup)
for _ in range(2): # 调整你的 burn-in 步数
    fxs, current_amps = sample_next(fxs, fpeps_model, graph, seed=None)

# VMC Settings
vmc_steps = 50
minSR = True # 推荐用 minSR，因为全在 GPU 上很快
learning_rate = 0.1
save_state_every = 10000
stats_file = pwd + f'/stats_{fpeps_model._get_name()}.json'
stats = {'mean': [], 'error': [], 'variance': []}

# ==========================================
# 4. VMC 主循环 (All on GPU)
# ==========================================
if RANK == 0:
    vmc_pbar = tqdm(total=vmc_steps, desc="VMC Steps")

for step in range(vmc_steps):
    t0 = time.time()
    
    # --- A. 本地采样与梯度计算 (Local Sampling & Compute) ---
    # 我们需要在本地累积 samples_per_rank 这么多数据
    local_energies_acc = []
    local_grads_acc = []
    local_amps_acc = []
    
    ##############################################################################
    # debug
    current_count = 0
    while current_count < samples_per_rank:
        # 1. 采样
        fxs, current_amps = sample_next(fxs, fpeps_model, graph, seed=None)
        
        # 2. 计算能量
        # 注意：evaluate_energy 内部需要确保返回 GPU tensor
        _, local_E = evaluate_energy(fxs, fpeps_model, H, current_amps)
        
        # 3. 计算梯度
        # batch_size=batch_size_per_rank 表示一次处理完，避免 OOM
        local_grads, local_amps = compute_grads(fxs, fpeps_model, vectorize=True)
        
        # 4. 收集 (还是 GPU tensor)
        # 裁剪掉多余的样本 (如果 batch_size 不整除 samples_per_rank)
        needed = min(batch_size_per_rank, samples_per_rank - current_count)
        
        local_energies_acc.append(local_E[:needed])
        local_grads_acc.append(local_grads[:needed])
        local_amps_acc.append(local_amps[:needed])
        
        current_count += needed

    ################################################################################

    # 拼接本地数据
    my_energies = torch.cat(local_energies_acc) # (samples_per_rank, )
    my_grads = torch.cat(local_grads_acc)       # (samples_per_rank, Np)
    my_amps = torch.cat(local_amps_acc)         # (samples_per_rank, )
    
    # 确保内存连续 (通信必须)
    my_energies = my_energies.contiguous()
    my_grads = my_grads.contiguous()
    my_amps = my_amps.contiguous()

    # --- B. 全局聚合 (Global Gather) ---
    # 准备接收容器
    def gather_tensor(tensor):
        gather_list = [torch.zeros_like(tensor) for _ in range(WORLD_SIZE)]
        dist.all_gather(gather_list, tensor)
        return torch.cat(gather_list)

    total_energies = gather_tensor(my_energies) # (Total_Ns, )
    total_amps = gather_tensor(my_amps)         # (Total_Ns, )
    # 如果 Np 很大，gather grads 可能会显存爆炸。如果炸了需要换策略 (reduce_scatter)。
    # 对于 Transformer fPEPS (Np ~ 10k-100k)，完全没问题。
    total_grads = gather_tensor(my_grads)       # (Total_Ns, Np)

    # --- C. 优化步 (Optimization) ---
    # 为了数值稳定和计算，我们在 Rank 0 上做 SR 的矩阵求逆
    # 其他 Rank 等待广播
    
    # 1. 计算全局能量平均
    E_mean = torch.mean(total_energies)
    E_var = torch.var(total_energies)
    
    # 准备 update 向量容器
    dp = torch.zeros(n_params, device=device, dtype=torch.float64)

    if RANK == 0:
        # SR / MinSR Logic (All GPU)
        # log_psi gradients
        log_grads = total_grads / total_amps # (Total_Ns, Np)
        log_grads_mean = torch.mean(log_grads, dim=0)
        
        # Centering
        O_sk = (log_grads - log_grads_mean.unsqueeze(0)) / np.sqrt(Total_Ns)
        E_s = (total_energies - E_mean) / np.sqrt(Total_Ns)
        
        # SR Matrix T = O * O^dagger
        # (Total_Ns, Np) @ (Np, Total_Ns) -> (Total_Ns, Total_Ns)
        # 如果 Np < Total_Ns (参数少，样本多)，应该算 (Np, Np) 的协方差矩阵
        # 如果 Np > Total_Ns (参数多，样本少)，minSR 算 (Ns, Ns) 的 Gram Matrix
        
        if minSR:
            # T is (Ns, Ns) - usually small
            T = O_sk @ O_sk.conj().T
            # Pseudo-inverse
            T_inv = torch.linalg.pinv(T, rtol=1e-12, hermitian=True)
            # dp = O^dagger * T_inv * E_s
            # (Np, Ns) @ (Ns, Ns) @ (Ns, ) -> (Np, )
            dp = O_sk.conj().T @ (T_inv @ E_s)
        else:
            # 也可以在这里实现 Iterative Solver (CG/MinRes) using torch.linalg
            # 简单起见，这里假设用 minSR
            pass

        # # For debug gradient, use raw gradient
        # dp = torch.einsum('si,s->i', log_grads, total_energies) / Total_Ns - E_mean * torch.mean(log_grads, dim=0)

        # 打印信息
        print(f"Step {step}: E = {E_mean.item()/nsites:.6f}, Var = {E_var.item()/nsites**2:.2e}, Std of E_mean = {(E_var.item()/(Total_Ns*nsites**2))**0.5:.2e}")
        print(f'SR dp mean: {dp.mean()}, std: {dp.std()}')

    # --- D. 广播更新量 (Broadcast Update) ---
    dist.broadcast(dp, src=0)

    # --- E. 更新模型参数 ---
    # 小技巧：先把 param vector 拿出来，减去 dp，再放回去
    current_params_vec = torch.nn.utils.parameters_to_vector(fpeps_model.parameters())
    new_params_vec = current_params_vec - learning_rate * dp
    torch.nn.utils.vector_to_parameters(new_params_vec, fpeps_model.parameters())

    t1 = time.time()
    
    # --- F. Logging (Rank 0 only) ---
    if RANK == 0:
        # stats['mean'].append(E_mean.item()/nsites)
        # stats['error'].append(torch.sqrt(E_var).item()/nsites)
        # stats['variance'].append(E_var.item())
        
        # with open(stats_file, 'w') as f:
        #     json.dump(stats, f)
            
        # if (step + 1) % save_state_every == 0:
        #     ckpt_path = pwd + f'/checkpoint_{step+1}.pt'
        #     torch.save(fpeps_model.state_dict(), ckpt_path)
        
        vmc_pbar.update(1)

if RANK == 0:
    vmc_pbar.close()

# 销毁进程组
dist.destroy_process_group()

Model parameters: 128 | World Size: 1 | Device: cuda:0


VMC Steps:   2%|▏         | 1/50 [00:02<02:09,  2.65s/it]

Step 0: E = 0.342344, Var = 8.92e-01, Std of E_mean = 4.22e-02
SR dp mean: 0.1402525148382699, std: 4.339575989193563


VMC Steps:   4%|▍         | 2/50 [00:05<01:59,  2.49s/it]

Step 1: E = 0.264746, Var = 1.49e+00, Std of E_mean = 5.46e-02
SR dp mean: -0.25565849697787224, std: 3.2501419288451547


VMC Steps:   6%|▌         | 3/50 [00:07<01:59,  2.54s/it]

Step 2: E = 0.080028, Var = 3.04e-01, Std of E_mean = 2.47e-02
SR dp mean: -0.5029354180332347, std: 6.642014631556633


VMC Steps:   8%|▊         | 4/50 [00:10<02:01,  2.64s/it]

Step 3: E = -0.060726, Var = 1.14e-01, Std of E_mean = 1.51e-02
SR dp mean: 0.029852509806195547, std: 3.3284047160087624


VMC Steps:  10%|█         | 5/50 [00:13<02:03,  2.73s/it]

Step 4: E = -0.071733, Var = 6.14e-02, Std of E_mean = 1.11e-02
SR dp mean: -0.015150399601216329, std: 1.21282559461489


VMC Steps:  12%|█▏        | 6/50 [00:16<02:02,  2.78s/it]

Step 5: E = -0.103638, Var = 7.02e-03, Std of E_mean = 3.75e-03
SR dp mean: -0.07416685104231913, std: 0.5103943355049855


VMC Steps:  14%|█▍        | 7/50 [00:19<02:00,  2.81s/it]

Step 6: E = -0.097767, Var = 2.39e-02, Std of E_mean = 6.91e-03
SR dp mean: 0.05111357004638722, std: 1.318456989465329


VMC Steps:  16%|█▌        | 8/50 [00:21<01:59,  2.83s/it]

Step 7: E = -0.107848, Var = 3.92e-03, Std of E_mean = 2.80e-03
SR dp mean: 0.0011944513957517708, std: 0.33725259860144086


VMC Steps:  18%|█▊        | 9/50 [00:24<01:56,  2.84s/it]

Step 8: E = -0.113258, Var = 4.38e-03, Std of E_mean = 2.96e-03
SR dp mean: -0.00960746043527783, std: 0.5876895939502549


VMC Steps:  20%|██        | 10/50 [00:27<01:54,  2.86s/it]

Step 9: E = -0.110830, Var = 7.51e-03, Std of E_mean = 3.88e-03
SR dp mean: 0.0650330782163315, std: 0.49646564280259686


VMC Steps:  22%|██▏       | 11/50 [00:30<01:51,  2.87s/it]

Step 10: E = -0.142724, Var = 1.86e-01, Std of E_mean = 1.93e-02
SR dp mean: 0.09775656889323286, std: 2.039713887063871


VMC Steps:  24%|██▍       | 12/50 [00:33<01:49,  2.88s/it]

Step 11: E = -0.126652, Var = 6.60e-03, Std of E_mean = 3.63e-03
SR dp mean: -0.1302555064743034, std: 9.05015897467296


VMC Steps:  26%|██▌       | 13/50 [00:36<01:48,  2.93s/it]

Step 12: E = -0.062193, Var = 9.77e-02, Std of E_mean = 1.40e-02
SR dp mean: -0.025087843810528435, std: 2.4652350269034087


VMC Steps:  28%|██▊       | 14/50 [00:40<01:51,  3.09s/it]

Step 13: E = -0.048701, Var = 1.44e-01, Std of E_mean = 1.70e-02
SR dp mean: 0.29512474419635126, std: 3.11598471532187


VMC Steps:  30%|███       | 15/50 [00:43<01:54,  3.28s/it]

Step 14: E = -0.098231, Var = 5.41e-02, Std of E_mean = 1.04e-02
SR dp mean: -0.22168260342843, std: 1.9394327816510644


VMC Steps:  32%|███▏      | 16/50 [00:47<01:52,  3.32s/it]

Step 15: E = -0.105112, Var = 4.16e-02, Std of E_mean = 9.12e-03
SR dp mean: -0.06374344731524788, std: 1.9403349855157184


VMC Steps:  34%|███▍      | 17/50 [00:50<01:54,  3.48s/it]

Step 16: E = -0.113849, Var = 6.00e-02, Std of E_mean = 1.10e-02
SR dp mean: -0.8930259886747504, std: 4.5347204553487765


VMC Steps:  36%|███▌      | 18/50 [00:55<01:57,  3.67s/it]

Step 17: E = -0.123439, Var = 4.12e-02, Std of E_mean = 9.08e-03
SR dp mean: 0.2498581863896141, std: 2.039108596699125


VMC Steps:  38%|███▊      | 19/50 [00:58<01:47,  3.47s/it]

Step 18: E = -0.123508, Var = 3.55e-03, Std of E_mean = 2.66e-03
SR dp mean: 0.040718757554239704, std: 0.28438599925931624


VMC Steps:  40%|████      | 20/50 [01:01<01:42,  3.41s/it]

Step 19: E = -0.049759, Var = 1.85e+00, Std of E_mean = 6.08e-02
SR dp mean: 0.005335274036637583, std: 5.831408819309945


VMC Steps:  42%|████▏     | 21/50 [01:05<01:42,  3.54s/it]

Step 20: E = -0.071748, Var = 2.65e-01, Std of E_mean = 2.30e-02
SR dp mean: 0.3580838958382103, std: 3.390476063743327


VMC Steps:  44%|████▍     | 22/50 [01:08<01:37,  3.50s/it]

Step 21: E = -0.168455, Var = 1.81e-01, Std of E_mean = 1.90e-02
SR dp mean: -0.10796654117810331, std: 2.367569393039268


KeyboardInterrupt: 

In [5]:
import os
os.environ["OPENBLAS_NUM_THREADS"] = '1'
os.environ['MKL_NUM_THREADS'] = '2'
os.environ["OMP_NUM_THREADS"] = '1'
from mpi4py import MPI
import numpy as np
import symmray as sr
import quimb.tensor as qtn
import pickle
from autoray import do
import torch
import time
from tqdm import tqdm
from vmap_utils import sample_next, evaluate_energy, compute_grads, random_initial_config
from vmap_utils import fPEPS_Model
from vmc_torch.hamiltonian_torch import spinful_Fermi_Hubbard_square_lattice_torch

COMM = MPI.COMM_WORLD
RANK = COMM.Get_rank()
SIZE = COMM.Get_size()

# torch.set_default_device("cuda:0") # GPU
torch.set_default_device("cpu") # CPU
torch.random.manual_seed(42 + RANK)

Lx = 2
Ly = 2
nsites = Lx * Ly
N_f = nsites  # filling
D = 4
chi = -1
seed = RANK + 42
# only the flat backend is compatible with jax.jit
flat = True
pwd = '/home/sijingdu/TNVMC/VMC_code/vmc_torch/vmc_torch/experiment/vmap'
u1z2 = True
appendix = '_U1SU' if u1z2 else ''
params = pickle.load(open(pwd+f'/{Lx}x{Ly}/t=1.0_U=8.0/N={N_f}/Z2/D={D}/peps_su_params{appendix}.pkl', 'rb'))
skeleton = pickle.load(open(pwd+f'/{Lx}x{Ly}/t=1.0_U=8.0/N={N_f}/Z2/D={D}/peps_skeleton{appendix}.pkl', 'rb'))
peps = qtn.unpack(params, skeleton)
for ts in peps.tensors:
    # print(ts.data)
    ts.modify(data=ts.data.to_flat()*10)
for site in peps.sites:
    peps[site].data._label = site
    peps[site].data.indices[-1]._linearmap = ((0, 0), (1, 0), (1, 1), (0, 1)) # Important for U1->Z2 fPEPS

fpeps_model = fPEPS_Model(
    peps, max_bond=chi, dtype=torch.float64
)
n_params = sum(p.numel() for p in fpeps_model.parameters())
if RANK == 0:
    # print model size
    print(f'fPEPS-based model number of parameters: {n_params}')

# generate Hamiltonian graph
t=1.0
U=8.0
n_fermions_per_spin = (N_f // 2, N_f // 2)
H = spinful_Fermi_Hubbard_square_lattice_torch(
    Lx,
    Ly,
    t,
    U,
    N_f,
    pbc=False,
    n_fermions_per_spin=n_fermions_per_spin,
    no_u1_symmetry=False,
)
graph = H.graph

if Lx*Ly <= 6 and RANK == 0:
    H_dense = torch.tensor(H.to_dense())
    psi_vec = fpeps_model(torch.tensor(H.hilbert.all_states(), dtype=torch.int32))
    energies_exact, states_exact = torch.linalg.eigh(H_dense)
    print(f'Exact ground state energy: {energies_exact[0].item()/nsites}')
    SU_E = (psi_vec.conj().T @ H_dense @ psi_vec) / (psi_vec.conj().T @ psi_vec)
    print(f'SU variational energy: {SU_E.item()/nsites}')

    terms = sr.hamiltonians.ham_fermi_hubbard_from_edges(
        "Z2",
        edges=tuple(peps.gen_bond_coos()),
        U=8,
        mu=0.0,
    )
    terms = {k: v.to_flat() for k, v in terms.items()}
    new_peps = peps.copy()
    new_peps.apply_to_arrays(lambda x: np.array(x))
    E_double = new_peps.compute_local_expectation_exact(terms, normalized=True)
    print(f'Double layer energy: {E_double/nsites}')


# Prepare initial samples
Ns = int(1024) # total sample size
# batchsize per rank
B = 1024
B_grad = 10
fxs = []
for _ in range(B):
    fxs.append(random_initial_config(N_f, nsites, seed=None))
fxs = torch.stack(fxs)
# burn-in for each rank
t0 = MPI.Wtime()
for _ in range(10):
    fxs, current_amps = sample_next(fxs, fpeps_model, graph)
t1 = MPI.Wtime()
if RANK == 0:
    print(f'Burn-in sampling time: {t1-t0:.4f} s')

vmc_steps = 50
TAG_OFFSET = 424242
vmc_pbar = tqdm(total=vmc_steps, desc="VMC steps")
minSR=False
learning_rate = 0.1

stats_file = pwd+f'/{Lx}x{Ly}/t=1.0_U=8.0/N={N_f}/Z2/D={D}/vmc_mpi_stats_{fpeps_model._get_name()}.json'
stats = {
    'Np': n_params,
    'sample size': Ns,
    'mean': [],
    'error': [],
    'variance': [],
}
save_state_every = 10

for _ in range(vmc_steps):
    sample_time = 0
    local_energy_time = 0
    grad_time = 0
    t0 = MPI.Wtime()
    message_tag = _
    # rank 0 is the master process, receives data and send out signal for stopping
    
    E_loc_vec = []
    amps_vec = []
    grads_vec_list = []

    n = 0
    n_total = 0
    # terminate = False
    terminate = np.array([0], dtype=np.int32)
    if RANK == 0:
        pbar = tqdm(total=Ns, desc="Sampling starts...")
        fxs, current_amps = sample_next(fxs, fpeps_model, graph, seed=None)
        energy, local_energies = evaluate_energy(fxs, fpeps_model, H, current_amps)
        grads_vec, amps = compute_grads(fxs, fpeps_model, vectorize=True)

        
        E_loc_vec.append(local_energies.detach().numpy())
        amps_vec.append(amps.detach().numpy())
        grads_vec_list.append(grads_vec.detach().numpy())

        n += fxs.shape[0]
        n_total += fxs.shape[0]
        pbar.update(fxs.shape[0])

    COMM.Barrier()  
    # if RANK == 1:
    #     print(f'Rank {RANK} B={B}, n_sample={n}\nSampling time: {sample_time:.4f} s, local energy time: {local_energy_time:.4f} s, grad time: {grad_time:.4f} s')

    local_energies = np.concatenate(E_loc_vec)
    grads_vec = np.concatenate(grads_vec_list)
    amps = np.concatenate(amps_vec)

    # use MPI to gather energies and grads from all ranks
    all_energies = COMM.allgather(local_energies)
    all_energies = np.concatenate(all_energies)
    energy = np.mean(all_energies)
    energy_var = np.var(all_energies) / all_energies.shape[0]

    if RANK == 0:
        print(f'\n\nSTEP {_} VMC energy: {energy/nsites}')
        N_total = all_energies.shape[0]
        print(f'Total sample size: {N_total}')

    # SR to compute parameter update
    if minSR:
        all_grads = COMM.gather(grads_vec, root=0) # shape (N_total, Np)
        all_amps = COMM.gather(amps, root=0)
        if RANK == 0:
            all_grads = np.concatenate(all_grads)
            all_amps = np.concatenate(all_amps)
            all_energies = torch.tensor(all_energies, dtype=torch.float64)
            grads_vec = torch.tensor(all_grads, dtype=torch.float64)
            amps = torch.tensor(all_amps, dtype=torch.float64)
            # Now that we have local energies, amps and per-sample gradients, we can compute the energy gradient
            # With the energy gradient, we can further do SR for optimization
            # Or we can do minSR, which is simpler here
            t0_sr = time.time()
            with torch.no_grad():
                all_energies_mean = torch.mean(all_energies)
                # compute log-derivative grads
                all_logamp_grads_vec = grads_vec / amps  # shape (B, Np)
                log_grads_vec_mean = torch.mean(all_logamp_grads_vec, dim=0)  # shape (Np,)

                O_sk = (all_logamp_grads_vec - log_grads_vec_mean[None, :]) / (N_total**0.5)  # shape (N_total, Np)
                T = (O_sk @ O_sk.T.conj())  # shape (N_total, N_total)
                E_s = (all_energies - all_energies_mean) / (N_total**0.5)  # shape (N_total,)

                # minSR: need to solve O_sk * dp = E_s in the least square sense, using the pseudo-inverse of O_sk to get the minimum norm solution
                T_inv = torch.linalg.pinv(T,  rtol=1e-12, atol=0, hermitian=True)
                dp = O_sk.conj().T @ (T_inv @ E_s)  # shape (Np,)
            print(f'MinSR dp mean: {dp.mean()}, std: {dp.std()}')
        
    else:
        # SR with iterative minres solver
        local_logamp_grads_vec = grads_vec / amps  # shape (n, Np)
        local_logamp_grads_vec_sum = np.sum(local_logamp_grads_vec, axis=0)  # shape (Np,)
        local_E_logamp_grads_vec_sum = np.dot(local_energies, local_logamp_grads_vec)  # shape (Np,)
        n_local = local_energies.shape[0]
        N_total = COMM.allreduce(n_local, op=MPI.SUM)

        logamp_grads_vec_sum = COMM.allgather(local_logamp_grads_vec_sum)
        E_logamp_grads_vec_sum = COMM.allgather(local_E_logamp_grads_vec_sum)

        logamp_grads_vec_sum = np.array(logamp_grads_vec_sum)  # shape (SIZE, Np)
        E_logamp_grads_vec_sum = np.array(E_logamp_grads_vec_sum)

        logamp_grads_vec_mean = np.sum(logamp_grads_vec_sum, axis=0) / N_total # shape (Np,)
        E_logamp_grads_vec_mean = np.sum(E_logamp_grads_vec_sum, axis=0) / N_total  # shape (Np,)
        
        energy_grad = E_logamp_grads_vec_mean - energy * logamp_grads_vec_mean  # shape (Np,)
        
        def R_dot_x(x, eta=1e-6):
            x_out_local = np.zeros_like(x)
            # use matrix multiplication for speedup
            x_out_local = do('dot', local_logamp_grads_vec.T, do('dot', local_logamp_grads_vec, x))
            # synchronize the result
            x_out = COMM.allreduce(x_out_local, op=MPI.SUM)/N_total
            x_out -= do('dot', logamp_grads_vec_mean, x)*logamp_grads_vec_mean
            return x_out + eta*x
        
        import scipy.sparse.linalg as spla
        def matvec(x):
            return R_dot_x(x, 1e-4)
        A = spla.LinearOperator((n_params, n_params), matvec=matvec)
        b = energy_grad
        dp, info = spla.minres(A, b, rtol=1e-4, maxiter=100)
        dp = energy_grad  # simple gradient descent without SR for debug
        print(f'SR dp mean: {np.mean(dp)}, std: {np.std(dp)}')
        
    if RANK == 0:
        # update params
        params_vec = torch.nn.utils.parameters_to_vector(fpeps_model.parameters())

        new_params_vec = params_vec - learning_rate * torch.tensor(dp, dtype=torch.float64)
    
    COMM.Barrier()
    
    # broadcast the new params to all ranks
    new_params_vec = COMM.bcast(new_params_vec if RANK == 0 else None, root=0)
    # print(f'Rank {RANK} received new params vector of shape: {new_params_vec.shape}')

    # load the new params back to the model
    torch.nn.utils.vector_to_parameters(new_params_vec, fpeps_model.parameters())

    vmc_pbar.update(1)
    t1 = MPI.Wtime()
    if RANK == 0:
        # save step, energy, energy variance to a file (if exists, delete and create a new one)
        log_file = pwd+f'/{Lx}x{Ly}/t=1.0_U=8.0/N={N_f}/Z2/D={D}/vmc_mpi_log_{fpeps_model._get_name()}.txt'
        print(f'STEP {_}:\nEnergy per site: {energy/nsites}\nEnergy variance square root: {np.sqrt(energy_var)/nsites}\nSample size: {N_total}\nTime elapsed: {t1 - t0} seconds\n\n')


vmc_pbar.close()

fPEPS-based model number of parameters: 128
Exact ground state energy: -0.33005873956798376
SU variational energy: 0.6509558133972732
Double layer energy: 0.5043600730329295
Burn-in sampling time: 1.4283 s


VMC steps:  28%|██▊       | 14/50 [00:46<01:59,  3.32s/it]
Sampling starts...: 100%|██████████| 1024/1024 [00:27<00:00, 37.66it/s] 
VMC steps:   2%|▏         | 1/50 [00:01<01:06,  1.36s/it]



STEP 0 VMC energy: 0.563015872939573
Total sample size: 1024
SR dp mean: -0.007137715565171081, std: 0.09208547255690996
STEP 0:
Energy per site: 0.563015872939573
Energy variance square root: 0.040940264297686546
Sample size: 1024
Time elapsed: 1.3636891200000036 seconds





Sampling starts...: 100%|██████████| 1024/1024 [00:01<00:00, 751.32it/s]

VMC steps:   4%|▍         | 2/50 [00:02<01:03,  1.33s/it]



STEP 1 VMC energy: 0.5598735873216727
Total sample size: 1024
SR dp mean: 0.0017383671039250574, std: 0.0695445824606592
STEP 1:
Energy per site: 0.5598735873216727
Energy variance square root: 0.03347178868674712
Sample size: 1024
Time elapsed: 1.3098883119999982 seconds




Sampling starts...: 100%|██████████| 1024/1024 [00:01<00:00, 781.47it/s]
VMC steps:   6%|▌         | 3/50 [00:03<01:02,  1.32s/it]



STEP 2 VMC energy: 0.5454575718193018
Total sample size: 1024
SR dp mean: 0.0015265783079231022, std: 0.06896169801579798
STEP 2:
Energy per site: 0.5454575718193018
Energy variance square root: 0.032293804246041974
Sample size: 1024
Time elapsed: 1.3127004859999971 seconds





Sampling starts...: 100%|██████████| 1024/1024 [00:01<00:00, 779.86it/s]


KeyboardInterrupt: 