In [57]:
import os
os.environ["OPENBLAS_NUM_THREADS"] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ["OMP_NUM_THREADS"] = '1'
import numpy as np
import quimb as qu
import quimb.tensor as qtn
import symmray as sr
import torch
import torch.nn as nn
import pickle
from vmc_torch.fermion_utils import fPEPS

Lx = 2
Ly = 2
nsites = Lx * Ly
N_f = nsites  # half-filling
D = 4
chi = 2*D
seed = 42
# only the flat backend is compatible with jax.jit
flat = True
pwd = '/home/sijingdu/TNVMC/VMC_code/vmc_torch/vmc_torch/experiment/vmap'
# params = pickle.load(open(pwd+f'/{Lx}x{Ly}/t=1.0_U=8.0/N={N_f}/Z2/D={D}/peps_su_params.pkl', 'rb'))
# skeleton = pickle.load(open(pwd+f'/{Lx}x{Ly}/t=1.0_U=8.0/N={N_f}/Z2/D={D}/peps_skeleton.pkl', 'rb'))
# peps = qtn.unpack(params, skeleton)
peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=[
        (0, 0),  # linear index 0 -> charge 0, offset 0
        (1, 1),  # linear index 1 -> charge 1, offset 1
        (1, 0),  # linear index 2 -> charge 1, offset 0
        (0, 1),  # linear index 3 -> charge 0, offset 1
    ],
    subsizes="equal",
    flat=flat,
    seed=seed,
)
original_peps = peps.copy()

# SU
terms = sr.ham_fermi_hubbard_from_edges(
    symmetry='Z2',
    edges=tuple(peps.gen_bond_coos()),
    t=1.0,
    U=8.0,
    mu=4.0,
)
ham = qtn.LocalHam2D(Lx, Ly, terms)
ham.apply_to_arrays(lambda A: A.to_flat())
su = qtn.SimpleUpdateGen(
    peps,
    ham,
    # setting a cutoff is important to turn on dynamic charge sectors
    # cutoff=1e-12,
    cutoff=0.0,
    second_order_reflect=True,
    # SimpleUpdateGen computes cluster energies by default
    # which might not be accurate
    compute_energy_every=10,
    compute_energy_opts=dict(max_distance=1),
    compute_energy_per_site=True,
    # use a fixed trotterization order
    ordering="sort",
    # if the gauge difference drops below this, we consider the PEPS converged
    tol=1e-9,
)

# run the evolution, these are reasonable defaults
tau = 0.1
steps = 20
su.evolve(steps, tau=tau)
peps = su.get_state()
for i in range(len(peps.arrays)):
    peps.arrays[i]._label = original_peps.arrays[i].label

# get pytree of initial parameters, and reference tn structure
params, skeleton = qtn.pack(peps)

n=20, D=4, tau=0.1, max|dS|=0.0136, energyâ‰ˆ-4.2385: 100%|##########| 20/20 [00:00<00:00, 33.66it/s] 


In [58]:
fpeps = peps.copy()
for ts in fpeps.tensors:
    ts.modify(data=ts.data.to_blocksparse())

fpeps = fPEPS(fpeps)
fpeps.apply_to_arrays(lambda x: torch.tensor(x, dtype=torch.float64))

In [59]:
def amplitude(x, params):
    tn = qtn.unpack(params, skeleton)
    # might need to specify the right site ordering here
    amp = tn.isel({tn.site_ind(site): x[i] for i, site in enumerate(tn.sites)})
    amp.contract_boundary_from_ymin_(max_bond=chi, cutoff=0.0, yrange=[0, amp.Ly//2-1])
    amp.contract_boundary_from_ymax_(max_bond=chi, cutoff=0.0, yrange=[amp.Ly//2, amp.Ly-1])
    return amp.contract()


# generate half-filling configs
# batchsize
B = 256
rng = np.random.default_rng(seed)
xs_u = np.concatenate(
    [
        np.zeros((B, nsites // 2), dtype=np.int32),
        np.ones((B, nsites // 2), dtype=np.int32),
    ],
    axis=1,
)
xs_d = xs_u.copy()
xs_u = rng.permuted(xs_u, axis=1)
xs_d = rng.permuted(xs_d, axis=1)
xs = np.concatenate([xs_u[:, :, None], xs_d[:, :, None]], axis=2).reshape(B, -1)
fxs = 2 * xs[:, ::2] + xs[:, 1::2]  # map to 0,1,2,3

# torch.set_default_device("cuda:0") # GPU
torch.set_default_device("cpu") # CPU

# convert bitstrings and arrays to torch
fxs = torch.tensor(fxs)
params = qu.tree_map(
    lambda x: torch.tensor(x, dtype=torch.float64),
    params,
)

vamp = torch.vmap(
    amplitude,
    # batch on configs, not parameters
    in_dims=(0, None),
)

class SimpleModel(nn.Module):
    def __init__(self, tn, dtype=torch.float64):
        super().__init__()
        params, skeleton = qtn.pack(tn)
        self.dtype = dtype
        self.skeleton = skeleton
        self.params = nn.ParameterDict(
            {
                str(k): nn.Parameter(torch.tensor(v['blocks'], dtype=self.dtype))
                for k, v in params.items()
            }
        )
    
    def params_as_dict(self):
        return {int(k): v for k, v in self.params.items()}
    
    def amplitude(self, x, params):
        tn = qtn.unpack(params, self.skeleton)
        # might need to specify the right site ordering here
        amp = tn.isel({tn.site_ind(site): x[i] for i, site in enumerate(tn.sites)})
        amp.contract_boundary_from_ymin_(max_bond=chi, cutoff=0.0, yrange=[0, amp.Ly//2-1])
        amp.contract_boundary_from_ymax_(max_bond=chi, cutoff=0.0, yrange=[amp.Ly//2, amp.Ly-1])
        return amp.contract()
    
    def vamp(self, x, params):
        params = {int(k): {'blocks':v} for k, v in params.items()}
        return torch.vmap(
            self.amplitude,
            in_dims=(0, None),
        )(x, params)

    def forward(self, x):
        return self.vamp(x, self.params)

peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=torch.float64))
fpeps_model = SimpleModel(peps)
fpeps_model(fxs[0].unsqueeze(0))  # warm up

tensor([-1.4924e-06], dtype=torch.float64, grad_fn=<MulBackward0>)

In [60]:
import random
def propose_exchange_or_hopping(i, j, current_config, hopping_rate=0.25):
    ind_n_map = {0: 0, 1: 1, 2: 1, 3: 2}
    if current_config[i] == current_config[j]:
        return current_config, 0
    proposed_config = current_config.clone()
    config_i = current_config[i].item()
    config_j = current_config[j].item()
    if random.random() < 1 - hopping_rate:
        # exchange
        proposed_config[i] = config_j
        proposed_config[j] = config_i
    else:
        # hopping
        n_i = ind_n_map[current_config[i].item()]
        n_j = ind_n_map[current_config[j].item()]
        delta_n = abs(n_i - n_j)
        if delta_n == 1:
            # consider only valid hopping: (0, u) -> (u, 0); (d, ud) -> (ud, d)
            proposed_config[i] = config_j
            proposed_config[j] = config_i
        elif delta_n == 0:
            # consider only valid hopping: (u, d) -> (0, ud) or (ud, 0)
            choices = [(0, 3), (3, 0)]
            choice = random.choice(choices)
            proposed_config[i] = choice[0]
            proposed_config[j] = choice[1]
        elif delta_n == 2:
            # consider only valid hopping: (0, ud) -> (u, d) or (d, u)
            choices = [(1, 2), (2, 1)]
            choice = random.choice(choices)
            proposed_config[i] = choice[0]
            proposed_config[j] = choice[1]
        else:
            raise ValueError("Invalid configuration")
    return proposed_config, 1

Sampling

In [61]:
# Batched Metropolis-Hastings updates
import time

def sample_next(fxs, fpeps_model, graph):
    current_amps = fpeps_model(fxs)
    for row, edges in graph.row_edges.items():
        for edge in edges:
            i, j = edge
            proposed_fxs = []
            new_flags = []
            # t0 = time.time()
            for fx in fxs:
                proposed_fx, new = propose_exchange_or_hopping(i, j, fx)
                proposed_fxs.append(proposed_fx)
                new_flags.append(new)
            # t1 = time.time()
            # print(f"Propose time: {t1 - t0}")
            proposed_fxs = torch.stack(proposed_fxs, dim=0)

            # only compute amplitudes for newly proposed configs
            new_proposed_fxs = proposed_fxs[torch.tensor(new_flags, dtype=torch.bool)]
            new_proposed_amps = fpeps_model(new_proposed_fxs)
            # print(f"Number of new proposals: {new_proposed_amps.shape[0]} ({B})" )
            proposed_amps = current_amps.clone()
            proposed_amps[torch.tensor(new_flags, dtype=torch.bool)] = new_proposed_amps
            ratio = proposed_amps**2 / current_amps**2
            accept_prob = torch.minimum(ratio, torch.ones_like(ratio))
            for k in range(B):
                if random.random() < accept_prob[k].item():
                    fxs[k] = proposed_fxs[k]
                    current_amps[k] = proposed_amps[k]

    for col, edges in graph.col_edges.items():
        for edge in edges:
            i, j = edge
            proposed_fxs = []
            new_flags = []
            for fx in fxs:
                proposed_fx, new = propose_exchange_or_hopping(i, j, fx)
                proposed_fxs.append(proposed_fx)
                new_flags.append(new)
            proposed_fxs = torch.stack(proposed_fxs, dim=0)
            # only compute amplitudes for newly proposed configs
            new_proposed_fxs = proposed_fxs[torch.tensor(new_flags, dtype=torch.bool)]
            new_proposed_amps = fpeps_model(new_proposed_fxs)
            # print(f"Number of new proposals: {new_proposed_amps.shape[0]} ({B})" )
            proposed_amps = current_amps.clone()
            proposed_amps[torch.tensor(new_flags, dtype=torch.bool)] = new_proposed_amps
            ratio = proposed_amps**2 / current_amps**2
            accept_prob = torch.minimum(ratio, torch.ones_like(ratio))
            for k in range(B):
                if random.random() < accept_prob[k].item():
                    fxs[k] = proposed_fxs[k]
                    current_amps[k] = proposed_amps[k]
    
    return fxs, current_amps



In [62]:
# # Sequential Metropolis-Hastings updates
# current_amps = torch.stack([amplitude(fx, params) for fx in fxs], dim=0)
# for row, edges in graph.row_edges.items():
#     for edge in edges:
#         i, j = edge
#         for k in range(B):
#             proposed_fx, new = propose_exchange_or_hopping(i, j, fxs[k])
#             proposed_amp = amplitude(proposed_fx, params) if new == 1 else current_amps[k]
#             ratio = (proposed_amp**2) / (current_amps[k]**2)
#             accept_prob = min(ratio.item(), 1.0)
#             if random.random() < accept_prob:
#                 fxs[k] = proposed_fx
#                 current_amps[k] = proposed_amp
# for col, edges in graph.col_edges.items():
#     for edge in edges:
#         i, j = edge
#         for k in range(B):
#             proposed_fx = propose_exchange_or_hopping(i, j, fxs[k])
#             proposed_amp = amplitude(proposed_fx, params)
#             ratio = (proposed_amp**2) / (current_amps[k]**2)
#             accept_prob = min(ratio.item(), 1.0)
#             if random.random() < accept_prob:
#                 fxs[k] = proposed_fx
#                 current_amps[k] = proposed_amp

Local Energy

In [63]:
def evaluate_energy(fxs, fpeps_model, H, current_amps):
    B = fxs.shape[0]
    # get connected configurations and coefficients
    conn_eta_num = []
    conn_etas = []
    conn_eta_coeffs = []
    for fx in fxs:
        eta, coeffs = H.get_conn(fx)
        conn_eta_num.append(len(eta))
        conn_etas.append(torch.tensor(eta))
        conn_eta_coeffs.append(torch.tensor(coeffs))

    conn_etas = torch.cat(conn_etas, dim=0)
    conn_eta_coeffs = torch.cat(conn_eta_coeffs, dim=0)

    print(f'Prepared batched conn_etas and coeffs: {conn_etas.shape}, {conn_eta_coeffs.shape} (batch size {B})')

    # calculate amplitudes for connected configs, in the future consider TN reuse to speed up calculation, TN reuse is controlled by a param that is not batched over (control flow?)
    conn_amps = fpeps_model(conn_etas)

    # Local energy \sum_{s'} H_{s,s'} <s'|psi>/<s|psi>

    local_energies = []
    offset = 0
    for b in range(B):
        n_conn = conn_eta_num[b]
        amps_ratio = conn_amps[offset:offset+n_conn] / current_amps[b]
        energy_b = torch.sum(conn_eta_coeffs[offset:offset+n_conn] * amps_ratio)
        local_energies.append(energy_b)
        offset += n_conn
    local_energies = torch.stack(local_energies, dim=0)
    print(f'Batched local energies: {local_energies.shape}')

    # Energy: (1/N) * \sum_s <s|H|psi>/<s|psi> = (1/N) * \sum_s \sum_{s'} H_{s,s'} <s'|psi>/<s|psi>
    energy = torch.mean(local_energies)
    print(f'Energy: {energy.item()}')

    return energy, local_energies



Gradient

In [64]:
def compute_grads(fxs, fpeps_model, vectorize=True):
    if vectorize:
        # Vectorized gradient calculation
        # need per sample gradient of amplitude -- jacobian
        params_pytree = dict(fpeps_model.params)
        # params is a pytree, fxs has shape (B, nsites)
        def g(x, p):
            results = fpeps_model.vamp(x, p)
            return results, results
        t0 = time.time()
        jac_pytree, amps = torch.func.jacrev(g, argnums=1, has_aux=True)(fxs, params_pytree)
        t1 = time.time()
        print(f"Vectorized jacobian time: {t1 - t0}")
        # jac_pytree has shape same as params_pytree, each leaf has shape (B, )

        # Get per-sample batched grads in list of dicts format
        batched_grads_vec = []
        for b in range(B):
            grad_b_iter = [jac_pytree[k][b] for k in jac_pytree.keys()]
            batched_grads_vec.append(torch.nn.utils.parameters_to_vector(grad_b_iter))
        batched_grads_vec = torch.stack(batched_grads_vec, dim=0)  # shape (B, Np), Np is number of parameters
        amps.unsqueeze_(1)  # shape (B, 1)
        return batched_grads_vec, amps
    
    else:
        # Sequential non-vectorized gradient calculation
        amps = []
        batched_grads_vec = []
        t0 = time.time()
        for fx in fxs:
            amp = fpeps_model(fx.unsqueeze(0))
            amps.append(amp)
            amp.backward()
            grads = qu.tree_map(lambda x: x.grad, fpeps_model.params_as_dict())
            batched_grads_vec.append(torch.nn.utils.parameters_to_vector(grads.values()))
            qu.tree_map(lambda x: x.grad.zero_(), fpeps_model.params_as_dict())
        t1 = time.time()
        print(f"Sequential jacobian time: {t1 - t0}")
        amps = torch.stack(amps, dim=0)
        batched_grads_vec = torch.stack(batched_grads_vec, dim=0)
        return batched_grads_vec, amps


In [65]:
from vmc_torch.fermion_utils import from_quimb_config_to_netket_config, from_netket_config_to_quimb_config, from_netket_config_to_quimb_binary_config, from_quimb_biniary_config_to_netket_config
from vmc_torch.hamiltonian_torch import spinful_Fermi_Hubbard_square_lattice_torch
# generate Hamiltonian graph
t=1.0
U=8.0
N_f = int(Lx*Ly) # half-filling
n_fermions_per_spin = (N_f // 2, N_f // 2)

H = spinful_Fermi_Hubbard_square_lattice_torch(Lx, Ly, t, U, N_f, pbc=False, n_fermions_per_spin=None, no_u1_symmetry=True)
graph = H.graph
print(len(H.hilbert.all_states()))

fxs_nk = from_quimb_config_to_netket_config(fxs)
fxs_quimb_binary = from_netket_config_to_quimb_binary_config(fxs_nk)

128


In [66]:
import quimb.experimental.operatorbuilder as qop
edges = qtn.edges_2d_square(Lx, Ly)
sites = [(i, j) for i in range(Lx) for j in range(Ly)]
H_quimb = qop.fermi_hubbard_from_edges(
    edges,
    U=8,
    mu=0,
    # this ordering pairs spins together, as with the fermionic TN
    order=lambda site: (site[1], site[0]),
    sector=int(sum(ary.charge for ary in peps.arrays) % 2),
    symmetry="Z2",
)
# H_quimb.flatconfig_coupling(fxs[0])
hs = H_quimb.hilbert_space
eta_quimb, eta_coeff_quimb = H_quimb.flatconfig_coupling(fxs_quimb_binary[0])
from_netket_config_to_quimb_config(from_quimb_biniary_config_to_netket_config(eta_quimb)), eta_coeff_quimb, hs, H.get_conn(fxs[0]), fxs[0]

(array([[3, 2, 0, 1],
        [0, 2, 3, 1],
        [2, 3, 1, 0],
        [2, 0, 1, 3]]),
 array([-1., -1., -1., -1.]),
 HilbertSpace(nsites=8, total_size=128, symmetry=Z2, sector=0),
 (array([[0, 2, 3, 1],
         [3, 2, 0, 1],
         [2, 0, 1, 3],
         [2, 3, 1, 0]]),
  array([-1., -1., -1., -1.])),
 tensor([2, 2, 1, 1], dtype=torch.int32))

In [67]:
# params, peps = qtn.pack(peps)
# peps = qtn.unpack(params, skeleton)
def amplitude_binary(x):
    x = 2 * x[::2] + x[1::2]
    x = torch.tensor(x, dtype=torch.int64)
    # might need to specify the right site ordering here
    amp = peps.isel({peps.site_ind(site): x[i] for i, site in enumerate(peps.sites)})
    amp.contract_boundary_from_ymin_(max_bond=chi, cutoff=0.0, yrange=[0, amp.Ly//2-1])
    amp.contract_boundary_from_ymax_(max_bond=chi, cutoff=0.0, yrange=[amp.Ly//2, amp.Ly-1])
    return amp.contract()

In [68]:
with torch.no_grad():
    H_dense = torch.tensor(H.to_dense())
    all_states = torch.tensor(H.hilbert.all_states())
    psi_vec = fpeps_model(all_states)
    fpsi_vec = torch.tensor([fpeps.get_amp(fx).contract().item() for fx in all_states], dtype=torch.float64)

    E = (psi_vec @ H_dense @ psi_vec)/(psi_vec @ psi_vec)
    E1 = (fpsi_vec @ H_dense @ fpsi_vec)/(fpsi_vec @ fpsi_vec)
print(f'Dense energy: {E.item() / Lx / Ly}, {E1.item() / Lx / Ly}')

# compute the full exact energy at the amplitude level
O = 0.0
p = 0.0

fcs = []
for i in range(hs.size):
    fx = hs.rank_to_flatconfig(i)
    # fx_tn = torch.tensor(2*fx[::2] + fx[1::2], dtype=torch.int64)
    # xpsi = amplitude(fx_tn, params).item()
    xpsi = amplitude_binary(fx).item()
    if not xpsi:
        continue

    pi = abs(xpsi) ** 2
    p += pi

    Oloc = 0.0
    for fy, hxy in zip(*H_quimb.flatconfig_coupling(fx)):
        # fy_tn = torch.tensor(2*fy[::2] + fy[1::2], dtype=torch.int64)
        # ypsi = amplitude(fy_tn, params).item()
        ypsi = amplitude_binary(fy).item()
        Oloc = Oloc + hxy * ypsi / xpsi

    O += Oloc * pi

print(f'MC energy: {O / p / Lx / Ly}')

Dense energy: 3.648023928988014, -0.23850422302929083
MC energy: 3.648023392066006


In [69]:
# exact energy via local expectation contraction
terms = sr.hamiltonians.ham_fermi_hubbard_from_edges(
    "Z2",
    edges=edges,
    U=8,
    mu=0.0,
)
ham = qtn.LocalHam2D(Lx, Ly, terms)
ham.apply_to_arrays(lambda A: A.to_flat())
new_peps = peps.copy()
new_peps.apply_to_arrays(lambda x: np.array(x))
new_peps.compute_local_expectation_exact(ham, normalized=True) / nsites

  new_peps.apply_to_arrays(lambda x: np.array(x))


np.float64(-0.23850422302929086)

VMC update

In [16]:
for _ in range(50):
    fxs, current_amps = sample_next(fxs, fpeps_model, graph)
    energy, local_energies = evaluate_energy(fxs, fpeps_model, H, current_amps)
    batched_grads_vec, amps = compute_grads(fxs, fpeps_model, vectorize=True)
    # Now that we have local energies, amps and per-sample gradients, we can compute the energy gradient
    # With the energy gradient, we can further do SR for optimization
    # Or we can do minSR, which is simpler here
    with torch.no_grad():
        local_energies # shape (B,)
        local_energies_mean = torch.mean(local_energies)
        amps # shape (B,)
        params # pytree with each leaf of shape (param_shape...)

        # flatten the model params into a 1d vector of shape (Np,)
        params_vec = torch.nn.utils.parameters_to_vector(fpeps_model.parameters())

        # compute log-derivative grads
        batched_log_grads_vec = batched_grads_vec / amps  # shape (B, Np)
        log_grads_vec_mean = torch.mean(batched_log_grads_vec, dim=0)  # shape (Np,)

        O_sk = (batched_log_grads_vec - log_grads_vec_mean[None, :]) / (B**0.5)  # shape (B, Np)
        T = (O_sk @ O_sk.T.conj())  # shape (B, B)
        E_s = (local_energies - local_energies_mean) / (B**0.5)  # shape (B,)

        # minSR: need to solve O_sk * dp = E_s in the least square sense, using the pseudo-inverse of O_sk to get the minimum norm solution
        T_inv = torch.linalg.pinv(T,  rtol=1e-12, atol=0, hermitian=True)
        dp = O_sk.conj().T @ (T_inv @ E_s)  # shape (Np,)
        # update params
        learning_rate = 0.01
        new_params_vec = params_vec - learning_rate * dp

    # load the new params back to the model
    torch.nn.utils.vector_to_parameters(new_params_vec, fpeps_model.parameters())

    energy, local_energies = evaluate_energy(fxs, fpeps_model, H, current_amps)
    print(f'STEP {_} VMC energy after update: {energy.item()}\n')

Prepared batched conn_etas and coeffs: torch.Size([3142, 8]), torch.Size([3142]) (batch size 256)
Batched local energies: torch.Size([256])
Energy: 16.21894879507688
Vectorized jacobian time: 1.1292724609375
Prepared batched conn_etas and coeffs: torch.Size([3142, 8]), torch.Size([3142]) (batch size 256)
Batched local energies: torch.Size([256])
Energy: 15.995233761048851
STEP 0 VMC energy after update: 15.995233761048851

Prepared batched conn_etas and coeffs: torch.Size([3133, 8]), torch.Size([3133]) (batch size 256)
Batched local energies: torch.Size([256])
Energy: 17.37750064260142
Vectorized jacobian time: 0.9365081787109375
Prepared batched conn_etas and coeffs: torch.Size([3133, 8]), torch.Size([3133]) (batch size 256)
Batched local energies: torch.Size([256])
Energy: 16.87564648459265
STEP 1 VMC energy after update: 16.87564648459265

Prepared batched conn_etas and coeffs: torch.Size([3156, 8]), torch.Size([3156]) (batch size 256)
Batched local energies: torch.Size([256])
Energ

KeyboardInterrupt: 