In [1]:
import json
import os
import pickle
import random
import sys
import pickle

from typing import Callable, Dict, List, Optional
import haiku as hk
import ase
import ase.io
import jax
import jax.numpy as jnp
import numpy as np
import optax
import yaml
import jraph


from model.datasets import datasets
from model.utils import (
    create_directory_with_random_name,
    compute_avg_num_neighbors,
)
from model.data_utils import (
    get_atomic_number_table_from_zs,
    compute_average_E0s,
    config_from_atoms,
    graph_from_configuration,
)
 
from model.predictors import predict_energy_forces_stress
from model.optimizer import optimizer
from model.energy_force_train import energy_force_train
from model.loss import WeightedEnergyFrocesStressLoss
from model.nequip_model import NequIP_JAXMD_model_Efield

from model.utils import (
    get_edge_relative_vectors,
    _safe_divide,
    sum_nodes_of_the_same_graph,
)


jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
np.set_printoptions(precision=3, suppress=True)



## General energy model training with the electric field dependence

First, we need to train the generalized energy model that includes an electric field dependence, expanded up to the order = 2 (with $E^2$ dependence in the energy expansion)

The dataset is generated for perturbed ethanol molecules in a random electric field using VASP DFT calculations to get the perturbed energy and atomic forces.
The general model is trained to capture such dependence by adding a polarization term and the polarizability term to the total energy.


In [2]:
with open('data/train_Efield_pot.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

save_dir_name = create_directory_with_random_name(
    os.path.splitext('Efield_pot_training')[0]
)


2024-03-12-18:44-Efield_pot_training-systematic-fleurette


In [3]:
train_loader, valid_loader,test_loader, r_max = datasets(
    r_max = config["cutoff"],
    config_dataset = config["dataset"],
)
    
print(len(train_loader.graphs))
print(len(valid_loader.graphs))

nums check 1500 9000 50
Loaded 50000 training configurations from 'data/ethanol-train-efield.xyz'
Using random 1000 configurations for validation
Total number of configurations: train=49000, valid=1000, test=0


100%|██████████████████████████████████| 49000/49000 [00:02<00:00, 23035.29it/s]
100%|████████████████████████████████████| 1000/1000 [00:00<00:00, 26654.53it/s]
0it [00:00, ?it/s]

49000
1000





In [4]:
model_fn, params, num_message_passing = NequIP_JAXMD_model_Efield(
    r_max=r_max,
    atomic_energies_dict={},
    train_graphs=train_loader.graphs,
    initialize_seed=config["model"]["seed"],
    num_species = config["model"]["num_species"],
    use_sc = True,
    graph_net_steps = config["model"]["num_layers"],
    hidden_irreps = config["model"]["internal_irreps"],
    nonlinearities =  {'e': 'swish', 'o': 'tanh'},
    save_dir_name = save_dir_name,
    reload = config["initialization"]['reload'] if 'reload' in config["initialization"] else None,
)
    
print("num_params:", sum(p.size for p in jax.tree_util.tree_leaves(params)))
    
predictor = jax.jit(
    lambda w, g: predict_energy_forces_stress(lambda *x: model_fn(w, *x), g)
)
    
gradient_transform, steps_per_interval, max_num_intervals = optimizer(
    lr = config["training"]["learning_rate"],
    max_num_intervals = config["training"]["max_num_intervals"],
    steps_per_interval = config["training"]["steps_per_interval"],
    # weight_decay = config["training"]["weight_decay"],
)
optimizer_state = gradient_transform.init(params)
print("optimizer num_params:", sum(p.size for p in jax.tree_util.tree_leaves(optimizer_state)))
    
loss_fn = WeightedEnergyFrocesStressLoss(
    energy_weight = config["training"]["energy_weight"],
    forces_weight = config["training"]["forces_weight"],
    stress_weight = config["training"]["stress_weight"],
)
    
    

Compute the average number of neighbors: 8.000
Do not normalize the radial basis (avg_r_min=None)
Computed average Atomic Energies using least squares: {1: -6.7887986501063295, 6: -2.262932883368747, 8: -1.1314664416843736}
Create NequIP (JAX-MD version) with parameters {'use_sc': True, 'graph_net_steps': 2, 'hidden_irreps': '64x0e + 64x0o + 48x1o + 48x1e +32x2o + 32x2e', 'nonlinearities': {'e': 'swish', 'o': 'tanh'}, 'r_max': 5.0, 'avg_num_neighbors': 8.0, 'avg_r_min': None, 'num_species': 100, 'radial_basis': <function bessel_basis at 0x7f6f593b7ce0>, 'radial_envelope': <function soft_envelope at 0x7f6f593b7420>}
num_params: 1888176
optimizer num_params: 5664530


In [5]:
energy_force_train(
    predictor,
    params,
    optimizer_state,
    train_loader,
    valid_loader,
    test_loader,
    gradient_transform,
    loss_fn,
    max_num_intervals,
    steps_per_interval,
    save_dir_name,
    ema_decay = config["training"]["ema_decay"],
    patience = config["training"]["patience"],
)

print('Training done!')

Started training


eval_train:   0%|                                       | 0/999 [00:00<?, ?it/s]

Compiled function `model` for args:
cache size: 1


eval_train: 100%|████████████████████| 999/999 [00:24<00:00, 41.41it/s, n=48951]


Interval 0: eval_train: loss=94.2012, mae_e_per_atom=336.7 meV, mae_f=845.4 meV/Å, mae_s=1.4 meV/Å³


eval_valid:  86%|████████████████████▌   | 18/21 [00:00<00:00, 55.90it/s, n=980]

Compiled function `model` for args:
cache size: 2


eval_valid: 100%|███████████████████████| 21/21 [00:05<00:00,  3.54it/s, n=1000]


Interval 0: eval_valid: loss=95.6327, mae_e_per_atom=335.4 meV, mae_f=855.2 meV/Å, mae_s=1.5 meV/Å³


Train interval 0:   0%|           | 5/1000 [00:08<19:55,  1.20s/it, loss=40.756]

Compiled function `update_fn` for args:
Outout: loss= 99.151
Compilation time: 7.890s, cache size: 1
Compiled function `update_fn` for args:
Outout: loss= 132.404
Compilation time: 0.032s, cache size: 2


Train interval 0: 100%|█████████| 1000/1000 [00:36<00:00, 27.25it/s, loss=0.421]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.29it/s, n=48951]


Interval 1: eval_train: loss=0.3746, mae_e_per_atom=2.6 meV, mae_f=52.7 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.89it/s, n=1000]


Interval 1: eval_valid: loss=0.3900, mae_e_per_atom=2.6 meV, mae_f=53.7 meV/Å, mae_s=2.7 meV/Å³


Train interval 1: 100%|█████████| 1000/1000 [00:28<00:00, 34.59it/s, loss=0.332]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.00it/s, n=48951]


Interval 2: eval_train: loss=0.2556, mae_e_per_atom=2.0 meV, mae_f=43.4 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.95it/s, n=1000]


Interval 2: eval_valid: loss=0.2650, mae_e_per_atom=2.0 meV, mae_f=44.1 meV/Å, mae_s=2.7 meV/Å³


Train interval 2: 100%|█████████| 1000/1000 [00:28<00:00, 34.58it/s, loss=0.208]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.98it/s, n=48951]


Interval 3: eval_train: loss=0.1851, mae_e_per_atom=1.7 meV, mae_f=36.9 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.62it/s, n=1000]


Interval 3: eval_valid: loss=0.1918, mae_e_per_atom=1.7 meV, mae_f=37.5 meV/Å, mae_s=2.7 meV/Å³


Train interval 3: 100%|█████████| 1000/1000 [00:28<00:00, 34.53it/s, loss=0.189]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.97it/s, n=48951]


Interval 4: eval_train: loss=0.1494, mae_e_per_atom=1.5 meV, mae_f=33.1 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.04it/s, n=1000]


Interval 4: eval_valid: loss=0.1546, mae_e_per_atom=1.5 meV, mae_f=33.6 meV/Å, mae_s=2.7 meV/Å³


Train interval 4: 100%|█████████| 1000/1000 [00:28<00:00, 34.62it/s, loss=0.128]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.84it/s, n=48951]


Interval 5: eval_train: loss=0.1260, mae_e_per_atom=1.4 meV, mae_f=30.4 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.89it/s, n=1000]


Interval 5: eval_valid: loss=0.1302, mae_e_per_atom=1.4 meV, mae_f=30.9 meV/Å, mae_s=2.7 meV/Å³


Train interval 5: 100%|█████████| 1000/1000 [00:28<00:00, 34.59it/s, loss=0.165]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.69it/s, n=48951]


Interval 6: eval_train: loss=0.1079, mae_e_per_atom=1.2 meV, mae_f=28.1 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.17it/s, n=1000]


Interval 6: eval_valid: loss=0.1112, mae_e_per_atom=1.3 meV, mae_f=28.5 meV/Å, mae_s=2.7 meV/Å³


Train interval 6: 100%|█████████| 1000/1000 [00:28<00:00, 35.01it/s, loss=0.133]
eval_train: 100%|████████████████████| 999/999 [00:17<00:00, 55.52it/s, n=48951]


Interval 7: eval_train: loss=0.0953, mae_e_per_atom=1.2 meV, mae_f=26.5 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.69it/s, n=1000]


Interval 7: eval_valid: loss=0.0980, mae_e_per_atom=1.2 meV, mae_f=26.8 meV/Å, mae_s=2.7 meV/Å³


Train interval 7: 100%|█████████| 1000/1000 [00:28<00:00, 35.13it/s, loss=0.093]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.49it/s, n=48951]


Interval 8: eval_train: loss=0.0848, mae_e_per_atom=1.1 meV, mae_f=25.0 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.93it/s, n=1000]


Interval 8: eval_valid: loss=0.0867, mae_e_per_atom=1.2 meV, mae_f=25.3 meV/Å, mae_s=2.7 meV/Å³


Train interval 8: 100%|█████████| 1000/1000 [00:28<00:00, 35.07it/s, loss=0.078]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.75it/s, n=48951]


Interval 9: eval_train: loss=0.0768, mae_e_per_atom=1.0 meV, mae_f=23.8 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.03it/s, n=1000]


Interval 9: eval_valid: loss=0.0786, mae_e_per_atom=1.1 meV, mae_f=24.1 meV/Å, mae_s=2.7 meV/Å³


Train interval 9: 100%|█████████| 1000/1000 [00:28<00:00, 34.98it/s, loss=0.063]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.20it/s, n=48951]


Interval 10: eval_train: loss=0.0702, mae_e_per_atom=1.0 meV, mae_f=22.8 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.61it/s, n=1000]


Interval 10: eval_valid: loss=0.0718, mae_e_per_atom=1.0 meV, mae_f=23.0 meV/Å, mae_s=2.7 meV/Å³


Train interval 10: 100%|████████| 1000/1000 [00:28<00:00, 35.22it/s, loss=0.090]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.46it/s, n=48951]


Interval 11: eval_train: loss=0.0651, mae_e_per_atom=1.0 meV, mae_f=21.9 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.69it/s, n=1000]


Interval 11: eval_valid: loss=0.0668, mae_e_per_atom=1.0 meV, mae_f=22.2 meV/Å, mae_s=2.7 meV/Å³


Train interval 11: 100%|████████| 1000/1000 [00:28<00:00, 35.09it/s, loss=0.084]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.34it/s, n=48951]


Interval 12: eval_train: loss=0.0606, mae_e_per_atom=0.9 meV, mae_f=21.2 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.71it/s, n=1000]


Interval 12: eval_valid: loss=0.0622, mae_e_per_atom=1.0 meV, mae_f=21.4 meV/Å, mae_s=2.7 meV/Å³


Train interval 12: 100%|████████| 1000/1000 [00:28<00:00, 35.01it/s, loss=0.080]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.29it/s, n=48951]


Interval 13: eval_train: loss=0.0572, mae_e_per_atom=1.0 meV, mae_f=20.6 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.00it/s, n=1000]


Interval 13: eval_valid: loss=0.0589, mae_e_per_atom=1.0 meV, mae_f=20.8 meV/Å, mae_s=2.7 meV/Å³


Train interval 13: 100%|████████| 1000/1000 [00:28<00:00, 35.03it/s, loss=0.064]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.45it/s, n=48951]


Interval 14: eval_train: loss=0.0537, mae_e_per_atom=0.9 meV, mae_f=19.9 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.13it/s, n=1000]


Interval 14: eval_valid: loss=0.0552, mae_e_per_atom=0.9 meV, mae_f=20.2 meV/Å, mae_s=2.7 meV/Å³


Train interval 14: 100%|████████| 1000/1000 [00:28<00:00, 35.08it/s, loss=0.052]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.15it/s, n=48951]


Interval 15: eval_train: loss=0.0508, mae_e_per_atom=0.9 meV, mae_f=19.4 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.35it/s, n=1000]


Interval 15: eval_valid: loss=0.0526, mae_e_per_atom=0.9 meV, mae_f=19.7 meV/Å, mae_s=2.7 meV/Å³


Train interval 15: 100%|████████| 1000/1000 [00:28<00:00, 35.03it/s, loss=0.046]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.05it/s, n=48951]


Interval 16: eval_train: loss=0.0481, mae_e_per_atom=0.8 meV, mae_f=18.9 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.35it/s, n=1000]


Interval 16: eval_valid: loss=0.0495, mae_e_per_atom=0.9 meV, mae_f=19.1 meV/Å, mae_s=2.7 meV/Å³


Train interval 16: 100%|████████| 1000/1000 [00:28<00:00, 34.85it/s, loss=0.057]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.04it/s, n=48951]


Interval 17: eval_train: loss=0.0459, mae_e_per_atom=0.8 meV, mae_f=18.4 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.53it/s, n=1000]


Interval 17: eval_valid: loss=0.0474, mae_e_per_atom=0.9 meV, mae_f=18.7 meV/Å, mae_s=2.7 meV/Å³


Train interval 17: 100%|████████| 1000/1000 [00:28<00:00, 34.96it/s, loss=0.058]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.14it/s, n=48951]


Interval 18: eval_train: loss=0.0440, mae_e_per_atom=0.9 meV, mae_f=18.1 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.93it/s, n=1000]


Interval 18: eval_valid: loss=0.0456, mae_e_per_atom=0.9 meV, mae_f=18.4 meV/Å, mae_s=2.7 meV/Å³


Train interval 18: 100%|████████| 1000/1000 [00:28<00:00, 34.91it/s, loss=0.081]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.85it/s, n=48951]


Interval 19: eval_train: loss=0.0422, mae_e_per_atom=0.8 meV, mae_f=17.7 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 53.43it/s, n=1000]


Interval 19: eval_valid: loss=0.0438, mae_e_per_atom=0.9 meV, mae_f=18.0 meV/Å, mae_s=2.7 meV/Å³


Train interval 19: 100%|████████| 1000/1000 [00:28<00:00, 34.86it/s, loss=0.049]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.85it/s, n=48951]


Interval 20: eval_train: loss=0.0405, mae_e_per_atom=0.8 meV, mae_f=17.3 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 56.06it/s, n=1000]


Interval 20: eval_valid: loss=0.0419, mae_e_per_atom=0.8 meV, mae_f=17.6 meV/Å, mae_s=2.7 meV/Å³


Train interval 20: 100%|████████| 1000/1000 [00:28<00:00, 34.84it/s, loss=0.052]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 55.07it/s, n=48951]


Interval 21: eval_train: loss=0.0390, mae_e_per_atom=0.8 meV, mae_f=17.0 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.20it/s, n=1000]


Interval 21: eval_valid: loss=0.0404, mae_e_per_atom=0.9 meV, mae_f=17.3 meV/Å, mae_s=2.7 meV/Å³


Train interval 21: 100%|████████| 1000/1000 [00:28<00:00, 34.92it/s, loss=0.037]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.86it/s, n=48951]


Interval 22: eval_train: loss=0.0375, mae_e_per_atom=0.8 meV, mae_f=16.7 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.12it/s, n=1000]


Interval 22: eval_valid: loss=0.0389, mae_e_per_atom=0.8 meV, mae_f=16.9 meV/Å, mae_s=2.7 meV/Å³


Train interval 22: 100%|████████| 1000/1000 [00:28<00:00, 34.82it/s, loss=0.051]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.90it/s, n=48951]


Interval 23: eval_train: loss=0.0362, mae_e_per_atom=0.7 meV, mae_f=16.4 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.79it/s, n=1000]


Interval 23: eval_valid: loss=0.0375, mae_e_per_atom=0.8 meV, mae_f=16.6 meV/Å, mae_s=2.7 meV/Å³


Train interval 23: 100%|████████| 1000/1000 [00:28<00:00, 34.88it/s, loss=0.040]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.96it/s, n=48951]


Interval 24: eval_train: loss=0.0351, mae_e_per_atom=0.8 meV, mae_f=16.1 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.16it/s, n=1000]


Interval 24: eval_valid: loss=0.0365, mae_e_per_atom=0.8 meV, mae_f=16.4 meV/Å, mae_s=2.7 meV/Å³


Train interval 24: 100%|████████| 1000/1000 [00:28<00:00, 34.77it/s, loss=0.043]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.75it/s, n=48951]


Interval 25: eval_train: loss=0.0341, mae_e_per_atom=0.7 meV, mae_f=15.9 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.30it/s, n=1000]


Interval 25: eval_valid: loss=0.0354, mae_e_per_atom=0.7 meV, mae_f=16.1 meV/Å, mae_s=2.7 meV/Å³


Train interval 25: 100%|████████| 1000/1000 [00:28<00:00, 34.75it/s, loss=0.040]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.86it/s, n=48951]


Interval 26: eval_train: loss=0.0332, mae_e_per_atom=0.8 meV, mae_f=15.6 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.54it/s, n=1000]


Interval 26: eval_valid: loss=0.0346, mae_e_per_atom=0.8 meV, mae_f=15.9 meV/Å, mae_s=2.7 meV/Å³


Train interval 26: 100%|████████| 1000/1000 [00:28<00:00, 34.75it/s, loss=0.046]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.82it/s, n=48951]


Interval 27: eval_train: loss=0.0321, mae_e_per_atom=0.7 meV, mae_f=15.4 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.05it/s, n=1000]


Interval 27: eval_valid: loss=0.0334, mae_e_per_atom=0.7 meV, mae_f=15.6 meV/Å, mae_s=2.7 meV/Å³


Train interval 27: 100%|████████| 1000/1000 [00:28<00:00, 34.79it/s, loss=0.035]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.81it/s, n=48951]


Interval 28: eval_train: loss=0.0313, mae_e_per_atom=0.7 meV, mae_f=15.2 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.07it/s, n=1000]


Interval 28: eval_valid: loss=0.0326, mae_e_per_atom=0.8 meV, mae_f=15.4 meV/Å, mae_s=2.7 meV/Å³


Train interval 28: 100%|████████| 1000/1000 [00:28<00:00, 34.78it/s, loss=0.053]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.73it/s, n=48951]


Interval 29: eval_train: loss=0.0305, mae_e_per_atom=0.7 meV, mae_f=15.0 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 54.77it/s, n=1000]


Interval 29: eval_valid: loss=0.0320, mae_e_per_atom=0.7 meV, mae_f=15.3 meV/Å, mae_s=2.7 meV/Å³


Train interval 29: 100%|████████| 1000/1000 [00:28<00:00, 34.67it/s, loss=0.033]
eval_train: 100%|████████████████████| 999/999 [00:18<00:00, 54.51it/s, n=48951]


Interval 30: eval_train: loss=0.0296, mae_e_per_atom=0.7 meV, mae_f=14.8 meV/Å, mae_s=2.7 meV/Å³


eval_valid: 100%|███████████████████████| 21/21 [00:00<00:00, 55.46it/s, n=1000]


Interval 30: eval_valid: loss=0.0310, mae_e_per_atom=0.7 meV, mae_f=15.0 meV/Å, mae_s=2.7 meV/Å³
Training complete
Training done!


## Derivations for the electric polarization vector and Born effective charges matrice

With the energy model trained above (order = 2) which includes the electric field dependence, we can derive the electric polarization vector and the Born effective charges matrice given a (molecular) structure.
Here the model is trained with the ethanol molecule dataset augmented with the electric field in DFT calculations.
We derive these properties for an ethanol molecule via the auto-differentiation of the energy functional with respect to the electric field vector variable.


First, one load the trained energy model with the electric field dependence.

In [6]:
model_fn, params, num_message_passing = NequIP_JAXMD_model_Efield(
    r_max=r_max,
    atomic_energies_dict={},
    train_graphs=train_loader.graphs,
    initialize_seed=config["model"]["seed"],
    num_species = config["model"]["num_species"],
    use_sc = True,
    graph_net_steps = config["model"]["num_layers"],
    hidden_irreps = config["model"]["internal_irreps"],
    nonlinearities =  {'e': 'swish', 'o': 'tanh'},
    save_dir_name = save_dir_name,
    reload = '2024-03-12-18:44-Efield_pot_training-systematic-fleurette',
)

print("num_params:", sum(p.size for p in jax.tree_util.tree_leaves(params)))
    
predictor = jax.jit(
    lambda w, g: predict_energy_forces_stress(lambda *x: model_fn(w, *x), g)
)


Create NequIP (JAX-MD version) with parameters {'use_sc': True, 'graph_net_steps': 2, 'hidden_irreps': '64x0e + 64x0o + 48x1o + 48x1e +32x2o + 32x2e', 'nonlinearities': {'e': 'swish', 'o': 'tanh'}, 'r_max': 5.0, 'avg_num_neighbors': 8.0, 'avg_r_min': None, 'num_species': 100, 'radial_basis': <function bessel_basis at 0x7f6f593b7ce0>, 'radial_envelope': <function soft_envelope at 0x7f6f593b7420>}
num_params: 1888176


Load the ethanol molecule structure and prepare the computational graph needed

In [7]:
ethanol_mol = ase.io.read('data/ethanol-PBE.vasp',format='vasp')

# shift the atomic positions properly for a molecule
allpos=ethanol_mol.get_positions()
allpos[allpos<-5.0] =  allpos[allpos<-5.0] + 10.0
allpos[allpos>5.0] =  allpos[allpos>5.0] - 10.0
ethanol_mol.set_positions(allpos)

ethanol_config = config_from_atoms(ethanol_mol)
ethanol_graph = graph_from_configuration(ethanol_config,cutoff = config["cutoff"])


Define functions to derive the electric polarization vector (first order derivatives) and the Born effective charge matrix (second order derivatives) via the auto-differentiation with respect to the electric field vector
U defines the neural network energy model, and the electric polarization can be computed as

$$
\vec{P} = \frac{\partial U}{\partial \vec{E}}
$$
and the Born effective charge matrix ($N \times 3 \times 3$) as
$$
Z^* = \frac{\partial^2 U}{\partial \vec{E} \partial \vec{u_i}}
$$

In [8]:

def predict_polarization(
    model, graph: jraph.GraphsTuple
) -> Dict[str, jnp.ndarray]:
    # here we derive the polarization vector for molecules at Efield = 0

    def energy_fn(positions, cell, Efield):
        vectors = get_edge_relative_vectors(
            positions=positions,
            senders=graph.senders,
            receivers=graph.receivers,
            shifts=graph.edges.shifts,
            cell=cell,
            n_edge=graph.n_edge,
        )
        
        #node_efield = jnp.repeat(ext_Efield[None,:],repeats=positions.shape[0],axis=0)
        
        node_energies = model(
            vectors, graph.nodes.species, graph.senders, graph.receivers, Efield  #node_efield,
        )  # [n_nodes, ]
        #print('check shape',node_energies.shape, ' vs. ',len(positions))
        assert node_energies.shape == (
            len(positions),
        ), "model output needs to be an array of shape (n_nodes, )"
        return jnp.sum(node_energies)   #. , node_energies
    
    #ext_Efield = jnp.zeros((3))
    #ext_Efield = jnp.zeros((9,3))
    ext_Efield = jnp.zeros((graph.nodes.positions.shape[0],3))
    graph_polarization = jax.grad(energy_fn,argnums=2, has_aux=False,
    )(graph.nodes.positions, graph.globals.cell,ext_Efield)
    
    return {
        'polarization': graph_polarization
    }

pol_predictor = lambda w, g: predict_polarization(lambda *x: model_fn(w, *x), g)



def predict_becs(
    model, graph: jraph.GraphsTuple
) -> Dict[str, jnp.ndarray]:
    # here we derive the polarization vector for molecules at Efield = 0
    
    def energy_fn(positions, cell, Efield):
        vectors = get_edge_relative_vectors(
            positions=positions,
            senders=graph.senders,
            receivers=graph.receivers,
            shifts=graph.edges.shifts,
            cell=cell,
            n_edge=graph.n_edge,
        )
        
        node_efield = jnp.repeat(Efield[None,:],repeats=positions.shape[0],axis=0)
        
        node_energies = model(
            vectors, graph.nodes.species, graph.senders, graph.receivers, node_efield  #node_efield,
        )  # [n_nodes, ]
        #print('check shape',node_energies.shape, ' vs. ',len(positions))
        assert node_energies.shape == (
            len(positions),
        ), "model output needs to be an array of shape (n_nodes, )"
        return jnp.sum(node_energies)   #. , node_energies
    
    basis = jnp.eye(3)
    
    ext_Efield = jnp.zeros((3))
    
    def graph_force(Efield):
        return jax.grad(energy_fn,argnums=0, has_aux=False)(graph.nodes.positions, graph.globals.cell,Efield)
    
    becs_mat = []
    for basis_jvp in basis:
        becs_eval = jax.jvp(
            graph_force,
            (ext_Efield,),
            (basis_jvp,),
        )[1]
        becs_mat.append(becs_eval[:,:,None])
    
    return {
        'becs': jnp.concatenate(becs_mat,axis=-1) 
    }


becs_predictor = lambda w, g: predict_becs(lambda *x: model_fn(w, *x), g)



Evaluate the electric polarization

In [9]:
pol_pred = pol_predictor(params,ethanol_graph)

vasp_dipole = jnp.array([ 0.004000   ,  -0.318553   ,   0.017774 ])
#print(pol_pred['polarization'].sum(axis=0))
print('VASP dipole: ',vasp_dipole)

#print(-vasp_dipole/pol_pred['polarization'].sum(axis=0))
#print(-pol_pred['polarization'].sum(axis=0) - vasp_dipole)
pred_dipole = -pol_pred['polarization'].sum(axis=0)
print('Predicted dipole: ',pred_dipole)
print('Electric polarization MAE: ',jnp.mean(jnp.abs(pred_dipole - vasp_dipole)))


VASP dipole:  [ 0.004 -0.319  0.018]
Predicted dipole:  [-0.003 -0.303  0.017]
Electric polarization MAE:  0.007493755


Evaluate the Born effective charge matrix

In [10]:
becs_vasp = """ ion    1
    1     0.76779    -0.22702    -0.00503
    2    -0.20500     0.41408     0.00527
    3    -0.00626     0.00546     0.34978
 ion    2
    1    -0.07405    -0.09377     0.01263
    2    -0.03967     0.05148     0.00411
    3     0.00892     0.00629     0.07288
 ion    3
    1    -0.83639     0.31798    -0.00369
    2     0.27561    -0.41805    -0.01616
    3    -0.00095    -0.01758    -0.48789
 ion    4
    1    -0.03316     0.04499    -0.01466
    2     0.00274    -0.07274     0.15055
    3     0.00513     0.10020    -0.12405
 ion    5
    1    -0.03347     0.03681     0.00317
    2    -0.01095    -0.10167    -0.15235
    3    -0.01139    -0.10634    -0.09470
 ion    6
    1    -0.07604     0.10381     0.00271
    2     0.06480    -0.01239     0.00087
    3     0.00416    -0.00072     0.07915
 ion    7
    1     0.05303    -0.05394    -0.01745
    2    -0.01201    -0.01803    -0.09756
    3     0.01182    -0.07743    -0.04958
 ion    8
    1     0.05303    -0.04470     0.01086
    2    -0.00360    -0.00045     0.09273
    3    -0.02200     0.07730    -0.06531
 ion    9
    1     0.17927    -0.08415     0.01146
    2    -0.07192     0.15777     0.01253
    3     0.01057     0.01283     0.31972 """

becs_vasp = becs_vasp.split('\n')

becs_dft_data = []
for ind,becs_dft in enumerate(becs_vasp):
    if ind % 4 ==0:
        continue
    tmp_list = [float(tmpnum) for tmpnum in becs_dft.split()[1:]]
    becs_dft_data.append(tmp_list)
becs_dft_data = jnp.array(becs_dft_data)
becs_dft_data = becs_dft_data.reshape((9,3,3))

# perform the prediction via second order derivatives
pred_out = becs_predictor(params,ethanol_graph)

becs_pred = jnp.transpose(pred_out['becs'], axes=(0,2,1))

print('Born effective charge matrix MAE: ',jnp.mean(jnp.abs(becs_pred - becs_dft_data)))



Born effective charge matrix MAE:  0.0058555966
