In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"


In [37]:
import numpy as np
from ops.utils import compare
from functools import partial

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import random as rnd
from jax import lax, jit, vmap, value_and_grad, grad
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten

from pytorch.models.og.model import fermiNet
from pytorch.sampling import MetropolisHasting
from pytorch.vmc import *
from pytorch.pretraining_v2 import Pretrainer, tile_labels, mse_error
from pytorch.systems import Molecule as Moleculetc
from pytorch.utils import update_state_dict, from_np
import torch as tc
tc.set_default_dtype(tc.float64)

from ops.vmc.utils import create_atom_batch
from ops.systems import Molecule
from ops.wf.ferminet import create_wf, create_masks
from ops.wf.parameters import initialise_params, count_mixed_features
from ops.sampling import create_sampler
from ops.vmc import create_energy_fn, local_kinetic_energy, compute_potential_energy
from ops.pretraining import create_loss_and_sampler

def df(arr_tc, arr_jax):
    arr_tc = arr_tc.detach().cpu().numpy()
    diff = jnp.mean(jnp.abs(arr_tc - arr_jax)) 
    print(diff)
    return diff
    
def compare_grads(model_tc, grads):
    tmp = []
    for k, value in grads.items():
        if k == 'intermediate':
            for intermediate in zip(*params[k]):
                for ps in intermediate:
                    tmp.append(ps)

        elif k == 'envelopes':
            order = ('linear', 'sigma', 'pi')
            for spin in (0, 1):
                for layer in order:
                    ps = grads[k][layer][spin]
                    tmp.append(ps)

        else:
            tmp.append(value)

    sd = model_tc.state_dict(keep_vars=True)
    for (k, val), p in zip(sd.items(), tmp):
        g = val.grad
        if not g is None: 
            diff = df(g, p)
            if diff > 0.001:
                print(k, diff)
                print(g[0], '\n', p[0], '\n')
                
                print(g[2], '\n', p[2], '\n')

In [38]:
# randomness
key = rnd.PRNGKey(1)
key, *subkeys = rnd.split(key, num=3)

# system
n_walkers = 1024
n_el = 4
r_atoms = jnp.array([[0.0, 0.0, 0.0]])
z_atoms = jnp.array([4.])

# ansatz
n_layers = 1
n_sh = 4
n_ph = 2
n_det = 1

mol = Molecule(r_atoms, z_atoms, n_el, n_det=n_det, n_sh=n_sh, n_ph=n_ph, n_layers=n_layers)
walkers = mol.initialise_walkers(n_walkers=n_walkers)
wf, wf_orbitals = create_wf(mol)
vwf = vmap(wf, in_axes=(None, 0, 0))
sampler = create_sampler(wf, correlation_length=10)
params = initialise_params(subkeys[0], mol)
compute_energy = create_energy_fn(wf, r_atoms, z_atoms)
laplacian_jax = jit(vmap(local_kinetic_energy(wf), in_axes=(None, 0)))
loss_function, sampler = create_loss_and_sampler(mol, wf, wf_orbitals)

loss_function2 = grad(loss_function)
loss_function = value_and_grad(loss_function)


walkers_tc = from_np(walkers)
r_atoms_tc = from_np(create_atom_batch(r_atoms, n_walkers))
z_atoms_tc = from_np(z_atoms)

mol_tc = Moleculetc(r_atoms_tc, z_atoms_tc, n_el, device='cpu', dtype=r_atoms_tc.dtype)

wf_tc = fermiNet(mol_tc, n_det=n_det, n_sh=n_sh, n_ph=n_ph, n_layers=n_layers, diagonal=False)
wf_tc = update_state_dict(wf_tc, params)


System: 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202
System: 
 Device  = cpu 
 dtype   = torch.float64 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202
Model: 
 device   = cpu 
 n_sh     = 4 
 n_ph     = 2 
 n_layers = 1 
 n_det    = 1 

lin_split_in.w torch.Size([8, 4]) (8, 4)
stream_s0.w torch.Size([13, 4]) (13, 4)
stream_p0.w torch.Size([5, 2]) (5, 2)
single_splits.0.w torch.Size([8, 4]) (8, 4)
single_intermediate.0.w torch.Size([9, 4]) (9, 4)
pairwise_intermediate.0.w torch.Size([3, 2]) (3, 2)
env_up_linear.w torch.Size([1, 2, 5]) (1, 2, 5)
env_up_sigma.sigma_einsum torch.Size([1, 2, 1, 3, 3]) (1, 2, 1, 3, 3)
env_up_pi.pi torch.Size([1, 2, 1]) (1, 2, 1)
env_down_linear.w torch.Size([1, 2, 5]) (1, 2, 5)
env_down_sigma.sigma_einsum torch.Size([1, 2, 1, 3, 3]) (1, 2, 1, 3, 3)
env_down_pi.pi torch.Size([1, 2, 1]) (1, 2, 1)


In [46]:


class fermiNet(nn.Module):
    def __init__(self,
                 mol,
                 diagonal: bool = False):

        super(fermiNet, self).__init__()
        
        from pytorch.models.og.model import Mixer, LinearSplit, LinearSingle, LinearPairwise, EnvelopeLinear, EnvelopeSigma, EnvelopePi
        
        
        self.device = mol.device
        self.dtype = mol.dtype
        dv, dt = self.device, self.dtype
        self.diagonal = diagonal
        
        n_layers, n_sh, n_ph, n_det = mol.n_layers, mol.n_sh, mol.n_ph, mol.n_det
        r_atoms, n_el, n_up, n_atoms = mol.r_atoms, mol.n_el, mol.n_up, mol.n_atoms
        #r_atoms = from_np(r_atoms)

        # things we need
        self.n_layers = n_layers
        self.r_atoms = r_atoms
        self.n_el = int(n_el)
        self.n_pairwise = int(n_el ** 2 - int(not diagonal) * n_el)
        self.n_up = n_up
        self.n_down = n_el - n_up
        self.n_atoms = int(n_atoms)
        n_down = n_el - n_up
        self.n_determinants = n_det

        # layers
        s_in = 4 * n_atoms
        p_in = 4
        s_hidden = n_sh
        self.s_hidden = s_hidden
        p_hidden = n_ph
        self.p_hidden = p_hidden
        s_mixed_in = 4 * n_atoms + 4 * 2
        s_mixed = n_sh * 3 + n_ph * 2

        self.mix_in = Mixer(s_in, p_in, n_el, n_up, diagonal, dv, dt)
        self.lin_split_in = LinearSplit(2 * s_in, s_hidden, dv, dt)

        self.stream_s0 = LinearSingle(s_mixed_in, s_hidden, dv, dt)
        self.stream_p0 = LinearPairwise(p_in, p_hidden, dv, dt)
        self.m0 = Mixer(s_hidden, p_hidden, n_el, n_up, diagonal, dv, dt)

        self.single_splits = \
            tc.nn.ModuleList([LinearSplit(2 * s_hidden, s_hidden, dv, dt) for _ in range(n_layers)])
        self.single_intermediate = \
            tc.nn.ModuleList([LinearSingle(s_mixed - 2 * s_hidden, s_hidden, dv, dt) for _ in range(n_layers)])
        self.pairwise_intermediate = \
            tc.nn.ModuleList([LinearPairwise(p_hidden, p_hidden, dv, dt) for _ in range(n_layers)])
        self.intermediate_mix = Mixer(s_hidden, p_hidden, n_el, n_up, diagonal, dv, dt)

        self.env_up_linear = EnvelopeLinear(s_hidden, n_up, n_det, dv, dt)
        self.env_up_sigma = EnvelopeSigma(n_up, n_det, n_atoms, dv, dt)
        self.env_up_pi = EnvelopePi(n_up, n_det, n_atoms, dv, dt)

        self.env_down_linear = EnvelopeLinear(s_hidden, n_down, n_det, dv, dt)
        self.env_down_sigma = EnvelopeSigma(n_down, n_det, n_atoms, dv, dt)
        self.env_down_pi = EnvelopePi(n_down, n_det, n_atoms, dv, dt)

        print('Model: \n',
              'device   = %s \n' % self.device,
              'n_sh     = %i \n' % n_sh,
              'n_ph     = %i \n' % n_ph,
              'n_layers = %i \n' % n_layers,
              'n_det    = %i \n' % n_det)

    def layers(self):
        for m in self.children():
            if len(list(m.parameters())) == 0:
                continue
            elif isinstance(m, tc.nn.ModuleList):
                yield from m
            else:
                yield m

#     def forward(self, walkers):
#         from pytorch.models.og.model import logabssumdet

#         up_orbitals, down_orbitals = self.generate_orbitals(walkers)

#         return up_orbitals

    def forward(self, walkers):
        from pytorch.models.og.model import compute_ae_vectors, compute_inputs
        #walkers = from_np(walkers)
        n_walkers = int(walkers.shape[0])

        self.single_input_residual = tc.zeros((n_walkers, self.n_el, self.s_hidden), device=walkers.device, dtype=walkers.dtype)
        self.pairwise_input_residual = tc.zeros((n_walkers, self.n_pairwise, self.p_hidden), device=walkers.device, dtype=walkers.dtype)

        ae_vectors = compute_ae_vectors(self.r_atoms, walkers)

        # the inputs
        single, pairwise = compute_inputs(walkers, n_walkers, ae_vectors, self.n_atoms, self.n_el)

        if self.diagonal:
            diagonal_pairwise_input = tc.zeros((n_walkers, self.n_el, 4), device=walkers.device, dtype=walkers.dtype)
            pairwise = tc.cat((pairwise, diagonal_pairwise_input), dim=1)

        # mix in
        single_mixed, single_split = self.mix_in(single, pairwise)

        # first layer
        single_split = self.lin_split_in(single_split)
        single = self.stream_s0(single_mixed, single_split, self.single_input_residual)
        pairwise = self.stream_p0(pairwise, self.pairwise_input_residual)

        # intermediate layers
        for ss, ls, ps in zip(self.single_intermediate, self.single_splits, self.pairwise_intermediate):
            single_mixed, single_split = self.intermediate_mix(single, pairwise)

            single_split = ls(single_split)
            single_res = single*0.
            single = ss(single_mixed, single_split, single_res)

        return single


In [47]:
from ops.wf.ferminet import *

def create_wf(mol):

    n_up, n_down, r_atoms, n_el = mol.n_up, mol.n_down, mol.r_atoms, mol.n_el
    masks = create_masks(mol.n_atoms, mol.n_el, mol.n_up, mol.n_layers, mol.n_sh, mol.n_ph)

    def _wf(params, walkers):

        if len(walkers.shape) == 1:  # this is a hack to get around the jvp
            walkers = walkers.reshape(n_up+n_down, 3)

        ae_vectors = compute_ae_vectors_i(walkers, r_atoms)

        single, pairwise = compute_inputs_i(walkers, ae_vectors)

        single_mixed, split = mixer_i(single, pairwise, n_el, n_up, n_down, *masks[0])

        split = linear_split(params['split0'], split)
        single = linear(params['s0'], single_mixed, split)
        pairwise = linear_pairwise(params['p0'], pairwise)

        split_params, s_params, p_params = params['intermediate'][0]
        mask = masks[1]

        single_mixed, split = mixer_i(single, pairwise, n_el, n_up, n_down, *mask)

        split = linear_split(split_params, split)
        single = linear(s_params, single_mixed, split)

        return single

    return _wf


# system
n_walkers = 1024
n_el = 4
r_atoms = jnp.array([[0.0, 0.0, 0.0]])
z_atoms = jnp.array([4.])

# ansatz
n_layers = 1
n_sh = 4
n_ph = 2
n_det = 1

mol = Molecule(r_atoms, z_atoms, n_el, n_det=n_det, n_sh=n_sh, n_ph=n_ph, n_layers=n_layers)
walkers = mol.initialise_walkers(n_walkers=n_walkers)
params = initialise_params(subkeys[0], mol)

walkers_tc = from_np(walkers)
r_atoms_tc = from_np(create_atom_batch(r_atoms, n_walkers))
z_atoms_tc = from_np(z_atoms)
mol_tc = Moleculetc(r_atoms_tc, z_atoms_tc, n_el, device='cpu', dtype=r_atoms_tc.dtype, n_det=n_det, n_sh=n_sh, n_ph=n_ph, n_layers=n_layers)
wf_tc = fermiNet(mol_tc)
wf_tc = update_state_dict(wf_tc, params)

wf = create_wf(mol)
vwf = vmap(wf, in_axes=(None, 0, 0))


wf_tc.zero_grad()

lp = vwf(params, walkers)
lp_tc = wf_tc(walkers_tc)

print('LOSS DIFF')
df(lp_tc, lp)
print('\n')

lp_loss_tc = lp_tc.sum()
lp_loss_tc.backward()

def swf(params, walkers):
    lp = vwf(params, walkers)
    return jnp.sum(lp)

lp_grad_fn = grad(swf)
grads = lp_grad_fn(params, walkers)

compare_grads(wf_tc, grads)

# print(lp[1])
# print(lp_tc[1])
# print(lp_tc.detach().cpu().numpy() - lp)

System: 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202
System: 
 Device  = cpu 
 dtype   = torch.float64 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202
Model: 
 device   = cpu 
 n_sh     = 4 
 n_ph     = 2 
 n_layers = 1 
 n_det    = 1 

lin_split_in.w torch.Size([8, 4]) (8, 4)
stream_s0.w torch.Size([13, 4]) (13, 4)
stream_p0.w torch.Size([5, 2]) (5, 2)
single_splits.0.w torch.Size([8, 4]) (8, 4)
single_intermediate.0.w torch.Size([9, 4]) (9, 4)
pairwise_intermediate.0.w torch.Size([3, 2]) (3, 2)
env_up_linear.w torch.Size([1, 2, 5]) (1, 2, 5)
env_up_sigma.sigma_einsum torch.Size([1, 2, 1, 3, 3]) (1, 2, 1, 3, 3)
env_up_pi.pi torch.Size([1, 2, 1]) (1, 2, 1)
env_down_linear.w torch.Size([1, 2, 5]) (1, 2, 5)
env_down_sigma.sigma_einsum torch.Size([1, 2, 1, 3, 3]) (1, 2, 1, 3, 3)
env_down_pi.pi torch.Size([1, 2, 1]) (1, 2, 1)
LOSS DIFF
6.33031830889303e-17


2.604035311735975e-06
2.2947455901290734e-14
7.86926079854

In [45]:
class Pretrainer():
    def __init__(self,
                 mol,
                 n_pretrain_iterations: int = 1000):

        self.mol = mol

        self.n_up = mol.n_up
        self.n_down = mol.n_down
        self.n_el = mol.n_el

        self.n_iterations = n_pretrain_iterations

    def compute_orbital_probability(self, samples):
        up_dets, down_dets = self.hf_orbitals(samples)

        spin_ups = up_dets ** 2
        spin_downs = down_dets ** 2

        p_up = tc.diagonal(spin_ups, dim1=-2, dim2=-1).prod(-1)
        p_down = tc.diagonal(spin_downs, dim1=-2, dim2=-1).prod(-1)
        # p_up = spin_ups.prod(1).prod(1)
        # p_down = spin_downs.prod(1).prod(1)

        probabilities = p_up * p_down

        return probabilities

    def pyscf_call(self, samples):
        samples = samples.cpu().numpy()
        ao_values = self.mol.pyscf_mol.eval_gto("GTOval_cart", samples)
        return tc.from_numpy(ao_values)

    def hf_orbitals(self, coord):
        coord = coord.view((-1, 3))

        number_spin_down = self.n_down
        number_spin_up = self.n_el - number_spin_down

        ao_values = self.pyscf_call(coord).to(device=coord.device, dtype=coord.dtype)
        ao_values = ao_values.view((int(len(ao_values) / self.n_el), self.n_el, len(ao_values[0])))

        spin_up = tc.stack([(self.mol.moT[orb_number, :] * ao_values[:, el_number, :]).sum(-1)
             for orb_number in range(number_spin_up) for el_number in
             range(number_spin_up)], dim=1).view((-1, number_spin_up, number_spin_up))

        spin_down = tc.stack([(self.mol.moT[orb_number, :] * ao_values[:, el_number, :]).sum(-1)
                            for orb_number in range(number_spin_down) for el_number in
                            range(number_spin_up, self.n_el)], dim=1).view((-1, number_spin_down, number_spin_down))

        return spin_up, spin_down
    
tc_hf = Pretrainer(mol_tc)

In [7]:
# compare the losses
def compare_grads(model_tc, grads):
    tmp = []
    for k, value in grads.items():
        if k == 'intermediate':
            for intermediate in zip(*params[k]):
                for ps in intermediate:
                    tmp.append(ps)

        elif k == 'envelopes':
            order = ('linear', 'sigma', 'pi')
            for spin in (0, 1):
                for layer in order:
                    ps = grads[k][layer][spin]
                    tmp.append(ps)

        else:
            tmp.append(value)

    sd = model_tc.state_dict(keep_vars=True)
    for (k, val), p in zip(sd.items(), tmp):
#         print(k, val.shape, p.shape)
#         print(k)
        assert val.shape == p.shape
        if val.grad is None:   # the hanging pairwise stream is None in pytorch
#             print(jnp.sum(jnp.abs(p)))
            continue
        g = val.grad.detach().cpu().numpy()
#         print(jnp.mean(jnp.abs(g - p)))
        
#         if k in ('env_up_sigma.sigma_einsum', 'env_down_sigma.sigma_einsum'):
#             print(g[0], '\n', p[0])
            
        if k in ('single_splits.0.w',):
            print(k)
#             print(g[0], '\n', p[0])
            print(g, '\n', p)
            print(jnp.mean(jnp.abs(p)) / jnp.mean(jnp.abs(g)))
#         sd[kprint(g[0], '\n', p[0]) = from_np(p)

#     model_tc.load_state_dict(sd, strict=True
    
    
up_dets, down_dets = tc_hf.hf_orbitals(walkers_tc)
up_dets = tile_labels(up_dets, wf_tc.n_determinants)
down_dets = tile_labels(down_dets, wf_tc.n_determinants)

model_up_dets, model_down_dets = wf_tc.generate_orbitals(walkers_tc)

loss = mse_error(up_dets, model_up_dets)
loss += mse_error(down_dets, model_down_dets)
wf_tc.zero_grad()
print(loss)

loss.backward()  # in order for hook to work must call backward

grads1 = loss_function2(params, walkers)
loss_value, grads2 = loss_function(params, walkers)
print(loss_value)

compare_grads(wf_tc, grads1)
# compare_grads(wf_tc, grads2)

tensor(0.7432, grad_fn=<AddBackward0>)
0.7431866161590366
single_splits.0.w
[[ 0.0395846   0.21111888  0.00150585  0.02713028]
 [-0.02316328  0.35843016  0.00348127  0.03004971]
 [-0.02788671  0.40136708  0.0044372   0.03414911]
 [ 0.03642672 -0.2012093  -0.00100034  0.00042881]
 [ 0.01944948  0.27694805  0.00156295  0.02711742]
 [-0.02936083  0.38747432  0.00436606  0.03332011]
 [-0.0330114   0.40305125  0.00466554  0.03306731]
 [ 0.03739027 -0.34935674 -0.00373558 -0.020753  ]] 
 [[ 0.19159925 -0.14835714  0.51060385  0.13056381]
 [ 0.02165198 -0.6838221  -0.18707593  0.3582895 ]
 [ 0.28646317 -0.08596921  0.31779283 -0.66303575]
 [ 0.0518322   0.28676125 -0.01401097 -0.11338016]
 [ 0.00448408 -0.5360741  -0.08405411 -0.536513  ]
 [-0.6326073  -0.22622964  0.59515196  0.08993056]
 [-0.32973337 -0.16543783 -0.48901087 -0.19934303]
 [ 0.6075665  -0.23415759  0.05191307  0.25775766]]
2.9651579849493763


In [8]:
print(grads)

OrderedDict([('split0', DeviceArray([[   66.65078 ,    58.176144,  -133.46118 ,   243.43398 ],
             [  167.34445 ,   256.2817  ,    95.12666 ,  -236.2991  ],
             [  -55.878025,   171.31535 ,   308.44946 ,   585.5462  ],
             [  758.3378  ,   340.0263  ,   246.6875  , -2330.7363  ],
             [  243.97282 ,   166.17392 ,  -191.01616 ,   181.22307 ],
             [  225.89394 ,  -170.51274 ,    66.91763 ,   394.34787 ],
             [  102.54219 ,   268.52792 ,  -308.95078 ,   566.5387  ],
             [  740.1543  ,   498.4542  ,   597.3778  , -1961.298   ]],            dtype=float32)), ('s0', DeviceArray([[ 1534.15635844, -1689.07992037,   160.82784695,
               1502.8496008 ],
             [  776.49583645,  -360.58406003,  -114.11196112,
                823.73215607],
             [  983.99286244,  3588.98830435,  -338.39412403,
               -179.76773983],
             [ -896.06693427,  2283.62528911,    -4.42706734,
              -2821.91394965],


In [9]:
device, dtype = wf_tc.device, wf_tc.dtype
sampler = MetropolisHastingsPretrain()
wf_walkers = walkers

wf_sampler = MetropolisHasting(wf)
for i in range(500):
    wf_walkers, wf_acc = wf_sampler(wf_walkers)
    e_locs = compute_local_energy(wf, wf_walkers, self.mol.r_atoms, self.mol.z_atoms)
    print(e_locs.mean())


opt = tc.optim.Adam(list(wf.parameters()), lr=lr)
steps = trange(
    0,  # init_step = 0
    n_it,
    initial=0,
    total=n_it,
    desc='pretraining',
    disable=None,
)

# walkers = initialize_walkers(self.mol.n_el_atoms, self.mol.atom_positions, n_walkers).to(device=device, dtype=dtype)

for step in steps:
    wf_walkers, wf_acc = wf_sampler(wf_walkers)
    e_locs = compute_local_energy(wf, wf_walkers, self.mol.r_atoms,  self.mol.z_atoms)

    walkers = sampler(wf, self, walkers)

    up_dets, down_dets = hf_orbitals(walkers)
    up_dets = tile_labels(up_dets, wf.n_determinants)
    down_dets = tile_labels(down_dets, wf.n_determinants)

    model_up_dets, model_down_dets = wf.generate_orbitals(walkers)

    loss = mse_error(up_dets, model_up_dets)
    loss += mse_error(down_dets, model_down_dets)
    opt.zero_grad()
    loss.backward()  # in order for hook to work must call backward
    opt.step()

NameError: name 'MetropolisHastingsPretrain' is not defined

In [None]:

import numpy as np
from tqdm.auto import trange
from jax import value_and_grad, grad, vmap, jit
import jax
import jax.numpy as jnp
import jax.random as rnd
from jax.experimental.optimizers import adam
from jax.tree_util import tree_unflatten, tree_flatten

from ops.pretraining import create_loss_and_sampler, equilibrate
from ops.vmc import create_energy_fn
from ops.



compute_local_energy = create_energy_fn(wf, mol.r_atoms, mol.z_atoms)

loss_function, sampler = create_loss_and_sampler(mol, wf, wf_orbitals)
loss_function = value_and_grad(loss_function)

walkers, step_size = equilibrate(params, walkers, compute_local_energy, sampler, key, n_it=50, step_size=0.02)
wf_walkers = jnp.array(walkers, copy=True)

init, update, get_params = adam(1e-3)
state = init(params)

steps = trange(0, n_it, initial=0, total=n_it, desc='pretraining', disable=None)
for step in steps:
    key, *subkeys = rnd.split(key, num=3)

    wf_walkers, acc = wf_sampler(params, wf_walkers, subkeys[0], step_size)
    e_locs = compute_local_energy(params, wf_walkers)

    walkers, mix_acc = sampler(params, walkers, subkeys[1], step_size)

    loss_value, grads = loss_function(params, walkers)

    params = sgd(params, grads, lr)
    # state = update(step, grads, state)
    # params = get_params(state)


    print('step %i | e_mean %.2f | loss %.2f | wf_acc %.2f | mix_acc %.2f |'
          % (step, jnp.mean(e_locs), loss_value, acc, mix_acc))
    # steps.set_postfix(E=f'{e_locs.mean():.6f}')

return params, walkers