In [1]:
# Distribution ✨ jit ❇ 
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
from jax.config import config
config.update('jax_disable_jit', True)
from functools import partial

import jax
from jax import pmap
from jax import random as rnd, numpy as jnp
from flax.training.train_state import TrainState
from flax import linen as jax_utils
import optax

from pyfig import Pyfig
from hwat import FermiNet, sample, compute_ke_b, PotentialEnergy

import wandb

In [2]:

# args = {
#     'n_e': 6,
#     'n_u': 3,
#     'n_b': 256,
# }

c = Pyfig(wandb_mode='disabled')

rng = c.rng_init

x = c.data.init_walker(rng, n_b=2)    
print(f'Distributed shapes: rng {rng.shape} x {x.shape}')

@partial(jax.pmap, axis_name='b', in_axes=(0, 0))
def create_train_state(rng, x):
    model = c.partial(FermiNet)  
    params = model.init(rng, x) 
    tx = optax
    return TrainState.create(
        apply_fn=model.apply, 
        params=params, 
        tx=tx
    )

state = create_train_state(rng, x)
print('State ✅')  
# print(state.apply_fn(state.params, x))

deltar = jnp.array([[0.02]]).repeat(c.n_device)
sample = pmap(sample, in_axes=(0, 0, 0, 0))
print('Sample ✅')  
# print(sample(rng, state, x, deltar))

compute_pe = pmap(partial(PotentialEnergy(a=c.data.a, a_z=c.data.a_z).apply, {}))
print(compute_pe(x))
compute_ke = pmap(compute_ke_b, in_axes=(0, 0))
print(compute_ke(state, x))
print('Energy ✅')

2022-11-29 16:53:04.226858: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:497] The NVIDIA driver's CUDA version is 11.4 which is older than the ptxas CUDA version (11.6.55). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Distributed shapes:  (2, 2) (2, 2, 4, 3)
State ✅
Sample ✅
[[ -66.702484 -111.34934 ]
 [ -83.3162    -58.39977 ]]
[[532.11206  -32.812122]
 [ 11.851755  -8.156422]]
Energy ✅


In [8]:
from hwat import PotentialEnergy, sample, compute_ke_b
from jax import grad

@partial(pmap, in_axes=(0, 0, 0, 0))
def train_step(state, x, rng, deltar):
    
    x, v_sam = sample(rng, state, x, deltar)

    pe = partial(PotentialEnergy(a=c.data.a, a_z=c.data.a_z).apply, {})(x)
    ke = compute_ke_b(state, x)
    e = pe + ke

    def loss(x):
        return (e * state.apply_fn(state.params, x)).sum()

    grad_d = grad(loss)(x)

    v = dict(
        v_sam=v_sam,
        grad = grad_d,
        pe = pe,
        ke=ke,
        e=e,
        deltar=deltar,
        rng=rng,
        x=x
    )
    
    return state, v

train_step(state, x, rng, deltar)

(TrainState(step=ShardedDeviceArray([0, 0], dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of VmapFermiNet()>, params=FrozenDict({
     params: {
         Dense_0: {
             bias: ShardedDeviceArray([[0., 0., 0., 0., 0., 0., 0., 0.],
                                 [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
             kernel: ShardedDeviceArray([[[-0.17291118,  0.8124223 , -0.38197178,
                                   -0.09221374, -0.18767868, -0.5424766 ,
                                   -0.30145955,  0.37979496],
                                  [ 0.65056443, -0.0448688 , -0.44495034,
                                   -0.6970015 , -0.8876309 , -0.8185091 ,
                                   -0.9850648 ,  1.2940776 ],
                                  [-0.3604867 , -0.9748268 , -0.44511953,
                                   -0.57526463, -0.17571512,  0.17359738,
                                   -0.4216642 , -0.6350346 ]],
             
     

In [None]:
import numpy as np
from copy import copy
from utils import npify

def compute_metrix(d:dict, specific:list|dict=None):
	# specific: dict = {metric: [variable,]}
	def compute_mean_std(_d, k, v):
		try:
			_d[k+r'_\mu'], _d[k+r'_\sigma'] = np.mean(v), np.std(v)
		except:
			print(f'Could not compute mean or std of {k} type {type(v)}')
		return _d
	
	fancy = dict(
		pe		= r'V(X)',    				
		ke		= r'$\nabla^2',    		
		e		= r'E',						
		log_psi	= r'$\log\psi$', 			
		deltar	= r'\delta_\mathrm{r}',	
		x		= r'r_\mathrm{e}',
	)
	_d = {}
	for k,v in d.items():
		k = fancy.get(k, k)
		v = jax.tree_map(lambda x: np.array(jax.device_get(x)), v)
		
		v_mean = jax.tree_map(lambda x: x.mean(), v)
		v_std = jax.tree_map(lambda x: x.std(), v)
		_d[k+r'_\mu'], _d[k+r'_\sigma'] = v_mean, v_std

		# if isinstance(v, dict or list):
		# 	_d[k] = compute_metrix(v)
		# compute_mean_std(_d, k, v)
	return _d


In [4]:
wandb.define_metric("*", step_metric="tr/step")

for step in range(1, c.n_step+1):

	state, data = train_step(rng, state, x, deltar)
	state.apply_gradients(grads=grad)

	if step % c.log_metric_step:
		metrix = compute_metrix(data)
		wm = {f'tr/{k}': v for k, v in metrix.items()}
		wandb.log({'tr/step': step, **wm})

	if step % c.log_state_step:
		print('not implemented')

'\\mu'