In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from jax.config import config
config.update('jax_disable_jit', True)

from pathlib import Path
from typing import Callable
import wandb

from jax import vmap
from functools import partial
import jax
from jax import numpy as jnp
from jax import random as rnd
from flax.training.train_state import TrainState
from flax.training import common_utils
from flax import linen as nn, jax_utils
import optax

from pyfig import Pyfig
import os
from utils import debug, wpr
from typing import Callable
from pprint import pprint

In [3]:
c = Pyfig(wandb_mode='disabled', debug=True)  # online:on|disabled:off|offline:local, True:
pprint(c.d)

2022-11-28 20:15:28.777897: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:497] The NVIDIA driver's CUDA version is 11.4 which is older than the ptxas CUDA version (11.6.55). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


{}
{'TMP': PosixPath('/home/amawi/projects/hwat/exp/tmp'),
 'commit_id': '76377b6',
 'data': {'a': DeviceArray([[0., 0., 0.]], dtype=float32),
          'a_z': DeviceArray([4.], dtype=float32),
          'n_b': 16,
          'n_d': 2,
          'n_e': 4,
          'n_u': 2},
 'data_dir': PosixPath('/home/amawi/projects/data'),
 'dtype': 'f32',
 'entity': 'xmax1',
 'env': 'dex',
 'exp_id': 'qpSsCK4',
 'exp_name': 'junk',
 'exp_path': PosixPath('/home/amawi/projects/hwat/exp/junk/qpSsCK4'),
 'git_branch': 'main',
 'git_remote': 'origin',
 'half_precision': True,
 'iter_exp_dir': True,
 'log_metric_step': 5,
 'log_sample_step': 5,
 'log_state_step': 10,
 'merge': <bound method Pyfig.merge of <pyfig.Pyfig object at 0x7fa704c3b460>>,
 'model': {'compute_p_emb': functools.partial(<function compute_emb at 0x7fa63b393e20>, terms=['xx']),
           'compute_s_emb': functools.partial(<function compute_emb at 0x7fa63b393e20>, terms=['x_rlen']),
           'compute_s_perm': functools.partial(<fun

In [19]:
from hwat import FermiNet, init_walker

rng = rnd.PRNGKey(c.seed)

x_init = rnd.normal(rng, (c.data.n_e, 3)) # NB no batch dim - batchless implementation

model = c.partial(FermiNet)
params = model.init(rng, x_init)['params'] # {'params':p, ... other variables if they exist}
model.apply({'params':params}, x_init) # potentially the only way to run something in jax
# model(x_init) # Can't call compact methods on unbound modules - what? - Unbound meaning not associated w another module
# see state.apply_fn(params, x_init) below for more craziness

DeviceArray(-6.689785, dtype=float32)

In [15]:
@partial(jax.pmap, axis_name='b')
def create_train_state(rng):
	model = c.partial(FermiNet)
	x = rnd.normal(rng, (c.data.n_e, 3))
	params = model.init(rng, x)['params']
	tx = optax.sgd(c.opt.lr)
	return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

rng_n = rnd.split(rng, len(jax.devices()))
state = create_train_state(rng_n)
state = jax_utils.replicate(state)
# state.params

In [34]:
# debug(True)
# wpr(dict(p=state.params))
# debug(False)
# type(state.params) # is the param dictionary, not the variable dictionary

KeyError: 'params'

In [30]:
from hwat import SampleState
x = init_walker(rng, c.data.n_b, c.data.n_u, c.data.n_d, center=c.data.a, std=0.1)
print('Walker shape: ', x.shape)
model.apply({'params':params}, x_init) # potentially the only way to run something in jax
sample = SampleState(rng, model=model)
sample_vmap = vmap(sample, in_axes=(None, 0))
# sample_vmap(params, x)

Walker shape:  (16, 4, 3)


In [26]:
from hwat import compute_pe
print('Walker shape: ', x.shape)
compute_pe(x, c.data.a, c.data.a_z)

Walker shape:  (16, 4, 3)


DeviceArray([  22.007858, -141.53467 ,  -45.23154 ,  -95.129295,
             -116.35297 , -173.2504  ,  -62.280594,  -60.103165,
              -98.63029 ,  -84.12774 ,  -81.23646 , -110.813065,
             -161.86557 ,  -86.836624, -127.94581 ,  -89.53885 ],            dtype=float32)

In [31]:
from hwat import create_compute_ke
debug(False)
compute_ke = create_compute_ke(model)
compute_ke(params, x)

DeviceArray([ 3115.9775  ,   -77.107544,    24.304708,   -15.595001,
                36.848007,    31.614841,    14.340487,   -34.231533,
               240.05957 ,  -137.8808  ,   813.81354 ,    64.84529 ,
             -4007.5537  ,  -185.4494  ,    68.41682 ,   -21.068924],            dtype=float32)

In [5]:


# train step framework get
# train step framework copy
# test energy function
# kinetic energy
# potential energy
# include atoms
# include sampler
# write metric
# run loop

@partial(jax.pmap, axis_name='x')
def train_step(state, x):
	params = state.params

	x = sample(params, x)

	ke = c.compute_ke()
	pe = c.compute_pe()
	e = ke+pe

	def loss_fn(p):
		out = state.apply({'params': p}, x) 
		return jnp.mean(out*e)

	grad_fn = jax.value_and_grad(loss_fn, has_aux=False)  # has_aux for more than one out
	out, grads = grad_fn(params)
	log_psi = out		
		
	v_b = dict( # scalars
			: pe,
			: ke,
			: e,
			: log_psi,
			: sample.move_std,
			: x
	)

	return state, grads, x, v_b

@jax.pmap
def update_model(state, grads):
	return state.apply_gradients(grads=grads)

In [None]:

wandb.define_metric("*", step_metric="train/step")

def compute_metric(d:dict):
    ...
    metrics = lax.pmean(metrics, axis_name='b')

    return metric
c = Pyfig(wandb_mode='disabled', debug=True)  # online:on|disabled:off|offline:local, True:
# pprint(c.d)

# # Display a project workspace
# %wandb USERNAME/PROJECT
# # Display a single run
# %wandb USERNAME/PROJECT/runs/RUN_ID
# # Display a sweep
# %wandb USERNAME/PROJECT/sweeps/SWEEP_ID
# # Display a report
# %wandb USERNAME/PROJECT/reports/REPORT_ID
# # Specify the height of embedded iframe
# %wandb USERNAME/PROJECT -h 2048

In [None]:
train_metrics = []

for step in range(c.n_step):
        
    state, grads, x, data = train_step(state, x)
    state = state.apply_gradients(grads=grads)

    if step % c.log_metric_step == 0:
        
        # r'sgn$(\cdot)$'     : sgn
        metric = compute_metric(metric)
        
        summary = {
            f'train/{k}': v
            for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()
        }

        wandb.log({
                "train/step": step, 
                **summary
        })

nice_keys = dict(
    r'V(X)'    				
    r'$\nabla^2'    		
    'E'						
    r'$\log\psi$' 			
    r'\delta_\mathrm{r}'	
    r'r_\mathrm{e}'	
)


In [None]:
# EMA Decay extension
if step <= c.ema.update_after_step:
            state = p_copy_params_to_ema(state)

        elif step % c.ema.update_every == 0:
            ema_decay = ema_decay_fn(step)
            state =  p_apply_ema(state, ema_decay)

def create_model_apply():
        model = c.pass_arg(FermiNet)
        model.init(rng, rnd.normal(rng, (c.data.n_e, 3)))
        @vmap
        def model_apply(x):
            out = model.apply(params, x)
            return out.sum()
        return model_apply