In [1]:
# Distribution ✨ jit ❇ Demo 💪 
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
%env "WANDB_NOTEBOOK_NAME" "run" # ❕same as notebook
# from jax.config import config
# config.update('jax_disable_jit', True)

from functools import partial
import numpy as np
import re as regex
from pprint import pprint

import jax
from jax import pmap, grad
from jax import numpy as jnp
from flax.training.train_state import TrainState
from flax.core.frozen_dict import FrozenDict	

from pyfig import Pyfig
from hwat import FermiNet, sample, compute_ke_b, PotentialEnergy
from hwat import PotentialEnergy, sample, compute_ke_b
import wandb

from utils import flat_any

env: "WANDB_NOTEBOOK_NAME"="run" # ❕same as notebook


In [2]:
def collect_stats(k, v, new_d, p='tr', suf='', sep='/', sep_long='-'):
	depth = p.count('/')
	if depth > 1:
		sep = sep_long
	if isinstance(v, dict):
		for k_sub,v_sub in v.items():
			collect_stats(k, v_sub, new_d, p=(p+sep+k_sub))
	else:
		new_d[p+sep+k+suf] = v
	return new_d

def compute_metrix(d:dict, mode='tr'):
	pattern_ignore = ['Dense']
	fancy = dict(
		pe		= r'$V(X)',    				
		ke		= r'$\nabla^2',    		
		e		= r'$E',						
		log_psi	= r'$\log\psi', 			
		deltar	= r'$\delta_\mathrm{r}',	
		x		= r'$r_\mathrm{e}',
	)
	_d = {}
	for k,v in d.items():
		k = fancy.get(k, k)
		v = jax.device_get(v)
		if isinstance(v, FrozenDict):
			v = v.unfreeze()
		
		v_mean = jax.tree_map(lambda x: x.mean(), v) if not np.isscalar(v) else v
		v_std = jax.tree_map(lambda x: x.std(), v) if not np.isscalar(v) else 0.

		_d = collect_stats(k, v_mean, _d, p=mode, suf=r'_\mu$')
		_d = collect_stats(k, v_std, _d, p=mode, suf=r'_\sigma$')

	# return {k:v for k,v in _d.items() if not any([regex.match(k, pat) for pat in pattern_ignore])}
	return {k:v for k,v in _d.items() if not any([pat in k for pat in pattern_ignore])}


In [6]:
args = {'l_e':[4,],'n_u': 2,'n_b': 64, 'n_sv': 16, 'n_pv': 8, 'corr_len': 20, 'n_step': 1000, 'log_metric_step': 1,'exp_name':'demo-final'}
c = Pyfig(wandb_mode='online', args=args, get_sys_arg=False)


0 args unmerged: ✅
Path:  /home/amawi/projects/hwat/exp/demo-final/3sHgDsN ✅
System 
{'a': DeviceArray([[0., 0., 0.]], dtype=float32),
 'a_z': DeviceArray(4, dtype=int32, weak_type=True),
 'acc_target': 0.5,
 'corr_len': 20,
 'equil_len': 1000,
 'init_walker': functools.partial(<function init_walker at 0x7f386eaa5a20>, n_b=64, n_u=2, n_d=2, center=DeviceArray([[0., 0., 0.]], dtype=float32), std=0.1),
 'l_e': [4],
 'n_b': 64,
 'n_d': 2,
 'n_e': 4,
 'n_u': 2}
Model 
{'compute_p_emb': functools.partial(<function compute_emb at 0x7f386eaa4940>, terms=['xx']),
 'compute_s_emb': functools.partial(<function compute_emb at 0x7f386eaa4940>, terms=['x_rlen', 'x']),
 'compute_s_perm': functools.partial(<function compute_s_perm at 0x7f386eaa51b0>, n_u=2),
 'masks': '...',
 'n_det': 1,
 'n_fb': 16,
 'n_fb_out': 64,
 'n_pv': 8,
 'n_sv': 16,
 'terms_p_emb': ['xx'],
 'terms_s_emb': ['x_rlen', 'x']}


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669286298565567, max=1.0…

run:  xmax1/hwat/tu81yc6r ✅


In [7]:
part_sample = partial(sample, acc_target=c.data.acc_target)

rng = c.rng_init
x = c.data.init_walker(rng, n_b=c.data.n_b)
deltar = jnp.ones((c.n_device, 1), dtype=x.dtype)*0.02
print(f'Init: x {x.shape} rng {rng.shape} ✅')

@partial(jax.pmap, axis_name='b', in_axes=(0,0))
def create_train_state(rng, x):
    model = c.partial(FermiNet)  
    params = model.init(rng, x)
    state = TrainState.create(apply_fn=model.apply, params=params, tx=c.opt.tx)
    return state

state = create_train_state(rng, x)
print('Model ✅')

@partial(jax.pmap, axis_name='b', in_axes=(0,0,0,0))
def equil(rng, state, x, deltar):
    x, v_sam = part_sample(rng, state, x, deltar)
    return x, v_sam
print('Equil ✅')

wandb.define_metric("*", step_metric="eq/step")
for step in range(1, c.data.equil_len//c.data.corr_len):
    x, v_sam = equil(rng, state, x, deltar)
    rng, deltar, acc = v_sam['rng'], v_sam['deltar'], v_sam['acc']
    if not (step % c.log_metric_step):
        wandb.log({'eq/step':step, **v_sam})
print('Walkers ✅ Training Variables ✅')

@partial(pmap, in_axes=(0, 0, 0, 0))
def train_step(rng, state, x, deltar):
    x, v_sam = part_sample(rng, state, x, deltar)

    pe = partial(PotentialEnergy(a=c.data.a, a_z=c.data.a_z).apply, {})(x)
    ke = compute_ke_b(state, x)
    e = pe + ke

    def loss(_params):
        return (e * state.apply_fn(_params, x)).mean()
    
    grads = grad(loss)(state.params)

    v = dict(grads=grads, pe=pe, ke=ke, e=e, deltar=v_sam['deltar'],
                    rng=rng, x=x, acc=v_sam['acc'])
                    
    state = state.apply_gradients(grads=grads)
    return state, v

state, data = train_step(rng, state, x, deltar)
print('Train Step ✅')

print('Go seek: ', c.wandb_c.wandb_run_path)

Init: x (1, 64, 4, 3) rng (1, 2) ✅
Model ✅
Equil ✅
Walkers ✅ Training Variables ✅
Train Step ✅
Go seek:  xmax1/hwat/tu81yc6r


In [None]:

# check all params
# check model top to bottoms
# check acc and deltar
# test 1


wandb.define_metric("*", step_metric="tr/step")
for step in range(1, c.n_step+1):
    state, data = train_step(rng, state, x, deltar)
    rng, deltar, x = data['rng'], data['deltar'], data['x']

    if not (step % c.log_metric_step):
        metrix = compute_metrix(data)
        wandb.log({'tr/step':step, **metrix})
        m = ' '.join([f'{k} {v:.5f} ' for k,v in metrix.items() if 'E_' in k])
        print(f'Step {step} {m}')
    if not (step % c.log_state_step):
        ...

Step 1 tr/$E_\mu$ 22.07416  tr/$E_\sigma$ 52.53654 
Step 2 tr/$E_\mu$ 29.62917  tr/$E_\sigma$ 81.63223 
Step 3 tr/$E_\mu$ 33.16518  tr/$E_\sigma$ 83.25869 
Step 4 tr/$E_\mu$ 31.72752  tr/$E_\sigma$ 100.58117 
Step 5 tr/$E_\mu$ 39.78465  tr/$E_\sigma$ 109.34447 
Step 6 tr/$E_\mu$ 43.45498  tr/$E_\sigma$ 79.99047 
Step 7 tr/$E_\mu$ 40.46207  tr/$E_\sigma$ 63.64935 
Step 8 tr/$E_\mu$ 50.50372  tr/$E_\sigma$ 123.09397 
Step 9 tr/$E_\mu$ 46.79241  tr/$E_\sigma$ 173.12579 
Step 10 tr/$E_\mu$ 59.42635  tr/$E_\sigma$ 392.29947 
Step 11 tr/$E_\mu$ 28.16961  tr/$E_\sigma$ 105.05470 
Step 12 tr/$E_\mu$ 27.50140  tr/$E_\sigma$ 67.91907 
Step 13 tr/$E_\mu$ 27.80373  tr/$E_\sigma$ 68.04050 
Step 14 tr/$E_\mu$ 35.46673  tr/$E_\sigma$ 117.95511 
Step 15 tr/$E_\mu$ 35.19810  tr/$E_\sigma$ 130.44794 
Step 16 tr/$E_\mu$ 31.32325  tr/$E_\sigma$ 128.62617 
Step 17 tr/$E_\mu$ 6.54882  tr/$E_\sigma$ 113.23344 
Step 18 tr/$E_\mu$ 23.03849  tr/$E_\sigma$ 66.42947 
Step 19 tr/$E_\mu$ 27.84310  tr/$E_\sigma$ 127

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 44)

In [None]:

# heroku create

# codey code

# 🌵 


In [None]:
""" bone zone

# 💓 
class tr_data:
    x:jnp.ndarray=x
    rng:jnp.ndarray=rng

# Likely needed for updating outside the loop, not sure
@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)
print('Update ✅')


# add mutable states to the trian state where the parallelism is handled? 
from flax.training import train_state

class TrainState(train_state.TrainState):
  batch_stats: flax.core.FrozenDict[str, Any]

# how to include other variables
def loss_fn(params):
    outputs, new_model_state = state.apply_fn(
        {'params': params, 'batch_stats': state.batch_stats},
        inputs,
        mutable=['batch_stats'])
    loss = xent_loss(outputs, labels)
    return loss, new_model_state

  (loss, new_model_state), grads = jax.value_and_grad(
      loss_fn, has_aux=True)(state.params)
  new_state = state.apply_gradients(
      grads=grads,
      batch_stats=new_model_state['batch_stats'],
  )

class arg_cls(Pyfig):
    def __init__(_i):
        pass
arg_cls.data.n_e = 6
arg_cls.data.n_u = 6
arg_cls.data.n_b = 6
arg_cls.n_step = 20
arg_cls.log_metric_step = 5
arg_cls.exp_name = 'demo'
args = flat_any(arg_cls().d)

"""