In [1]:
import os
os.environ['JAX_PLATFORM_NAME']='cpu'
os.environ['XLA_FLAGS']="--xla_force_host_platform_device_count=4"

import jax
import jax.numpy as jnp
from jax import local_device_count
devices = jax.devices()
n_devices = len(devices)
print('Devices: ', devices)

from jax import pmap
import jax.random as rnd

from nn_ansatx import *

Devices:  [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]


  h5py.get_config().default_file_mode = 'a'


In [2]:
key_gen = lambda keys: [x.squeeze() for x in jnp.array([rnd.split(key) for key in keys]).split(2, axis=1)]

In [3]:
def split_variables_for_pmap(n_devices, *args):
    for i in range(len(args))[:-1]:
        assert len(args[i]) == len(args[i+1])
    
    assert len(args[0]) % n_devices == 0
        
    new_args = []
    for arg in args:
        shape = arg.shape
        new_args.append(arg.reshape(n_devices, shape[0] // n_devices, *shape[1:]))
    
    if len(args) == 1:
        return new_args[0]
    return new_args
        

In [7]:
key = rnd.PRNGKey(123)

config = setup()

mol = SystemAnsatz(**config)

wf, kfac_wf, wf_orbitals = create_wf(mol)
params = initialise_params(key, mol)
d0s = expand_d0s(initialise_d0s(mol), config['n_walkers'] // n_devices)
walkers = mol.initialise_walkers(n_walkers=config['n_walkers'])

version 		 110521
seed 		 369
n_devices 		 4
save_every 		 1000
print_every 		 0
exp_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m1024_s32_p8_l2_det2/run10
events_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m1024_s32_p8_l2_det2/run10/events
models_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m1024_s32_p8_l2_det2/run10/models
opt_state_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m1024_s32_p8_l2_det2/run10/models/opt_state
pre_path 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/pretrained/s32_p8_l2_det2_1lr-4_i1000.pk
timing_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/pmap/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_m1024_s32_p8_l2_det2/run10/events/timing
system 		 Be
r_atoms 		 [[0. 0. 0.]]
z_atoms 		 [

In [5]:
key = rnd.PRNGKey(123)
keys = rnd.split(key, 4)
keys, subkeys = key_gen(keys) 
print(keys.shape, subkeys.shape)

(4, 2) (4, 2)


In [8]:

sampler, equilibrate = create_sampler(wf, mol)
psampler = pmap(sampler, in_axes=(None, 0, None, 0, None))
pwalkers = split_variables_for_pmap(4, walkers)

xwalkers, acceptance, step_size = psampler(params, pwalkers, d0s, subkeys, config['n_walkers'] // n_devices)

In [9]:
print(xwalkers.shape)

(4, 256, 4, 3)


In [11]:
xwalkers, acceptance, step_size = psampler(params, xwalkers, d0s, subkeys, config['n_walkers'] // n_devices)

In [12]:
print(xwalkers.shape)

(4, 256, 4, 3)
