In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
import numpy as np
from jax import numpy as jnp
from jax import random as rnd
from flax.traverse_util import flatten_dict

import torch 

from hwat_jax import FermiNet
from hwat_func import Ansatz_fb

from pyfig import Pyfig
from utils import flat_any

rng = rnd.PRNGKey(1)

c = Pyfig(wb_mode='disabled', notebook=True)
c._convert(device='cpu', dtype=torch.float64)
n_b, n_e = c.data.n_b, c.data.n_e
torch.set_default_dtype(torch.float64)
r = np.random.normal(0, 1, (n_b, n_e, 3))
tr = torch.tensor(r, dtype=torch.float64)
cj = Pyfig(wb_mode='disabled', notebook=True)

jmod = cj.partial(FermiNet)
jp = jmod.init(rng, r[[0]])

tmod = c.partial(Ansatz_fb).to(torch.float64)
tp = list(tmod.named_parameters())

def swap_params(jp, tp):
    jp_flat = flatten_dict(jp)
    print(len(jp_flat), len(tp))
    jp_b = [(k, jp0) for (k, jp0) in  jp_flat.items() if 'bias' in k]
    tp_b = [(k, jp0) for (k, jp0) in  tp if 'bias' in k]
    jp_p = [(k, jp0) for (k, jp0) in  jp_flat.items() if 'kernel' in k]
    tp_p = [(k, jp0) for (k, jp0) in  tp if 'weight' in k]

    order = [0, 2, 1] + [i for i in range(3, 9)]
    key = {jp_b[i][0]:tp_b[j][0] for j, i in enumerate(order)} | {jp_p[i][0]:tp_p[j][0] for j, i in enumerate(order)}
    jp_b = [jp_b[i] for i in order]
    jp_p = [jp_p[i] for i in order]
    new_tp = {tp_b[i][0] : np.array(b[1]) for i,b in enumerate(jp_b)} | {tp_p[i][0] : np.array(p[1]).T for i,p in enumerate(jp_p)}

    # print(len(jp_flat), len(tp))
    # for (k, jp0), (tpk, tp0) in zip(jp_b, tp_b):
    #     print('.'.join(k).replace('kernel', 'weight'), tpk, jp0.shape, tp0.shape)
        
    # print(len(jp_flat), len(tp))
    # for (k, jp0), (tpk, tp0) in zip(jp_p, tp_p):
    #     print('.'.join(k).replace('kernel', 'weight'), tpk, jp0.shape, tp0.shape)
    return new_tp, key

new_tp, key = swap_params(jp, tp)
for name, param in tmod.named_parameters():
    # print(name)
    # print(param.data.dtype, param.data.shape)
    param.data = torch.tensor(new_tp[name], requires_grad=True, dtype=torch.float64)
    # print(param.data.dtype, param.data.shape)

tmod = tmod.to(torch.float64)

print(tr.dtype, tr.shape)

import jax 
def compute_ke_b(state, r):

	grad_fn = jax.grad(lambda r: state.apply_fn(state.params, r).sum())

	n_b, n_e, n_dim = r.shape
	n_jvp = n_e * n_dim
	r = r.reshape(n_b, n_jvp)
	eye = jnp.eye(n_jvp, dtype=r.dtype)[None, ...].repeat(n_b, axis=0)

	def _body_fun(i, val):
		primal, tangent = jax.jvp(grad_fn, (r,), (eye[..., i],))  
		return val + (primal[:, i]**2).squeeze() + (tangent[:, i]).squeeze()

	return (- 0.5 * jax.lax.fori_loop(0, n_jvp, _body_fun, jnp.zeros(n_b,))).squeeze()


# def compute_ke_b(model_rv: nn.Module, r: torch.Tensor):
# 	dtype, device = r.dtype, r.device
# 	# MODEL IS FUNCTIONAL THAT IS VMAPPED
# 	n_b, n_e, n_dim = r.shape
# 	n_jvp = n_e * n_dim
# 	r_flat = r.reshape(n_b, n_jvp)
# 	eyes = torch.eye(n_jvp, dtype=dtype, device=device)[None].repeat((n_b, 1, 1))

# 	# vjp method
# 	# grad_fn = grad(lambda _r: model_rv(_r).sum())
# 	# g, fn = vjp(grad_fn, r_flat)
# 	# gg2 = torch.stack([fn(eye_b[..., i])[0][:, i] for i in range(n_jvp)], dim=0).sum(0)
# 	# lap = gg2 + (g**2).sum(-1)
 
# 	# jvp 
# 	grad_fn = grad(lambda _r: model_rv(_r).sum())
# 	jvp_all = [jvp(grad_fn, (r_flat,), (eyes[:, i],)) for i in range(n_jvp)]  # grad out, jvp
# 	e_jvp = torch.stack([a[:, i]**2+b[:, i] for i, (a,b) in enumerate(jvp_all)]).sum(0)
 
# 	#  (primal[:, i]**2).squeeze() + (tangent[:, i]).squeeze()
# 	# primal, tangent = jax.jvp(grad_fn, (r,), (eye[..., i],))  
# 	# 	return val + (primal[:, i]**2).squeeze() + (tangent[:, i]).squeeze()
# 	# return (- 0.5 * jax.lax.fori_loop(0, n_jvp, _body_fun, jnp.zeros(n_b,))).squeeze()

# 	return e_jvp


### train step ###
from jax import numpy as jnp
from typing import NamedTuple
from functools import partial 
import optax
from flax.training.train_state import TrainState

def create_train_state(rng, r):
	model = cj.partial(FermiNet)
	params = model.init(rng, r)
	opt = optax.chain(optax.clip_by_block_rms(1.),optax.adamw(0.001))
	return TrainState.create(apply_fn=model.apply, params=params, tx=opt)

def train_step(state, r_step):

    ke = compute_ke_b(state, r_step)
    return ke

state = create_train_state(rng, r)
jke = train_step(state, r)
# print(jke)
from hwat_func import compute_ke_b as ke
from functorch import vmap, make_functional 
model_fn, params = make_functional(tmod)
model_v = vmap(model_fn, in_dims=(None, 0))
model_rv = lambda _r: model_v(params, _r)
tke = ke(model_rv, tr)

print(tke.detach().cpu().numpy()*-0.5 - jke)

2022-12-22 01:04:05.944834: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


init PlugIn classes
updating configuration
Run: ['git', 'log', '--pretty=format:%h', '-n', '1'] at /home/amawi/projects/hwat
stdout: 61fc966 stderr: 
Run: ['hostname'] at .
stdout: oceanus  stderr: 
running script
setting exp_path
Run: ['git', 'log', '--pretty=format:%h', '-n', '1'] at /home/amawi/projects/hwat
stdout: 61fc966 stderr: 
Run: ['hostname'] at .
stdout: oceanus  stderr: 


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


init PlugIn classes
updating configuration
Run: ['git', 'log', '--pretty=format:%h', '-n', '1'] at /home/amawi/projects/hwat
stdout: 61fc966 stderr: 
Run: ['hostname'] at .
stdout: oceanus  stderr: 
running script
setting exp_path
Run: ['git', 'log', '--pretty=format:%h', '-n', '1'] at /home/amawi/projects/hwat
stdout: 61fc966 stderr: 
Run: ['hostname'] at .
stdout: oceanus  stderr: 
Run: ['git', 'log', '--pretty=format:%h', '-n', '1'] at /home/amawi/projects/hwat
stdout: 61fc966 stderr: 
Run: ['hostname'] at .
stdout: oceanus  stderr: 
Run: ['git', 'log', '--pretty=format:%h', '-n', '1'] at /home/amawi/projects/hwat
stdout: 61fc966 stderr: 
Run: ['hostname'] at .
stdout: oceanus  stderr: 
18 18
torch.float64 torch.Size([256, 4, 3])
Run: ['git', 'log', '--pretty=format:%h', '-n', '1'] at /home/amawi/projects/hwat
stdout: 61fc966 stderr: 
Run: ['hostname'] at .
stdout: oceanus  stderr: 


  eye = jnp.eye(n_jvp, dtype=r.dtype)[None, ...].repeat(n_b, axis=0)


RuntimeError: grad_and_value(f)(*args): Expected f(*args) to return a scalar Tensor, got tensor with 1 dims. Maybe you wanted to use the vjp or jacrev APIs instead?

In [28]:
def jpe(r, a=None, a_z=None):

	rr = jnp.expand_dims(r, -2) - jnp.expand_dims(r, -3)
	rr_len = jnp.linalg.norm(rr, axis=-1)
	pe_rr = jnp.tril(1./rr_len, k=-1).sum((1,2))

	if not (a is None):
		a, a_z = a[None, :, :], a_z[None, None, :]
		ra = jnp.expand_dims(r, -2) - jnp.expand_dims(a, -3)
		ra_len = jnp.linalg.norm(ra, axis=-1)
		pe_ra = (a_z/ra_len).sum((1,2))   
	
		if len(a) > 1:  # len(a) = n_a
			raise NotImplementedError
	return (pe_rr - pe_ra).squeeze()

def tpe(r, a=None, a_z=None):
	dtype, device = r.dtype, r.device
 
	pe_rr = torch.zeros(r.shape[0], dtype=dtype, device=device)
	pe_ra = torch.zeros(r.shape[0], dtype=dtype, device=device)
	pe_aa = torch.zeros(r.shape[0], dtype=dtype, device=device)

	rr = torch.unsqueeze(r, -2) - torch.unsqueeze(r, -3)
	rr_len = torch.linalg.norm(rr, axis=-1)
	pe_rr += torch.tril(1./rr_len, diagonal=-1).sum((-1,-2))

	if not a is None:
		a, a_z = a[None, :, :], a_z[None, None, :]
		ra = torch.unsqueeze(r, -2) - torch.unsqueeze(a, -3)
		ra_len = torch.linalg.norm(ra, axis=-1)
		pe_ra += (a_z/ra_len).sum((-1,-2))

		if len(a_z) > 1:
			# print('here')
			# aa = torch.unsqueeze(a, -2) - torch.unsqueeze(a, -3)
			# aa_len = torch.linalg.norm(aa, axis=-1)
			# pe_aa += torch.tril((a_z*a_z)/aa_len, diagonal=-1).sum((-1,-2))
			raise NotImplementedError

	return (pe_rr - pe_ra + pe_aa).squeeze()  

print(jpe(r, cj.data.a, cj.data.a_z).sum(), tpe(tr, c.app.a, c.app.a_z).sum())

-2343.137 tensor(-2343.1369)


In [29]:
print(tmod(tr[1]))

jmod.apply(jp, r[[1]])


tensor(-16.3276, grad_fn=<SqueezeBackward0>)


Array([-16.32761], dtype=float32)

In [13]:
new_tp

{'Vs.0.bias': array([0.00593489, 0.00268932, 0.00064195, 0.00706287, 0.00621831,
        0.00998503, 0.00726382, 0.00146654, 0.00408086, 0.00807343,
        0.00637496, 0.00919925, 0.00673006, 0.00812667, 0.00376891,
        0.00985797, 0.00689847, 0.00966867, 0.00270503, 0.00158928,
        0.00351574, 0.00832261, 0.00841896, 0.00154599, 0.00474995,
        0.00800335, 0.00841118, 0.00508559, 0.00370565, 0.00682438,
        0.00971892, 0.00518902], dtype=float32),
 'Vs.1.bias': array([0.00625007, 0.00542044, 0.00135756, 0.00795248, 0.00810666,
        0.00583005, 0.00183058, 0.00284598, 0.00146769, 0.00457759,
        0.00290021, 0.00924134, 0.00183403, 0.00595838, 0.00854883,
        0.00378785, 0.00534647, 0.00389087, 0.00645652, 0.00056143,
        0.00566135, 0.00917947, 0.00111457, 0.00533534, 0.008281  ,
        0.00224588, 0.00311537, 0.00027112, 0.00229247, 0.00872462,
        0.00720863, 0.00751973], dtype=float32),
 'Ws.0.bias': array([0.00060303, 0.00728955, 0.0071933 , 0.0