In [1]:
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_deterministic_reductions'
import pathlib

import numpy as np
from jax.config import config 

from jax_dft import neural_xc
from jax_dft import datasets
from jax_dft import utils

from ks_regularizer import analysis
from ks_regularizer import ksr

# Set the default dtype as float64. Note: the dtype may switch from float64 to
# float32 during e.g. evaluating/training a convolution neural net. This can 
# give rise to slight numerical issues and hence slight reproducibility issues. 
config.update('jax_enable_x64', True)



In [3]:
# path/to/directory/of/model/optimal_ckpt.pkl
model_dir = ''

mol = 'h2'
dissoc_dir = pathlib.Path('../../data/molecules') / mol

dissoc_dataset = datasets.Dataset(dissoc_dir, num_grids=513)
dissoc_dataset = dissoc_dataset.get_subdataset(downsample_step=2)
grids = dissoc_dataset.grids

# use ksr.PureKSR if unrestricted (runs faster)
sKSR = ksr.SpinKSR(grids)

# set molecules test set
sKSR.set_test_set(dissoc_dataset.get_molecules())

# increase ks iterations for testing..
sKSR.set_ks_params(
  # The number of Kohn-Sham iterations in training.
  num_iterations=40,
  # The density linear mixing factor.
  alpha=0.5,
  # Decay factor of density linear mixing factor.
  alpha_decay=0.9,
  # Enforce reflection symmetry across the origin. Note: currently not supported
  # in unrestricted KS.
  enforce_reflection_symmetry=True,
  # The number of density differences in the previous iterations to mix the
  # density. Linear mixing is num_mixing_iterations = 1.
  num_mixing_iterations=1,
  # The stopping criteria of Kohn-Sham iteration on density.
  density_mse_converge_tolerance=-1,
  # Apply stop gradient on the output state of this step and all steps
  # before. The first KS step is indexed as 0. Default -1, no stop gradient
  # is applied.
  stop_gradient_step=-1,
)

# set ML model for xc functional
network = neural_xc.build_global_local_conv_net_sigma(num_global_filters=8,
  num_local_filters=16, num_local_conv_layers=2, activation='swish',
  grids=grids, minval=0.1, maxval=2.385345,
  downsample_factor=0)

network = neural_xc.wrap_network_with_self_interaction_layer_sigma(
    network, grids=grids,
    interaction_fn=utils.exponential_coulomb) 
 
init_fn, neural_xc_energy_density_fn = neural_xc.global_functional_sigma(
  network, grids=grids)

sKSR.set_neural_xc_functional(model_dir=model_dir,
  neural_xc_energy_density_fn=neural_xc_energy_density_fn)


<ks_regularizer.ksr.SpinKSR at 0x7fe63830aca0>

In [4]:
# load optimal checkpoint params and run test states
states = sKSR.get_test_states(
  optimal_ckpt_path=os.path.join(model_dir, 'optimal_ckpt.pkl'))
final_states = sKSR.get_final_states(states)

In [None]:
# Plot dissociation energy curve
import matplotlib.pyplot as plt

nuclear_energy = utils.get_nuclear_interaction_energy_batch(
    dissoc_dataset.locations,
    dissoc_dataset.nuclear_charges,
    interaction_fn=utils.exponential_coulomb)

distances = utils.compute_distances_between_nuclei(dissoc_dataset.locations, 
  [0,1])

plt.plot(distances, final_states.total_energy + nuclear_energy, 
  label='KSR-global')
plt.plot(distances, dissoc_dataset.total_energies + nuclear_energy, 'k--', 
  label='exact')

plt.xlabel('$R$', fontsize=16)
plt.ylabel('$E$', fontsize=16)
plt.legend()
plt.grid(alpha=0.4)

fig_pdf_name = 'h2_dissoc'
plt.savefig(f'{fig_pdf_name}.pdf', bbox_inches='tight')