In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from jax.config import config
config.update('jax_disable_jit', True)

from pathlib import Path
from typing import Callable
import wandb

import jax
from jax import numpy as jnp
from jax import random as rnd
from flax.training.train_state import TrainState
from flax.training import common_utils
from flax import linen as nn, jax_utils
import optax

from pyfig import Pyfig
from hwat import logabssumdet, create_masks

def wpr(d:dict):
    for k,v in d.items():
        typ = type(v) 
        has_shape = hasattr(v, 'shape')
        shape = v.shape if has_shape else None
        dtype = v.dtype if hasattr(v, 'dtype') else None
        mean = jnp.mean(v) if has_shape else v
        std = jnp.std(v) if has_shape else None
        print(k, f'\t mean={mean} \t std={std} \t shape={shape} \t dtype={dtype}') # \t type={typ}


# Method 1 completely refer to Pyfig:
    # - Can't get module in the args 
# Must have shape debug print

# 11am: 
# 1- Putting all variables into every Sub - done, it was a loopy mutable issue
# 2- Stop printing mask - done, moved masks to Ferminet
# 3- 

c = Pyfig(wandb_mode='disabled', debug=True) # online:on|disabled:off|offline:local, True: 
# c.d



2022-11-28 11:54:27.891218: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:497] The NVIDIA driver's CUDA version is 11.4 which is older than the ptxas CUDA version (11.6.55). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


{}


In [None]:
class FermiNet(nn.Module):
			n_e: int = None
			n_u: int = None
			n_d: int = None
			compute_s_emb: Callable = None
			compute_p_emb: Callable = None
			compute_s_perm: Callable = None
			n_det: int = None
			n_fb: int = None
			n_fb_out: int = None
			n_pv: int = None
			n_sv: int = None

			@nn.compact
			def __call__(_i, x):

							p_mask_u, p_mask_d = create_masks(_i.n_e, _i.n_u)

							xu, xd = jnp.split(x, [_i.n_u,], axis=0)
							x_s_var = _i.compute_s_emb(x)
							x_p_var = _i.compute_p_emb(x)
							wpr(dict(x_s_var=x_s_var, x_p_var=x_p_var))

							x_s_res = x_p_res = 0.
							for _ in range(_i.n_fb):
											x_p_var = x_p_res = nn.tanh(nn.Dense(_i.n_pv)(x_p_var)) + x_p_res
											x_s_var = _i.compute_s_perm(x_s_var, x_p_var, p_mask_u, p_mask_d)
											x_s_var = x_s_res = nn.tanh(nn.Dense(_i.n_sv)(x_s_var)) + x_s_res
											wpr(dict(x_p_var=x_p_var, x_s_var=x_s_var))

							x_w = nn.tanh(nn.Dense(_i.n_fb_out)(x_s_var))
							x_wu, x_wd = jnp.split(x_w, [_i.n_u,], axis=0)
							x_wu = nn.tanh(nn.Dense(_i.n_det*_i.n_u)(x_wu))
							x_wd = nn.tanh(nn.Dense(_i.n_det*_i.n_d)(x_wd))
							wpr(dict(x_w=x_w, x_wu=x_wu, x_wd=x_wd))

							orb_u = jnp.stack((x_wu * jnp.exp(-nn.Dense(_i.n_u*_i.n_det)(-xu))).split(_i.n_det, axis=-1)) # (e, f(e)) (e, (f(e))*n_det)
							orb_d = jnp.stack((x_wd * jnp.exp(-nn.Dense(_i.n_d*_i.n_det)(-xd))).split(_i.n_det, axis=-1))
							wpr(dict(orb_u=orb_u, orb_d=orb_d))

							log_psi, sgn = logabssumdet(orb_u, orb_d)
							return log_psi

model = c.pass_arg(FermiNet)

rng = rnd.PRNGKey(1)
x = rnd.normal(rng, (c.data.n_e, 3))
params = model.init(rng, x)
model.apply(params, x)

from functools import partial

@partial(jax.pmap, axis_name='b')
def create_train_state(rng):
	model = c.pass_arg(FermiNet)
	x = rnd.normal(rng, (c.data.n_e, 3))
	params = model.init(rng, x)['params']
	tx = optax.sgd(c.opt.lr)
	return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

rng = rnd.split(rng, len(jax.devices()))
state = create_train_state(rng)
state = jax_utils.replicate(state)

# train step framework get
# train step framework copy
# test energy function
# kinetic energy
# potential energy
# include atoms
# include sampler
# write metric
# run loop
from hwat import SampleState
sample = SampleState()

@partial(jax.pmap, axis_name='x')
def train_step(state, x):

	x = sample(x, state)

	ke = c.compute_ke()
	pe = c.compute_pe()
	e = ke+pe

	def loss_fn(p):
		out = model.apply({'params': p}, x) 
		return jnp.mean(out*e)

	grad_fn = jax.value_and_grad(loss_fn, has_aux=False)  # has_aux for more than one out
	out, grads = grad_fn(state.params)
	log_psi = out

	v_b = { # scalars
		r'V(X)'    					: pe,
		r'$\nabla^2'    			: ke,
		'E'							: e,
		r'$\log\psi$' 				: log_psi,
		r'\delta_\mathrm{r}'		: sample.move_std,
	}

	return state, grads, x, v_b

@jax.pmap
def update_model(state, grads):
	return state.apply_gradients(grads=grads)

In [None]:

wandb.define_metric("*", step_metric="train/step")

# # Display a project workspace
# %wandb USERNAME/PROJECT
# # Display a single run
# %wandb USERNAME/PROJECT/runs/RUN_ID
# # Display a sweep
# %wandb USERNAME/PROJECT/sweeps/SWEEP_ID
# # Display a report
# %wandb USERNAME/PROJECT/reports/REPORT_ID
# # Specify the height of embedded iframe
# %wandb USERNAME/PROJECT -h 2048

In [None]:
train_metrics = []

for step in range(c.n_step):
        
    state, grads, x, data = train_step(state, x)
    state = state.apply_gradients(grads=grads)

    if step % c.log_metric_step == 0:
        
        # r'sgn$(\cdot)$'     : sgn
        metric = compute_metric(metric)
        
        summary = {
            f'train/{k}': v
            for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()
        }

        wandb.log({
                "train/step": step, 
                **summary
        })


In [None]:
# EMA Decay extension
if step <= c.ema.update_after_step:
            state = p_copy_params_to_ema(state)

        elif step % c.ema.update_every == 0:
            ema_decay = ema_decay_fn(step)
            state =  p_apply_ema(state, ema_decay)