In [3]:
%load_ext autoreload
%autoreload 2
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
sys.path.append('/home/amawi/projects/nn_ansatz/src')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import jax
import jax.random as rnd
import jax.numpy as jnp
from jax import vmap, jit, grad, pmap
from jax.experimental.optimizers import adam
from jax import tree_util
from tqdm.notebook import trange

from functools import partial

from nn_ansatz import *

In [5]:
cfg = config = setup(system='LiSolidBCC',
               n_pre_it=0,
               n_walkers=64,
               n_layers=2,
               n_sh=32,
               n_ph=8,
               opt='kfac',
               n_det=2,
               print_every=1,
               save_every=5000,
               n_it=1000)

logger = Logging(**cfg)

keys = rnd.PRNGKey(cfg['seed'])
if bool(os.environ.get('DISTRIBUTE')) is True:
    keys = rnd.split(keys, cfg['n_devices']).reshape(cfg['n_devices'], 2)

mol = SystemAnsatz(**cfg)

pwf = pmap(create_wf(mol), in_axes=(None, 0))
vwf = create_wf(mol)

params = initialise_params(mol, keys)

sampler = create_sampler(mol, vwf)

ke = pmap(create_local_kinetic_energy(vwf), in_axes=(None, 0))
pe = pmap(create_potential_energy(mol), in_axes=(0, None, None))
grad_fn = create_grad_function(mol, vwf)

walkers = generate_walkers_around_nuclei(mol.n_el_atoms, mol.atom_positions, mol.n_walkers)
walkers = walkers.reshape(mol.n_devices, -1, *walkers.shape[1:])
walkers = keep_in_boundary(walkers, mol.real_basis, mol.inv_real_basis)

log_psi = pwf(params, walkers)

keys, subkeys = key_gen(keys)
sampler(params, walkers, subkeys, mol.step_size)

version 		 100921
seed 		 369
n_devices 		 1
save_every 		 5000
print_every 		 1
exp_dir 		 /home/amawi/projects/nn_ansatz/src/scripts/validation/100921_ewalds/experiments/LiSolidBCC/100921/junk/kfac_1lr-4_1d-3_1nc-4_m64_s32_p8_l2_det2/run0
events_dir 		 /home/amawi/projects/nn_ansatz/src/scripts/validation/100921_ewalds/experiments/LiSolidBCC/100921/junk/kfac_1lr-4_1d-3_1nc-4_m64_s32_p8_l2_det2/run0/events
models_dir 		 /home/amawi/projects/nn_ansatz/src/scripts/validation/100921_ewalds/experiments/LiSolidBCC/100921/junk/kfac_1lr-4_1d-3_1nc-4_m64_s32_p8_l2_det2/run0/models
opt_state_dir 		 /home/amawi/projects/nn_ansatz/src/scripts/validation/100921_ewalds/experiments/LiSolidBCC/100921/junk/kfac_1lr-4_1d-3_1nc-4_m64_s32_p8_l2_det2/run0/models/opt_state
pre_path 		 /home/amawi/projects/nn_ansatz/src/scripts/validation/100921_ewalds/experiments/LiSolidBCC/pretrained/s32_p8_l2_det2_1lr-4_i0.pk
timing_dir 		 /home/amawi/projects/nn_ansatz/src/scripts/validation/100921_ewalds/experiments/L

(DeviceArray([[[[5.583473  , 0.988975  , 0.39664713],
                [6.0973487 , 5.6345315 , 0.82166576],
                [3.9083393 , 4.1259055 , 3.3362308 ],
                [3.4474747 , 2.8105505 , 3.3686457 ],
                [1.6237023 , 0.90791935, 1.5795434 ],
                [3.0421147 , 2.3331306 , 3.9409254 ]],
 
               [[0.96281385, 0.97168374, 2.1809282 ],
                [0.07163441, 6.5877647 , 0.5054116 ],
                [3.4621022 , 3.1513455 , 4.9210405 ],
                [1.7614384 , 4.3776894 , 2.19134   ],
                [0.34704062, 1.7658527 , 5.586016  ],
                [3.6268432 , 4.158191  , 3.458527  ]],
 
               [[1.0535563 , 1.3553293 , 0.09168262],
                [1.4582154 , 6.0680437 , 1.288484  ],
                [3.3828547 , 2.3205283 , 4.911339  ],
                [1.4159553 , 2.8223505 , 4.010165  ],
                [0.9543327 , 5.8596525 , 3.949174  ],
                [2.2789385 , 3.3533428 , 3.423867  ]],
 
               ...,

In [15]:
walker = walkers[0, 0, ...][None, None, ...]

In [18]:
energy = float(pe(walker, mol.r_atoms, mol.z_atoms))

In [24]:
import csv 

data = []
for i in range(6):
    r = walker[0, 0, i, :]
    tmp = [str(float(x)) for x in r]
    tmp.insert(0, 'e%i'%i)
    data.append(tmp)

for i in range(2):
    r = mol.r_atoms[i, :]
    tmp = [str(float(x)) for x in r]
    tmp.insert(0, 'a%i'%i)
    data.append(tmp)

header = ['pe=%.4f'%energy,]
with open('configuration.csv', 'w') as f:
    write = csv.writer(f)
    write.writerow(header)
    for row in data:
        write.writerow(row)




