In [1]:
from nn_ansatz import setup
from nn_ansatz import run_vmc

  h5py.get_config().default_file_mode = 'a'


In [2]:
# using routines

lr, damping, nc = 1e-4, 1e-4, 1e-4
n_pre_it = 500
n_walkers = 512
n_layers = 2
n_sh = 64
n_ph = 16
n_det = 8
n_it = 1000
seed = 1


config = setup(system='Be',
               n_pre_it=500,
               n_walkers=512,
               n_layers=2,
               n_sh=64,
               n_ph=16,
               opt='kfac',
               n_det=8,
               print_every=1,
               save_every=5000,
               lr=lr,
               n_it=1000,
               norm_constraint=nc,
               damping=damping)



version 		 060421
seed 		 369
save_every 		 5000
print_every 		 1
exp_dir 		 /home/xmax/nn_ansatz/src/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_s64_p16_l2_det8/run1
events_dir 		 /home/xmax/nn_ansatz/src/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_s64_p16_l2_det8/run1/events
models_dir 		 /home/xmax/nn_ansatz/src/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_s64_p16_l2_det8/run1/models
opt_state_dir 		 /home/xmax/nn_ansatz/src/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_s64_p16_l2_det8/run1/models/opt_state
pre_path 		 /home/xmax/nn_ansatz/src/experiments/Be/pretrained/s64_p16_l2_det8_1lr-4_i500.pk
timing_dir 		 /home/xmax/nn_ansatz/src/experiments/Be/junk/kfac_1lr-4_1d-4_1nc-4_s64_p16_l2_det8/run1/events/timing
system 		 Be
r_atoms 		 [[0. 0. 0.]]
z_atoms 		 [4.]
n_el 		 4
n_el_atoms 		 [4]
n_layers 		 2
n_sh 		 64
n_ph 		 16
n_det 		 8
opt 		 kfac
lr 		 0.0001
damping 		 0.0001
norm_constraint 		 0.0001
n_it 		 1000
load_it 		 0
n_walkers 		 512
step_size 		 0.02
pre_lr 		 0.0001
n_pre_it 		

In [3]:
run_vmc(**config)

System: 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 
 n_el    = 4 

Ansatz: 
 n_layers = 2 
 n_det    = 8 
 n_sh     = 64 
 n_ph     = 16 

converged SCF energy = -14.351880476202


KeyboardInterrupt: 

In [7]:
# using core functions 
# Sytem setup

from tqdm import trange

import jax.random as rnd 
import jax.numpy as jnp

from nn_ansatz import create_sampler
from nn_ansatz import create_wf
from nn_ansatz import initialise_params, initialise_d0s, expand_d0s
from nn_ansatz import SystemAnsatz
from nn_ansatz import pretrain_wf
from nn_ansatz import create_energy_fn, create_grad_function
from nn_ansatz import create_natural_gradients_fn, kfac
from nn_ansatz import Logging, load_pk, save_pk



r_atoms = config['r_atoms']
z_atoms = config['z_atoms']
n_el_atoms = config['n_el_atoms']
step_size = config['step_size']
n_el = config['n_el']
load_pretrain = False
pre_path = config['pre_path']
norm_constraint = config['norm_constraint']

key = rnd.PRNGKey(seed)

mol = SystemAnsatz(r_atoms,
                   z_atoms,
                   n_el,
                   n_el_atoms=n_el_atoms,
                   n_layers=n_layers,
                   n_sh=n_sh,
                   n_ph=n_ph,
                   n_det=n_det,
                   step_size=step_size)

wf, kfac_wf, wf_orbitals = create_wf(mol)
params = initialise_params(key, mol)
d0s = expand_d0s(initialise_d0s(mol), n_walkers)

sampler, equilibrate = create_sampler(wf, mol, correlation_length=10)


walkers = mol.initialise_walkers(n_walkers=n_walkers)


System: 
 n_atoms = 1 
 n_up    = 2 
 n_down  = 2 
 n_el    = 4 

Ansatz: 
 n_layers = 2 
 n_det    = 8 
 n_sh     = 64 
 n_ph     = 16 

converged SCF energy = -14.351880476202


In [None]:
# Pretrain

params, walkers = pretrain_wf(params,
                              wf,
                              wf_orbitals,
                              mol,
                              walkers,
                              n_it=n_pre_it,
                              lr=1e-4,
                              n_eq_it=n_pre_it)

In [10]:
# vmc loop with kfac

grad_fn = create_grad_function(wf, mol)


update, get_params, kfac_update, state = kfac(kfac_wf, wf, mol, params, walkers, d0s,
                                                  lr=lr,
                                                  damping=damping,
                                                  norm_constraint=norm_constraint)


steps = trange(0, n_it, initial=0, total=n_it, desc='training', disable=None)
for step in steps:
    key, subkey = rnd.split(key)

    walkers, acceptance, step_size = sampler(params, walkers, d0s, subkey, step_size)

    grads, e_locs = grad_fn(params, walkers, d0s)

    grads, state = kfac_update(step, grads, state, walkers, d0s)

    state = update(step, grads, state)
    params = get_params(state)

    steps.set_postfix(E=f'{jnp.mean(e_locs):.6f}')
    steps.refresh()
    

print('exit')

exit
