In [1]:
%load_ext autoreload
%autoreload 2

In [41]:
import numpy as np
from tutorial_functions import compare
from functools import partial

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import random as rnd
from jax import lax

from pytorch.models.og.model import fermiNet
from pytorch.sampling import MetropolisHasting
from pytorch.vmc import get_energy_and_center, compute_local_energy, compute_potential_energy, batched_cdist_l2
from pytorch.pretraining import Pretrainer
from pytorch.systems import Molecule as Moleculetc
from pytorch.utils import update_state_dict, from_np
import torch as tc
tc.set_default_dtype(tc.float64)

from ops.vmc.utils import create_atom_batch
from ops.systems import Molecule
from ops.wf.ferminet import model, create_masks
from ops.wf.parameters import initialise_params, count_mixed_features

In [42]:
n_el = 4
n_atom = 1
n_up = 2
n_down = n_el - n_up
n_layers = 2
n_sh = 20
n_ph = 10
key = rnd.PRNGKey(1)
key, *subkeys = rnd.split(key, 2)
n_det = 5
n_walkers = 20

walkers = rnd.normal(key, (n_walkers, n_el, 3))
r_atoms = create_atom_batch(rnd.normal(key, (n_atom, 3)), n_walkers)
z_atoms = jnp.array([n_el])

walkers_tc = from_np(walkers)
r_atoms_tc = from_np(r_atoms)
z_atoms_tc = from_np(z_atoms)

mol_tc = Moleculetc(r_atoms_tc, z_atoms_tc, n_el, device='cpu', dtype=r_atoms_tc.dtype)
mol = Molecule(r_atoms, z_atoms, n_el)

"""
In this phase
- vmap everything
"""

System: 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202
System: 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 

converged SCF energy = -14.351880476202


'\nIn this phase\n- vmap everything\n'

In [43]:
def compute_ae_vectors(walkers: jnp.array,
                       r_atoms: jnp.array) -> jnp.array:
    r_atoms = jnp.expand_dims(r_atoms, axis=0)
    walkers = jnp.expand_dims(walkers, axis=1)
    ae_vectors = walkers - r_atoms
    return ae_vectors

fn = jax.vmap(compute_ae_vectors, in_axes=(0, None))
# compute_ae_vectors(walkers, r_atoms).shape
fn(walkers, r_atoms_tmp).shape

(20, 4, 1, 3)

In [44]:
# setup the models
masks = create_masks(n_atom, n_el, n_up, n_layers, n_sh, n_ph)
params = initialise_params(key, n_atom, n_up, n_down, n_layers, n_sh, n_ph, n_det)

model_tc = fermiNet(mol_tc, n_det=n_det, n_sh=n_sh, n_ph=n_ph, diagonal=False)
model_tc = update_state_dict(model_tc, params)

r_atoms_tmp = r_atoms[0]
pmodel = partial(model, r_atoms=r_atoms_tmp, masks=masks, n_up=n_up, n_down=n_down)
print(r_atoms_tmp.shape)

Model: 
 device   = cpu 
 n_sh     = 20 
 n_ph     = 10 
 n_layers = 2 
 n_det    = 5 

(1, 3)


In [45]:
logpsi_tc = model_tc(walkers_tc)
logpsi = pmodel(params, walkers)
compare(logpsi_tc, logpsi)

(20, 4, 3) (1, 3)
(4, 3) (1, 3)
(4, 1, 3) (1, 1, 3)
(4, 1, 3)
(20, 4, 3) (20, 4, 1, 3)


ValueError: Incompatible shapes for broadcasting: ((20, 1, 4, 3), (1, 1, 20, 4))

In [None]:
def compute_ae_vectors(walkers: jnp.array,
                       r_atoms: jnp.array) -> jnp.array:
    r_atoms = jnp.expand_dims(r_atoms, axis=1)
    walkers = jnp.expand_dims(walkers, axis=2)
    ae_vectors = walkers - r_atoms
    return ae_vectors


def drop_diagonal(square):
    """
    for proof of this awesomeness go to debugging/drop_diagonal where compared with masking method
    """
    n = square.shape[0]
    split1 = jnp.split(square, n, axis=0)
    upper = [jnp.split(split1[i], [j], axis=1)[1] for i, j in zip(range(0, n), range(1, n))]
    lower = [jnp.split(split1[i], [j], axis=1)[0] for i, j in zip(range(1, n), range(1, n))]
    arr = [ls[i] for i in range(n-1) for ls in (upper, lower)]
    result = jnp.concatenate(arr, axis=1)
    return jnp.squeeze(result)


def compute_inputs(walkers, ae_vectors):
    """
    Notes:
        Previous masking code for dropping the diagonal
            # mask = jnp.expand_dims(~jnp.eye(n_electrons, dtype=bool), axis=(0, 3))
            # mask = jnp.repeat(jnp.repeat(mask, n_samples, axis=0), 3, axis=-1)
            # ee_vectors = ee_vectors[mask].reshape(n_samples, n_electrons ** 2 - n_electrons, 3)
    """
    n_electrons, n_atoms = ae_vectors.shape[:2]

    ae_distances = jnp.linalg.norm(ae_vectors, axis=-1, keepdims=True)
    single_inputs = jnp.concatenate([ae_vectors, ae_distances], axis=-1)
    single_inputs = single_inputs.reshape(n_electrons, 4 * n_atoms)

    re1 = jnp.expand_dims(walkers, axis=1)
    re2 = jnp.transpose(re1, [1, 0, 2])
    ee_vectors = re1 - re2
    ee_vectors = drop_diagonal(ee_vectors)
    ee_distances = jnp.linalg.norm(ee_vectors, axis=-1, keepdims=True)
    pairwise_inputs = jnp.concatenate([ee_vectors, ee_distances], axis=-1)

    return single_inputs, pairwise_inputs

ae_vectors = compute_ae_vectors(walkers, r_atoms)
fn = jax.vmap(compute_inputs, in_axes=(0, 0))
x, y = fn(walkers, ae_vectors)
print(x.shape, y.shape)