In [10]:
%load_ext autoreload
%autoreload 2

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
import numpy as np
from ops.utils import compare
from functools import partial

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import random as rnd
from jax import lax, jit, vmap
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten

from pytorch.models.og.model import fermiNet
from pytorch.sampling import MetropolisHasting
from pytorch.vmc import *
from pytorch.pretraining import Pretrainer
from pytorch.systems import Molecule as Moleculetc
from pytorch.utils import update_state_dict, from_np
import torch as tc
tc.set_default_dtype(tc.float64)

from ops.vmc.utils import create_atom_batch
from ops.systems import Molecule
from ops.wf.ferminet import create_wf, create_masks
from ops.wf.parameters import initialise_params, count_mixed_features
from ops.sampling import create_sampler
from ops.vmc import create_energy_fn, local_kinetic_energy, compute_potential_energy
from ops.pretraining import create_loss_and_sampler

In [12]:
# randomness
key = rnd.PRNGKey(1)
key, *subkeys = rnd.split(key, num=3)

# system
n_layers, n_sh, n_ph, n_det = 2, 64, 16, 2
n_walkers = 1024
n_el = 4
r_atoms = jnp.array([[0.0, 0.0, 0.0]])
z_atoms = jnp.array([4.])

# ansatz

mol = Molecule(r_atoms, z_atoms, n_el, n_layers=n_layers, n_sh=n_sh, n_ph=n_ph, n_det=n_det)
walkers = mol.initialise_walkers(n_walkers=n_walkers)
wf, wf_orbitals = create_wf(mol)
vwf = vmap(wf, in_axes=(None, 0, 0))
sampler = create_sampler(wf, correlation_length=10)
params = initialise_params(subkeys[0], mol)
compute_energy = create_energy_fn(wf, r_atoms, z_atoms)
laplacian_jax = jit(vmap(local_kinetic_energy(wf), in_axes=(None, 0)))
loss_function, sampler = create_loss_and_sampler(mol, wf, wf_orbitals)

walkers_tc = from_np(walkers)
r_atoms_tc = from_np(create_atom_batch(r_atoms, n_walkers))
z_atoms_tc = from_np(z_atoms)

mol_tc = Moleculetc(r_atoms_tc, z_atoms_tc, n_el, device='cpu', dtype=r_atoms_tc.dtype, n_layers=n_layers, n_sh=n_sh, n_ph=n_ph, n_det=n_det)

model_tc = fermiNet(mol_tc, diagonal=False)
model_tc = update_state_dict(model_tc, params)


System: 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202
System: 
 Device  = cpu 
 dtype   = torch.float64 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202
Model: 
 device   = cpu 
 n_sh     = 64 
 n_ph     = 16 
 n_layers = 2 
 n_det    = 2 

lin_split_in.w torch.Size([8, 64]) (8, 64)
stream_s0.w torch.Size([13, 64]) (13, 64)
stream_p0.w torch.Size([5, 16]) (5, 16)
single_splits.0.w torch.Size([128, 64]) (128, 64)
single_splits.1.w torch.Size([128, 64]) (128, 64)
single_intermediate.0.w torch.Size([97, 64]) (97, 64)
single_intermediate.1.w torch.Size([97, 64]) (97, 64)
pairwise_intermediate.0.w torch.Size([17, 16]) (17, 16)
pairwise_intermediate.1.w torch.Size([17, 16]) (17, 16)
env_up_linear.w torch.Size([2, 2, 65]) (2, 2, 65)


AttributeError: 'list' object has no attribute 'shape'

In [25]:
from jax.nn.initializers import orthogonal

init_orthogonal = orthogonal()


def init_linear_layer(key, shape, bias, bias_axis=0):
    key, subkey = rnd.split(key)
    p = init_orthogonal(key, shape)
    if bias:
        shape = list(shape)
        shape[bias_axis] = 1
        b = rnd.normal(subkey, tuple(shape))
        p = jnp.concatenate([p, b], axis=bias_axis)
    return p

def init_sigma(shape, bias=False):
    n_det, n_spin, n_atom, _, _ = shape
    subkeys = rnd.split(key, num=jnp.prod(jnp.array(shape[:3])))
    new_shape = (3, 3)
    p = jnp.concatenate([init_linear_layer(k, new_shape, bias)[None, ...] for k in subkeys], axis=0)
    p = p.reshape(shape)
    return p


def f1(sigma: jnp.array,
                ae_vectors: jnp.array):
    
    # n_det, n_spin, n_atom, _, _ = sigma.shape
    n_spin, n_atom, _ = ae_vectors.shape

    sigma = [jnp.squeeze(x) for x in jnp.split(sigma, n_atom, axis=2)]

    ae_vectors = [jnp.squeeze(x) for x in jnp.split(ae_vectors, n_atom, axis=1)]

    outs = []
    for ae_vector, s in zip(ae_vectors, sigma):
        exponent = jnp.dot(ae_vector, s).reshape(n_spin, -1, n_spin, n_atom, 3)
        out = jnp.exp(-jnp.linalg.norm(exponent, axis=-1))
        outs.append(out)
        print(out.shape)

    return jnp.concatenate(outs, axis=-1)


def f2(sigma: jnp.array,
                ae_vectors: jnp.array) -> jnp.array:
    # sigma (n_det, n_spin, n_atom, 3, 3)
    # ae_vectors (n_spin, n_atom, 3)

    exponent = jnp.einsum('jmv,kimvc->jkimc', ae_vectors, sigma)

    return jnp.exp(-jnp.linalg.norm(exponent, axis=-1))


key = rnd.PRNGKey(1)
n_spin, n_atom, n_det = 3, 1, 2
x = rnd.normal(key, (n_spin, n_atom, 3))
shape = (n_det, n_spin, n_atom, 3, 3)
sigma = init_sigma(shape)

z = f1(sigma, x)
y = f2(sigma, x)

print(z.shape, y.shape)
jnp.isclose(z, y)


(3, 2, 3, 1)
(3, 2, 3, 1) (3, 2, 3, 1)


DeviceArray([[[[ True],
               [ True],
               [ True]],

              [[ True],
               [ True],
               [ True]]],


             [[[ True],
               [ True],
               [ True]],

              [[ True],
               [ True],
               [ True]]],


             [[[ True],
               [ True],
               [ True]],

              [[ True],
               [ True],
               [ True]]]], dtype=bool)