<a href="https://colab.research.google.com/github/path/to/.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Environment Setup

Change to GPU runtime: Runtime -> Change runtime type -> Hardware accelerator -> GPU

In [None]:
# Check cuda version
! nvcc --version

The jaxlib version must correspond to the version of the existing CUDA installation you want to use, with e.g. `cuda111` for CUDA 11.1.

In [None]:
# For GPU runtime
! pip install --upgrade pip
! pip install --upgrade jaxlib==0.1.72+cuda111 jax==0.2.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [None]:
# Install ksr-dft
! git clone https://github.com/pedersor/ksr_dft.git
! pip install ksr_dft

## Import and setup

In [1]:
import os
from pathlib import Path

import numpy as np
import scipy
import jax
from jax import random
from jax import tree_util
from jax.config import config
import jax.numpy as jnp

from ksr_dft import datasets
from ksr_dft import jit_scf
from ksr_dft import losses
from ksr_dft import neural_xc
from ksr_dft import np_utils
from ksr_dft import scf
from ksr_dft import utils
from ksr_dft import xc
from ksr_dft import analysis
from ksr_dft import ksr


os.environ['XLA_FLAGS'] = '--xla_gpu_deterministic_reductions'
# Set the default dtype as float64. Note: the dtype may switch from float64 to
# float32 during e.g. evaluating/training a convolution neural net. 
config.update('jax_enable_x64', True)



In [2]:
# get path to ksr_dft repo

try:
  import google.colab
  # in Colab
  KSR_DFT_PATH = Path('/content/ksr_dft/')
except:
  # running in local directory
  KSR_DFT_PATH = Path('../')

# Load data

In [None]:
# load atomic systems dataset
ions_dataset = datasets.Dataset(KSR_DFT_PATH / 'data/ions/dmrg', num_grids=513)
grids = ions_dataset.grids

# Train neural XC functional with Kohn-Sham regularizer (KSR)

![](../images/sksr_global.png)

In [3]:
trainer = ksr.SpinKSR(grids)


# set ML model for xc functional
model_dir = 'example_model/'
if not os.path.exists(model_dir):
  os.makedirs(model_dir)

# architecture
network = neural_xc.build_sliding_net(
    window_size=1,
    num_filters_list=[16, 16, 16],
    activation='swish',
)
init_fn, neural_xc_energy_density_fn = neural_xc.global_functional_sigma(
    network, grids=grids)

trainer.set_neural_xc_functional(
    model_dir=model_dir,
    neural_xc_energy_density_fn=neural_xc_energy_density_fn)



# set initial params from init_fn
key = jax.random.PRNGKey(0)
trainer.set_init_model_params(init_fn, key)

# set KS parameters
trainer.set_ks_params(
    num_iterations=6,
    alpha=0.5,
    alpha_decay=0.9,
    enforce_reflection_symmetry=False,
    num_mixing_iterations=1,
    density_mse_converge_tolerance=-1.,
    stop_gradient_step=-1,
)

## Train Ions

# set training set
to_train = [(1, 1), (2, 2)]
training_set = ions_dataset.get_ions(to_train)
trainer.set_training_set(training_set)

# setup parameters associated with the optimization
trainer.setup_optimization(
    initial_checkpoint_index=0,
    save_every_n=10,
    max_train_steps=100,
    # number of iterations skipped in energy loss evaluation,
    # a value of -1 corresponds to using the final KS only.
    num_skipped_energies=-1,
    # can also modify energy vs density weight in loss function:
    # energy_loss_weight=0.5,
)

# perform training optimization
trainer.do_lbfgs_optimization(verbose=1)

## Validate Ions

# set validation set
to_validate = [(3, 3)]
validation_set = ions_dataset.get_ions(to_validate)
trainer.set_validation_set(validation_set)
# get optimal checkpoint from validation
trainer.get_optimal_ckpt(model_dir)


number of parameters = 560
step 0, loss 2.087971552383939 in 73.64164400100708 sec
Save checkpoint example_model/ckpt-00000
step 1, loss 0.09922457311561972 in 2.027907133102417 sec
step 2, loss 0.04648410970686181 in 2.2076478004455566 sec
step 3, loss 0.04505530701520058 in 2.4542808532714844 sec
step 4, loss 0.04133870938155233 in 2.117668390274048 sec
step 5, loss 0.03480095597706702 in 1.8888216018676758 sec
step 6, loss 0.02358349513211198 in 2.8329248428344727 sec
step 7, loss 0.01308982869115781 in 2.187788963317871 sec
step 8, loss 0.006527987203520393 in 2.594045877456665 sec
step 9, loss 0.006270831470625427 in 1.9054477214813232 sec
step 10, loss 0.006243509754087787 in 2.2791924476623535 sec
Save checkpoint example_model/ckpt-00010
step 11, loss 0.006042104814021789 in 2.465707302093506 sec
step 12, loss 0.005671039070936974 in 3.165398359298706 sec
step 13, loss 0.005262280266105735 in 2.6839404106140137 sec
step 14, loss 0.005099102132806769 in 2.7878565788269043 sec
ste

KeyboardInterrupt: 