In [1]:
# Distribution ✨ jit ❇ Demo 💪 
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

%load_ext autoreload
%autoreload 2
%env "WANDB_NOTEBOOK_NAME" "run.ipynb" # ❕same as notebook
# from jax.config import config
# config.update('jax_disable_jit', True)

env: "WANDB_NOTEBOOK_NAME"="run.ipynb" # ❕same as notebook


In [2]:
from functools import partial
import numpy as np
import re as regex
from pprint import pprint

import jax
from jax import pmap, grad
from jax import numpy as jnp
from jax import random as rnd
from flax.training.train_state import TrainState
from flax.core.frozen_dict import FrozenDict	
import optax

from pyfig import Pyfig
from hwat import FermiNet, sample, compute_ke_b, PotentialEnergy
from hwat import PotentialEnergy, sample, compute_ke_b
import wandb

from utils import flat_any
from utils import debug
from hwat import check_antisym


In [3]:
# with jax.checking_leaks():

args = {'l_e':[4,], 'a_z':[4,], 'n_u': 2,'n_b': 16, 'n_sv': 16, 'n_pv': 8, 'corr_len': 20, 'n_step': 1000, 'log_metric_step': 1,'exp_name':'demo-final'}
args = {}

c = Pyfig(wandb_mode='online', args=args, get_sys_arg=False)

rng = rnd.split(rnd.PRNGKey(c.seed), c.n_device)
r = c.data.init_walker(rng, n_b=c.data.n_b)
print(f'innit: r {r.shape} rng {rng.shape} ✅')

@partial(jax.pmap, axis_name='dev', in_axes=(0,0))
def create_train_state(rng, r):
    model = c.partial(FermiNet)  
    params = model.init(rng, r)
    return TrainState.create(apply_fn=model.apply, params=params, tx=optax.chain(optax.adaptive_grad_clip(0.1), optax.adam(c.opt.lr)))
state = create_train_state(rng, r)
print('Model ✅')

@partial(jax.pmap, axis_name='b', in_axes=(0,0,0,0))
def equil(rng, state, r, deltar):
    r, v_sam = partial(sample, acc_target=c.data.acc_target)(rng, state, r, deltar)
    return r, v_sam
print('Equil ✅')

deltar = jnp.ones((c.n_device, 1), dtype=r.dtype)*0.02
for step in range(1, 2):
    r, v_sam = equil(rng, state, r, deltar)

check_antisym(c, rng, r)

Path:  /home/amawi/projects/hwat/exp/demo-final/TEVJ5pJ ✅
System 
{'a': array([[0., 0., 0.]]),
 'a_z': array([4.]),
 'acc_target': 0.5,
 'corr_len': 20,
 'equil_len': 10000,
 'init_walker': functools.partial(<function init_walker at 0x7f0a244201f0>, n_b=512, n_u=2, n_d=2, center=array([[0., 0., 0.]]), std=0.1),
 'l_e': [4],
 'n_b': 512,
 'n_d': 2,
 'n_e': 4,
 'n_u': 2,
 'with_sign': False}
Model 
{'compute_p_emb': functools.partial(<function compute_emb at 0x7f0a243f7010>, terms=['rr']),
 'compute_s_emb': functools.partial(<function compute_emb at 0x7f0a243f7010>, terms=['r_len', 'r', 'ra', 'ra_len'], a=array([[0., 0., 0.]])),
 'n_det': 1,
 'n_fb': 3,
 'n_fbv': 128,
 'n_pv': 16,
 'n_sv': 32,
 'terms_p_emb': ['rr'],
 'terms_s_emb': ['r_len', 'r', 'ra', 'ra_len']}


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mxmax1[0m. Use [1m`wandb login --relogin`[0m to force relogin


run:  xmax1/hwat/2fmwizzz ✅
innit: r (1, 512, 4, 3) rng (1, 2) ✅
Model ✅
Equil ✅
[ 0.10409795  0.04441693 -0.10007349] [ 0.07109304  0.08764078 -0.04961066] [ 0.10409795  0.04441693 -0.10007349]
[ 0.07109304  0.08764078 -0.04961066] [ 0.10409795  0.04441693 -0.10007349] [ 0.07109304  0.08764078 -0.04961066]
[-0.02051444  0.0849402   0.06399783] [-0.02051444  0.0849402   0.06399783] [-0.02405036  0.04376032 -0.06164024]
[-0.02405036  0.04376032 -0.06164024] [-0.02405036  0.04376032 -0.06164024] [-0.02051444  0.0849402   0.06399783]
-13.1151495 -13.253824 -13.348494
-10.732815 -10.62299 -10.751295
-10.161472 -10.20397 -10.236872
-11.279682 -11.407277 -11.183874
1.0 -1.0 -1.0
-1.0 1.0 1.0
-1.0 1.0 1.0
1.0 -1.0 -1.0


In [4]:
# Equilibration
wandb.define_metric("*", step_metric="eq/step")
for step in range(1, c.data.equil_len//c.data.corr_len + 1):
    r, v_sam = equil(rng, state, r, deltar)
    rng, deltar, acc = v_sam['rng'], v_sam['deltar'], v_sam['acc']
    # if not (step % c.log_metric_step//10):
    if not (step % 1):
        wandb.log({'eq/step':step, 'acc': acc.mean(), 'deltar': deltar.mean()})
print('Walkers ✅ Training Variables ✅')

@partial(pmap, in_axes=(0, 0, 0, 0))
def train_step(rng, state, r, deltar):
    
    r, v_sam = partial(sample, acc_target=c.data.acc_target)(rng, state, r, deltar)
    pe = partial(PotentialEnergy(a=c.data.a, a_z=c.data.a_z).apply, {})(r)
    ke = compute_ke_b(state, r)
    e = pe + ke

    def loss(_params):
        return ((e - jnp.mean(e)) * state.apply_fn(_params, r)).mean()
    
    grads = grad(loss)(state.params)

    v = dict(
        rng=rng, \
        grads=grads, r=r, \
        pe=pe, ke=ke, e=e, \
        deltar=v_sam['deltar'], acc=v_sam['acc']
    )
                    
    state = state.apply_gradients(grads=grads)
    return state, v

state, data = train_step(rng, state, r, deltar)
print('Train Step ✅')

print('Go seek: ', c.wandb_c.wandb_run_path)

Walkers ✅ Training Variables ✅
Train Step ✅
Go seek:  xmax1/hwat/2fmwizzz


In [5]:
def collect_stats(k, v, new_d, p='tr', suf='', sep='/', sep_long='-'):
	depth = p.count('/')
	if depth > 1:
		sep = sep_long
	if isinstance(v, dict):
		for k_sub,v_sub in v.items():
			collect_stats(k, v_sub, new_d, p=(p+sep+k_sub))
	else:
		new_d[p+sep+k+suf] = v
	return new_d

def compute_metrix(d:dict, mode='tr'):
	pattern_ignore = ['Dense']
	fancy = dict(
		pe		= r'$V(X)',    				
		ke		= r'$\nabla^2',    		
		e		= r'$E',						
		log_psi	= r'$\log\psi', 			
		deltar	= r'$\delta_\mathrm{r}',	
		x		= r'$r_\mathrm{e}',
	)
	_d = {}
	for k,v in d.items():
		k = fancy.get(k, k)
		v = jax.device_get(v)
		if isinstance(v, FrozenDict):
			v = v.unfreeze()
		
		v_mean = jax.tree_map(lambda x: x.mean(), v) if not np.isscalar(v) else v
		v_std = jax.tree_map(lambda x: x.std(), v) if not np.isscalar(v) else 0.

		_d = collect_stats(k, v_mean, _d, p=mode, suf=r'_\mu$')
		_d = collect_stats(k, v_std, _d, p=mode, suf=r'_\sigma$')

	# return {k:v for k,v in _d.items() if not any([regex.match(k, pat) for pat in pattern_ignore])}
	return {k:v for k,v in _d.items() if not any([pat in k for pat in pattern_ignore])}


In [6]:
wandb.define_metric("*", step_metric="tr/step")

for step in range(1, c.n_step+1):
    state, data = train_step(rng, state, r, deltar)
    rng, deltar, r = data['rng'], data['deltar'], data['r']

    # if not (step % c.log_metric_step):
    if not (step % 1):
        metrix = compute_metrix(data)
        wandb.log({'tr/step':step, **metrix})
        m = ' '.join([f'{k} {v:.5f} ' for k,v in metrix.items() if 'E_' in k])
        print(f'Step {step} {m}')

        nans = {k:jnp.isnan(v) for k,v in metrix.items()}
        for k,v in nans.items():
        
            print(k)
            pprint(nans)

    if step == 100:
        break

        

    # if not (step % c.log_state_step):
    #     ...

  x = asanyarray(arr - arrmean)


Step 1 tr/$E_\mu$ nan  tr/$E_\sigma$ nan 
tr/acc_\mu$
{'tr/$E_\\mu$': DeviceArray(True, dtype=bool),
 'tr/$E_\\sigma$': DeviceArray(True, dtype=bool),
 'tr/$V(X)_\\mu$': DeviceArray(False, dtype=bool),
 'tr/$V(X)_\\sigma$': DeviceArray(True, dtype=bool),
 'tr/$\\delta_\\mathrm{r}_\\mu$': DeviceArray(False, dtype=bool),
 'tr/$\\delta_\\mathrm{r}_\\sigma$': DeviceArray(False, dtype=bool),
 'tr/$\\nabla^2_\\mu$': DeviceArray(True, dtype=bool),
 'tr/$\\nabla^2_\\sigma$': DeviceArray(True, dtype=bool),
 'tr/acc_\\mu$': DeviceArray(False, dtype=bool),
 'tr/acc_\\sigma$': DeviceArray(False, dtype=bool),
 'tr/params/p_0-bias-grads': DeviceArray(True, dtype=bool),
 'tr/params/p_0-kernel-grads': DeviceArray(True, dtype=bool),
 'tr/params/p_1-bias-grads': DeviceArray(True, dtype=bool),
 'tr/params/p_1-kernel-grads': DeviceArray(True, dtype=bool),
 'tr/params/p_2-bias-grads': DeviceArray(True, dtype=bool),
 'tr/params/p_2-kernel-grads': DeviceArray(True, dtype=bool),
 'tr/params/s_0-bias-grads': D

In [None]:

# heroku create

# codey code

# 🌵 


In [None]:
""" bone zone

# 💓 
class tr_data:
    x:jnp.ndarray=x
    rng:jnp.ndarray=rng

# Likely needed for updating outside the loop, not sure
@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)
print('Update ✅')


# add mutable states to the trian state where the parallelism is handled? 
from flax.training import train_state

class TrainState(train_state.TrainState):
  batch_stats: flax.core.FrozenDict[str, Any]

# how to include other variables
def loss_fn(params):
    outputs, new_model_state = state.apply_fn(
        {'params': params, 'batch_stats': state.batch_stats},
        inputs,
        mutable=['batch_stats'])
    loss = xent_loss(outputs, labels)
    return loss, new_model_state

  (loss, new_model_state), grads = jax.value_and_grad(
      loss_fn, has_aux=True)(state.params)
  new_state = state.apply_gradients(
      grads=grads,
      batch_stats=new_model_state['batch_stats'],
  )

class arg_cls(Pyfig):
    def __init__(_i):
        pass
arg_cls.data.n_e = 6
arg_cls.data.n_u = 6
arg_cls.data.n_b = 6
arg_cls.n_step = 20
arg_cls.log_metric_step = 5
arg_cls.exp_name = 'demo'
args = flat_any(arg_cls().d)

"""