In [1]:
import os
os.environ['JAX_PLATFORM_NAME']='cpu'
os.environ['XLA_FLAGS']="--xla_force_host_platform_device_count=4"

import jax
import jax.numpy as jnp
from jax import local_device_count, vmap, jit, grad, lax
devices = jax.devices()
n_devices = len(devices)
print('Devices: ', devices)
from jax import pmap
import jax.random as rnd
from jax.tree_util import tree_unflatten, tree_flatten

from collections import OrderedDict

from nn_ansatx import *

Devices:  [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]


  h5py.get_config().default_file_mode = 'a'


In [2]:
key_gen = lambda keys: [x.squeeze() for x in jnp.array([rnd.split(key) for key in keys]).split(2, axis=1)]

In [3]:
def split_variables_for_pmap(n_devices, *args):
    for i in range(len(args))[:-1]:
        assert len(args[i]) == len(args[i+1])
    
    assert len(args[0]) % n_devices == 0
        
    new_args = []
    for arg in args:
        shape = arg.shape
        new_args.append(arg.reshape(n_devices, shape[0] // n_devices, *shape[1:]))
    
    if len(args) == 1:
        return new_args[0]
    return new_args
        

In [4]:
key = rnd.PRNGKey(123)

config = setup(n_walkers=8)

mol = SystemAnsatz(**config)

wf, kfac_wf, wf_orbitals = create_wf(mol)
params = initialise_params(key, mol)
d0s = initialise_d0s(mol)
walkers = mol.initialise_walkers(n_walkers=config['n_walkers'])
walkers = split_variables_for_pmap(4, walkers)

version 		 130521
seed 		 369
n_devices 		 4
save_every 		 1000
print_every 		 0
exp_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m8_s32_p8_l2_det2/run7
events_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m8_s32_p8_l2_det2/run7/events
models_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m8_s32_p8_l2_det2/run7/models
opt_state_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m8_s32_p8_l2_det2/run7/models/opt_state
pre_path 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/pretrained/s32_p8_l2_det2_1lr-4_i1000.pk
timing_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m8_s32_p8_l2_det2/run7/events/timing
system 		 Be
r_atoms 		 [[0. 0. 0.]]
z_atoms 		 [4.]
n_el 		 4
n_el_a

In [5]:
key = rnd.PRNGKey(123)
keys = rnd.split(key, 4)
keys, subkeys = key_gen(keys) 
print(keys.shape, subkeys.shape)

(4, 2) (4, 2)


In [6]:
# print(kfac_wf)
# print(wf)
# print(mol)
# flat_params, tree = tree_flatten(params)
# flat_d0s, tree = tree_flatten(d0s)
# [print(v.shape) for v in flat_params]
# print(walkers.shape)
# [print(v.shape) for v in flat_d0s]
print(d0s['split0'].shape)

(4, 2, 1, 32)


In [7]:
sampler, equilibrate = create_sampler(wf, mol)


In [8]:
walkers, acceptance, step_size = sampler(params, walkers, d0s, subkeys, config['step_size'])
print(walkers.shape)
print(walkers[2].device())

devs = jax.devices()

%timeit x = jax.device_put(walkers[1], devs[0]).block_until_ready()
%timeit x = jax.device_put(walkers[0], devs[0]).block_until_ready()
%timeit x = walkers[0]

# %timeit x = jax.device_put(walkers[0], devs[0]).block_until_ready()
# %timeit x = jax.device_put(walkers[1], devs[0]).block_until_ready()
# %timeit x = jax.device_put(walkers[0], devs[0]).block_until_ready()

walkers[0].device()


(4, 2, 4, 3)
cpu:2
50 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
30 µs ± 1.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
7.36 µs ± 33.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


CpuDevice(id=0)

In [None]:
def clip_and_center(e_locs):
    median = jnp.median(e_locs)
    total_var = jnp.mean(jnp.abs(e_locs - median))
    lower, upper = median - 5*total_var, median + 5*total_var
    e_locs = jnp.clip(e_locs, a_min=lower, a_max=upper)
    return e_locs - jnp.mean(e_locs)

def create_grad_function(wf, mol):

    vwf = vmap(wf, in_axes=(None, 0, 0))
    compute_energy = create_energy_fn(wf, mol)
    
    def _forward_pass(params, walkers, d0s):

        e_locs = lax.stop_gradient(compute_energy(params, walkers, d0s))
        e_locs_centered = clip_and_center(e_locs)
        log_psi = vwf(params, walkers, d0s)
        
        # jax.device_put(e_locs).reshape(-1) this doesn't work on the inside of the function
        return jnp.mean(e_locs_centered * log_psi), e_locs
    
    grad_fn = jit(grad(_forward_pass, has_aux=True))
    
    def _grad_fn(params, walkers, d0s):
        grads, e_locs = grad_fn(params, walkers, d0s)
        return jax.lax.pmean(grads, axis_name='g'), e_locs
    
    return pmap(_grad_fn, in_axes=(None, 0, 0), axis_name='g')

pgrad_fn = create_grad_function(wf, mol)
grads, e_locs = pgrad_fn(params, walkers, d0s)

In [None]:
flat_grads, tree = tree_flatten(grads)
[print(v.shape) for v in flat_grads] 
a = jnp.zeros((flat_grads[0].shape[-1],))

print(a.device_buffer.device())

print(e_locs.reshape(-1).shape)
print(flat_grads[-1].device())

In [None]:
def update_maa_and_mss(step, maa, aa, mss, ss):
    cov_moving_weight = jnp.min(jnp.array([step, 0.95])) 
    cov_instantaneous_weight = 1. - cov_moving_weight
    total = cov_moving_weight + cov_instantaneous_weight

    maa = (cov_moving_weight * maa + cov_instantaneous_weight * aa) / total
    mss = (cov_moving_weight * mss + cov_instantaneous_weight * ss) / total

    return maa, mss

def kfac(kfac_wf, wf, mol, params, walkers, d0s, lr, damping, norm_constraint):

    kfac_update, substate = create_natural_gradients_fn(kfac_wf, wf, mol, params, walkers, d0s)

    def _get_params(state):
        return x[0]

    def _update(step, grads, state):
        params = _get_params(state)
        params, tree = tree_flatten(params)
        params = [p - g for p, g in zip(params, grads)]
        params = tree_unflatten(tree, params)

        return [params, *state[1:]]

    state = [*substate, lr, damping, norm_constraint]

    return _update, _get_params, kfac_update, state


def create_sensitivities_grad_fn(kfac_wf):
    vwf = vmap(kfac_wf, in_axes=(None, 0, 0))

    def _sum_log_psi(params, walkers, d0s):
        log_psi, activations = vwf(params, walkers, d0s)
        return log_psi.mean()

    grad_fn = pmap(jit(grad(_sum_log_psi, argnums=2)), in_axes=(None, 0, 0))

    return grad_fn


def create_natural_gradients_fn(kfac_wf, wf, mol, params, walkers, d0s):

    sensitivities_fn = create_sensitivities_grad_fn(kfac_wf)
    # grad_fn = create_grad_function(wf, mol)
    vwf = pmap(jit(vmap(kfac_wf, in_axes=(None, 0, 0))), in_axes=(None, 0, 0))

    @jit
    def _kfac_step(step, gradients, activations, sensitivities, maas, msss, lr, damping, norm_constraint):

        gradients, gradients_tree_map = tree_flatten(gradients)
        activations, activations_tree_map = tree_flatten(activations)
        sensitivities, sensitivities_tree_map = tree_flatten(sensitivities)

        ngs = []
        new_maas = []
        new_msss = []
        for g, a, s, maa, mss in zip(gradients, activations, sensitivities, maas, msss):
            
            print('g', g.shape)
            print('a', a.shape)
            print('s', s.shape)
            print('maa', maa.shape)
            print('mss', mss.shape)
            
            n = a.shape[1]
            sl_factor = 1.

            if len(a.shape) == 4:
                sl_factor = float(a.shape[2] ** 2)
                a = a.mean(2)
            if len(s.shape) == 4:
                sl_factor = float(s.shape[2] ** 2)
                s = s.mean(2)
                
            print(a.shape, s.shape, jnp.transpose(a))

            aa = jnp.transpose(a, axes=(0, 2, 1)) @ a / float(n)
            ss = jnp.transpose(s, axes=(0, 2, 1)) @ s / float(n)

            maa, mss = update_maa_and_mss(step, maa, aa, mss, ss)
            
            smaa = jax.device_put(maa, jax.devices()[0]).mean(0)
            smss = jax.device_put(mss, jax.devices()[0]).mean(0)
            
            dmaa, dmss = damp(smaa, smss, sl_factor, damping)

            # chol_dmaa = jnp.linalg.cholesky(dmaa)
            # chol_dmss = jnp.linalg.cholesky(dmss)

            dmaa = (dmaa + jnp.transpose(dmaa)) / 2.
            dmss = (dmss + jnp.transpose(dmss)) / 2.

            chol_dmaa = jax.scipy.linalg.cho_factor(dmaa)
            chol_dmss = jax.scipy.linalg.cho_factor(dmss)

            inv_dmaa = jax.scipy.linalg.cho_solve(chol_dmaa, jnp.eye(smaa.shape[0]))  # , check_finite=False for performance
            inv_dmss = jax.scipy.linalg.cho_solve(chol_dmss, jnp.eye(smss.shape[0]))

            # the zero index takes the values on device 0
            ng = inv_dmaa @ g[0] @ inv_dmss / sl_factor

            # vals_dmaa, vecs_dmaa = jnp.linalg.eigh(dmaa)
            # vals_dmss, vecs_dmss = jnp.linalg.eigh(dmss)
            #
            # tmp = (jnp.transpose(vecs_dmaa) @ g @ vecs_dmss) / (vals_dmaa[:, None] * vals_dmss[None, :])
            # ng = vecs_dmaa @ tmp @ jnp.transpose(vecs_dmss)

            ngs.append(ng)
            new_maas.append(maa)
            new_msss.append(mss)

        eta = compute_norm_constraint(ngs, gradients, lr, norm_constraint)

        return [lr * eta * ng for ng in ngs], (new_maas, new_msss, lr, damping, norm_constraint)

    kfac_step = pmap(_kfac_step, in_axes=(None, 0, 0, 0, 0, 0, None, None, None), axis_name='k')

    def _compute_natural_gradients(step, grads, state, walkers, d0s):

        params, maas, msss, lr, damping, norm_constraint = state
        # yes there is a more efficient way of doing this.
        # It can be reduced by 1 forward passes and 1 backward pass
        _, activations = vwf(params, walkers, d0s)
        sensitivities = sensitivities_fn(params, walkers, d0s)

        ngs, state = _kfac_step(step, grads, activations, sensitivities, maas, msss, lr, damping, norm_constraint)

        return ngs, (params, *state)

    _, activations = vwf(params, walkers, d0s)
    sensitivities = sensitivities_fn(params, walkers, d0s)

    activations, activations_tree_map = tree_flatten(activations)
    sensitivities, sensitivities_tree_map = tree_flatten(sensitivities)
    maas = [jnp.zeros((mol.n_devices, a.shape[-1], a.shape[-1])) for a in activations]
    msss = [jnp.zeros((mol.n_devices, s.shape[-1], s.shape[-1])) for s in sensitivities]
    substate = (params, maas, msss)

    return _compute_natural_gradients, substate




update, get_params, kfac_update, state = kfac(kfac_wf, wf, mol, params, walkers, d0s, 
                                              lr=1e-4, damping=1e-3, norm_constraint=1e-3)
step = 1
kfac_grads, state = kfac_update(step, grads, state, walkers, d0s)

In [None]:
[print(g[2].device()) for g in grads]
[print(g.shape) for g in grads]


In [None]:
# print(state[0])
# print(state[1])
print(state[2])

In [None]:
state = update(step, grads, state)

params = get_params(state)


In [None]:
x = jnp.ones((4, 2, 3))
y = jnp.transpose(x, axes=(0, 2, 1)) @ x
print(y.shape)


In [None]:
def check_symmetric(x):
    x = x - x.transpose(-1, -2)
    print(x.mean())


def compute_norm_constraint(nat_grads, grads, lr, norm_constraint):
    sq_fisher_norm = 0.
    for ng, g in zip(nat_grads, grads):
        sq_fisher_norm += (ng * g).sum()
    eta = jnp.min(jnp.array([1., jnp.sqrt(norm_constraint / (lr**2 * sq_fisher_norm))]))
    return eta


def decay_variable(self, variable, iteration):
    return variable / (1. + self.decay * iteration)


def damp(maa, mss, sl_factor, damping):

    dim_a = maa.shape[-1]
    dim_s = mss.shape[-1]

    tr_a = get_tr_norm(maa)
    tr_s = get_tr_norm(mss)

    pi = ((tr_a * dim_s) / (tr_s * dim_a))

    eye_a = jnp.eye(dim_a, dtype=maa.dtype)
    eye_s = jnp.eye(dim_s, dtype=maa.dtype)

    m_aa_damping = jnp.sqrt((pi * damping / sl_factor))
    m_ss_damping = jnp.sqrt((damping / (pi * sl_factor)))

    maa += eye_a * m_aa_damping
    mss += eye_s * m_ss_damping
    return maa, mss


def get_tr_norm(x):
    trace = jnp.diagonal(x).sum(-1)
    return jnp.max(jnp.array([1e-5, trace]))