In [1]:
%load_ext autoreload
%autoreload 2
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
import jax
import jax.random as rnd
import jax.numpy as jnp
from jax import vmap, jit, grad, pmap
from jax.experimental.optimizers import adam
from jax import tree_util
from tqdm.notebook import trange

from functools import partial

from nn_ansatz import *

In [6]:
cfg = config = setup(system='LiSolidBCC',
               n_pre_it=0,
               n_walkers=256,
               n_layers=2,
               n_sh=32,
               step_size=0.02,
               n_ph=8,
               orbital_decay='isotropic',
               n_periodic_input=3,
               opt='adam',
               n_det=2,
               print_every=100,
               save_every=5000,
               n_it=20000,
               name='isotropic')

logger = Logging(**cfg)

keys = rnd.PRNGKey(cfg['seed'])
if bool(os.environ.get('DISTRIBUTE')) is True:
    keys = rnd.split(keys, cfg['n_devices']).reshape(cfg['n_devices'], 2)

mol = SystemAnsatz(**cfg)

pwf = pmap(create_wf(mol), in_axes=(None, 0))
vwf = create_wf(mol)
pwf_grad = grad(lambda x, y: pwf(x, y).sum(), argnums=(0,1))

params = initialise_params(mol, keys)

sampler = create_sampler(mol, vwf)

ke = pmap(create_local_kinetic_energy(vwf), in_axes=(None, 0))
pe = pmap(create_potential_energy(mol), in_axes=(0, None, None))
pmap_compute_ae_vectors_periodic_i = lambda x, y: compute_ae_vectors_periodic_i(x, y, mol.unit_cell_length)
pmap_compute_ae_vectors_periodic_i = pmap(vmap(pmap_compute_ae_vectors_periodic_i, in_axes=(0, None)), in_axes=(0, None))
pmap_compute_ee_vectors_i = pmap(vmap(compute_ee_vectors_i, in_axes=(0,)), in_axes=(0,))

grad_fn = create_grad_function(mol, vwf)

walkers = generate_walkers_around_nuclei(mol.n_el_atoms, mol.atom_positions, mol.n_walkers)
walkers = walkers.reshape(mol.n_devices, -1, *walkers.shape[1:])
walkers = keep_in_boundary(walkers, mol.real_basis, mol.inv_real_basis)

log_psi = pwf(params, walkers)

keys, subkeys = key_gen(keys)
sampler(params, walkers, subkeys, mol.step_size)

version 		 100921
seed 		 369
n_devices 		 1
save_every 		 5000
print_every 		 100
exp_dir 		 /home/amawi/projects/nn_ansatz/src/experiments/LiSolidBCC/100921/isotropic/adam_1lr-4_1d-3_1nc-4_m256_s32_p8_l2_det2/run0
events_dir 		 /home/amawi/projects/nn_ansatz/src/experiments/LiSolidBCC/100921/isotropic/adam_1lr-4_1d-3_1nc-4_m256_s32_p8_l2_det2/run0/events
models_dir 		 /home/amawi/projects/nn_ansatz/src/experiments/LiSolidBCC/100921/isotropic/adam_1lr-4_1d-3_1nc-4_m256_s32_p8_l2_det2/run0/models
opt_state_dir 		 /home/amawi/projects/nn_ansatz/src/experiments/LiSolidBCC/100921/isotropic/adam_1lr-4_1d-3_1nc-4_m256_s32_p8_l2_det2/run0/models/opt_state
pre_path 		 /home/amawi/projects/nn_ansatz/src/experiments/LiSolidBCC/pretrained/s32_p8_l2_det2_1lr-4_i0.pk
timing_dir 		 /home/amawi/projects/nn_ansatz/src/experiments/LiSolidBCC/100921/isotropic/adam_1lr-4_1d-3_1nc-4_m256_s32_p8_l2_det2/run0/events/timing
system 		 LiSolidBCC
r_atoms 		 [[0.  0.  0. ]
 [0.5 0.5 0.5]]
z_atoms 		 [3. 3.]
n_

(DeviceArray([[[[0.8582945 , 6.5696616 , 0.04186589],
                [6.2126694 , 1.3524095 , 6.116393  ],
                [3.8645022 , 2.912761  , 1.4493206 ],
                [5.17297   , 4.029324  , 2.7368865 ],
                [0.528752  , 1.3406394 , 0.7872262 ],
                [2.2903671 , 1.4516345 , 4.184484  ]],
 
               [[0.70812196, 0.56556445, 1.2961988 ],
                [6.2208285 , 6.382479  , 0.6443732 ],
                [4.8242397 , 4.1581893 , 3.683971  ],
                [3.9407113 , 5.073767  , 6.0349646 ],
                [6.1779075 , 0.34728247, 0.12393504],
                [2.1068974 , 2.329549  , 5.0772657 ]],
 
               [[0.33890557, 6.264277  , 0.5172672 ],
                [6.161936  , 0.6629444 , 0.09966614],
                [1.5763894 , 3.7691598 , 3.2538962 ],
                [2.111093  , 2.9599917 , 2.879323  ],
                [5.7157664 , 6.5899916 , 0.15455239],
                [3.3246207 , 4.7007947 , 2.2608042 ]],
 
               ...,

In [7]:
if cfg['opt'] == 'kfac':
    update, get_params, kfac_update, state = kfac(mol, params, walkers, cfg['lr'], cfg['damping'], cfg['norm_constraint'])
elif cfg['opt'] == 'adam':
    init, update, get_params = adam(cfg['lr'])
    update = jit(update)
    state = init(params)
else:
    exit('Optimiser not available')

steps = trange(1, cfg['n_it']+1, initial=1, total=cfg['n_it']+1, desc='training', disable=None)
step_size = split_variables_for_pmap(cfg['n_devices'], cfg['step_size'])

training:   0%|          | 1/20001 [00:00<?, ?it/s]

In [8]:


for step in steps:
    keys, subkeys = key_gen(keys)

    walkers, acceptance, step_size = sampler(params, walkers, subkeys, step_size)
    # gparam, gwalker = pwf_grad(params, walkers)
    # stop = capture_nan(walkers, 'walkers', False)

    # pote = pe(walkers, mol.r_atoms, mol.z_atoms)
    # pote_nan = check_if_nan(pote, 'x')

    # kine = ke(params, walkers)
    # kine_nan = check_if_nan(kine,'x')

    # ae_vectors = pmap_compute_ae_vectors_periodic_i(walkers, mol.r_atoms)
    # ee_vectors = pmap_compute_ee_vectors_i(walkers)
    # min_im_ee_vectors = apply_minimum_image_convention(ee_vectors, mol.unit_cell_length)
    # min_im_ae_vectors = apply_minimum_image_convention(ae_vectors, mol.unit_cell_length)

    # if kine_nan:
    #     print('nan in kinetic')
    #     break
    
    grads, e_locs = grad_fn(params, walkers)
    # stop = capture_nan(grads, 'e_locs', stop)
    # stop = capture_nan(grads, 'grads', stop)

    if cfg['opt'] == 'kfac':
        grads, state = kfac_update(step, grads, state, walkers)

    state = update(step, grads, state)
    params = get_params(state)

    steps.set_postfix(E=f'{jnp.mean(e_locs):.6f}')
    steps.refresh()

    logger.log(step,
                opt_state=state,
                params=params,
                e_locs=e_locs,
                acceptance=acceptance[0],
                walkers=walkers)

logger.walkers = walkers

step 100 | e_mean -3.4131 | e_std 4.1471 | e_mean_mean -3.0957 | acceptance 0.5539 | t_per_it 0.0518 |
step 200 | e_mean -5.5867 | e_std 4.3614 | e_mean_mean -5.1772 | acceptance 0.5031 | t_per_it 0.0517 |
step 300 | e_mean -5.9305 | e_std 4.2964 | e_mean_mean -5.6933 | acceptance 0.5121 | t_per_it 0.0516 |
step 400 | e_mean -5.6908 | e_std 3.3997 | e_mean_mean -6.0146 | acceptance 0.5320 | t_per_it 0.0514 |
step 500 | e_mean -5.9243 | e_std 3.3819 | e_mean_mean -6.0710 | acceptance 0.4941 | t_per_it 0.0513 |
step 600 | e_mean -5.9484 | e_std 3.6764 | e_mean_mean -6.1820 | acceptance 0.5152 | t_per_it 0.0512 |
step 700 | e_mean -6.1402 | e_std 2.6809 | e_mean_mean -6.2906 | acceptance 0.4934 | t_per_it 0.0512 |
step 800 | e_mean -6.1693 | e_std 2.8186 | e_mean_mean -6.3954 | acceptance 0.4754 | t_per_it 0.0513 |
step 900 | e_mean -6.3035 | e_std 2.9426 | e_mean_mean -6.4434 | acceptance 0.4875 | t_per_it 0.0513 |
step 1000 | e_mean -6.6934 | e_std 3.7546 | e_mean_mean -6.5475 | accepta

In [45]:
idx = jnp.argwhere(jnp.isnan(kine[0]))[0, 0]
print(kine[0, idx])
print(ae_vectors[0, idx, ...] / (mol.unit_cell_length / 2.))

nan
[[[-1.3872877e-01 -5.6188392e-09 -4.2735010e-01]
  [ 8.6127132e-01  1.0000000e+00  5.7264996e-01]]

 [[ 8.4430242e-01  9.5665139e-01 -8.6504728e-01]
  [-1.5569763e-01 -4.3348670e-02  1.3495275e-01]]

 [[ 7.8833884e-01 -7.3619252e-01 -5.7768416e-01]
  [-2.1166119e-01  2.6380754e-01  4.2231593e-01]]

 [[ 8.0202579e-01  8.5771257e-01  7.7705044e-01]
  [-1.9797423e-01 -1.4228749e-01 -2.2294964e-01]]

 [[-8.8907361e-01  9.6318859e-01  7.7088475e-01]
  [ 1.1092643e-01 -3.6811471e-02 -2.2911531e-01]]

 [[ 1.7893118e-01 -2.0605713e-01 -2.1177055e-02]
  [-8.2106888e-01  7.9394299e-01  9.7882301e-01]]]


In [44]:
print(idx)
walkers[0, idx, ...]
log_psi = pwf(params, walkers)
print(walkers[0, idx, ...])
print(log_psi[0, idx])
print(gwalker[0, 0, idx, :])

177
[[4.5988584e-01 1.8626451e-08 1.4166656e+00]
 [3.8311377e+00 3.4587009e+00 2.8676317e+00]
 [4.0166569e+00 2.4404781e+00 1.9150229e+00]
 [3.9712846e+00 3.7866831e+00 4.0540781e+00]
 [2.9472790e+00 3.4370301e+00 4.0745173e+00]
 [6.0368433e+00 6.8307936e-01 7.0201933e-02]]
-3.8840332
[ 0.1320289 -1.5310602 -0.8113647]


In [19]:
from jax import lax
from jax.tree_util import tree_flatten

def create_grad_function(mol, vwf):
    
    compute_energy = create_energy_fn(mol, vwf)

    def _forward_pass(params, walkers):
        e_locs = lax.stop_gradient(compute_energy(params, walkers))

        e_locs_centered = clip_and_center(e_locs) # takes the mean of the data on each device and does not distribute
        log_psi = vwf(params, walkers)

        return jnp.mean(e_locs_centered * log_psi), e_locs

    _param_grad_fn = grad(_forward_pass, has_aux=True)  # has_aux indicates the number of outputs is greater than 1
    
    if bool(os.environ.get('DISTRIBUTE')) is True:
        _param_grad_fn = pmap(_param_grad_fn, in_axes=(None, 0))

    '''nb: it is not possible to undevice variables within a pmap'''

    def _grad_fn(params, walkers):
        grads, e_locs = _param_grad_fn(params, walkers)
        grads = jax.device_put(grads, jax.devices()[0])
        grads, tree = tree_flatten(grads)
        grads = [g.mean(0) for g in grads]
        grads = tree_unflatten(tree, grads)
        return grads, jax.device_put(e_locs, jax.devices()[0]).reshape(-1)

    return jit(_grad_fn)

# compute_energy = create_energy_fn(mol, vwf)

# jnp.linalg.norm(mol.real_basis, axis=-1)#.mean()

print(jnp.linalg.norm(mol.real_basis, axis=-1))


RuntimeError: Unknown: an illegal memory access was encountered
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc(40): 'cuLinkCreate(0, nullptr, nullptr, &link_state)'

In [19]:
# jax.device_put(walkers, jax.devices()[0])
print(jax.devices())



[GpuDevice(id=0, process_index=0)]


In [26]:
import numpy as np
print(np.linalg.norm(np.array(mol.real_basis), axis=-1))
jnp.linalg.norm(mol.real_basis, axis=-1)

[6.63 6.63 6.63]


ValueError: Internal: Failed to launch CUDA kernel: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered

In [21]:
# print(apply_minimum_image_convention(walkers, unit_cell_length=mol.unit_cell_length) / mol.unit_cell_length)
w = walkers[0, 0]
print(w)
print((2 * w / mol.unit_cell_length).astype(int))
print((2 * w / mol.unit_cell_length).astype(w.dtype) * mol.unit_cell_length)
# displace = (2. * displacement_vectors / unit_cell_length).astype(int).astype(displacement_vectors.dtype) * unit_cell_length
    # displacement_vectors = displacement_vectors + lax.stop_gradient(displace)  # 

[[0.64362454 2.797233   2.0600808 ]
 [1.2021452  0.6426074  2.484054  ]
 [5.493576   5.7438507  5.119844  ]
 [4.999399   4.46594    4.9091105 ]
 [3.3290756  2.5294046  3.2106931 ]
 [4.6895256  3.9742057  5.5813255 ]]
[[0 0 0]
 [0 0 0]
 [1 1 1]
 [1 1 1]
 [1 0 0]
 [1 1 1]]
[[ 1.2872492  5.5944667  4.120162 ]
 [ 2.4042904  1.2852148  4.9681087]
 [10.987153  11.487702  10.239688 ]
 [ 9.998799   8.93188    9.818221 ]
 [ 6.6581516  5.0588093  6.4213862]
 [ 9.379052   7.9484124 11.162652 ]]


In [9]:
print(jnp.isinf(test).any())

False


In [None]:
# grad_fn = create_grad_function(mol, vwf)

# if cfg['opt'] == 'kfac':
#     update, get_params, kfac_update, state = kfac(mol, params, walkers, cfg['lr'], cfg['damping'], cfg['norm_constraint'])
# elif cfg['opt'] == 'adam':
#     init, update, get_params = adam(cfg['lr'])
#     update = jit(update)
#     state = init(params)
# else:
#     exit('Optimiser not available')

# steps = trange(1, cfg['n_it']+1, initial=1, total=cfg['n_it']+1, desc='training', disable=None)
# step_size = split_variables_for_pmap(cfg['n_devices'], cfg['step_size'])

# for step in steps:
#     keys, subkeys = key_gen(keys)

#     walkers, acceptance, step_size = sampler(params, walkers, subkeys, step_size)

KeyboardInterrupt: 