In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import json
import logging
import os
from pprint import pprint

from mlff.src.data.preprocessing import split_data
from mlff.src.training import Optimizer, Coach, get_loss_fn, create_train_state
from mlff.src.io.io import create_directory, merge_dicts, bundle_dicts, save_dict
from mlff.src.data import DataTuple, DataSet
from mlff.src.indexing.indices import get_indices
from mlff.src.nn.stacknet import StackNet, get_obs_and_grad_obs_fn, get_obs_and_force_fn
from mlff.src.nn.embed import AtomTypeEmbed, GeometryEmbed
from mlff.src.nn.layer import So3kratesLayer
from mlff.src.nn.observable import Energy
import wandb
logging.basicConfig(level=logging.INFO)

In [None]:
# We start by giving the path to the data we want to train our model on, as well as the path where we want 
# to save the model checkpoints as well as the hyperparamter file. 

data_path = 'example_data/ethanol.npz'
save_path = 'example_model/'
ckpt_dir = os.path.join(save_path, 'module')
ckpt_dir = create_directory(ckpt_dir, exists_ok=False)

# next we define the property keys. The keys in the dictionary should not be changed, but the values can be 
# changed arbitrarily such that they match the names in the data file.

prop_keys = {'energy': 'E',
             'force': 'F',
             'atomic_type': 'z',
             'atomic_position': 'R',
             'hirshfeld_volume': None,
             'total_charge': None,
             'total_spin': None,
             'partial_charge': None
             }

# Training

In [None]:
# We now load the data. We assume that the data is given or has been transformed to .npz format.
data = dict(np.load(data_path))

# Initialize a DataSet object with the property keys and the loaded (and dict transformed) data set. For 
# a more detailed introduction into the DataSet object take a look at the 00_Data_Preparation.ipynb example.
md17_dataset = DataSet(prop_keys=prop_keys, data=data)

# Since StackNets work on neighborhood index lists which often are not part of the input data, the split function
# takes an argument for the cutoff radius as input. Based on the cutoff radius neighborhood lists are 
# constructed which are saved as 'idx_i' and 'idx_j' in the returned dictionary. The former is the centering
# atom and the latter the neighboring atoms.
r_cut = 5.
# Next, we split the data, where we define the keys that should be split into training, validation and testing
# data.
d = md17_dataset.random_split(n_train=100,
                              n_valid=100,
                              n_test=None,
                              training=True,
                              seed=0,
                              r_cut=r_cut)

# The resulting dictionary, has mutliple keys which also includes keys that have not been defined as quantities to
# split. For that reason they are in the upper level of the dictionary.
print(list(d.keys()))

md17_dataset.save_splits_to_file(path=ckpt_dir, filename='my_first_training_split.json')

In [None]:
# We can also see that there are the keys called 'train', 'valid' and 'test', respectively. Each key contains
# another dictionary which contains the splitted quantities as numpy arrays: E.g.
print('Keys in the training split: {}'.format(list(d['train'].keys())))
print('Shape of the atomic positions in the training set: {}'.format(d['train']['R'].shape))
print('Shape of the atomic positions in the test set: {}'.format(d['test']['R'].shape))

In [None]:
# each StackNet consists of 4 building blocks.

# 1) Sequence of modules that embed the geometry of molecule. Here we choose a single module, that returns geometry 
# related quantities such that the expansion of the interatomic distance vectors in spherical harmonics, the 
# expansion of the interatomic distances in some radial basis function as well as cutoff function related 
# quantities. It also takes as input the prop_keys dictionary in order to "know" the name of 
# the atomic positions in the data.
geometry_embeddings = [GeometryEmbed(degrees=[1, 2],
                                    radial_basis_function='phys',
                                    n_rbf=32,
                                    radial_cutoff_fn='cosine_cutoff_fn',
                                    r_cut=r_cut,
                                    prop_keys=prop_keys,
                                    sphc=True
                                    )]

# 2) A list of modules that embed different input quantities. Since in our example we only have atomic types
# as input, we only use the `AtomTypeEmbed` module. It takes the atomic embeddings and returns a feature vectors
# of dimension `features` based on the atomic type.
embeddings = [AtomTypeEmbed(num_embeddings=100, features=32, prop_keys=prop_keys)]


# 3) A list of modules that represent layers. Here we use 2 So3krates layers.
so3krates_layer = [So3kratesLayer(fb_filter='radial_spherical',
                                  fb_rad_filter_features=[32, 32],
                                  fb_sph_filter_features=[32, 32],
                                  fb_attention='conv_att',
                                  gb_filter='radial_spherical',
                                  gb_rad_filter_features=[32, 32],
                                  gb_sph_filter_features=[32, 32],
                                  gb_attention='conv_att',
                                  degrees=[1, 2],
                                  n_heads=2,
                                  chi_cut=None,
                                  ) for _ in range(2)]


# 4) A list of observable modules that are not related to the input by some differential operator. E.g. forces are 
# the gradient wrt the energy, thus it will be defined as an extra observable in the next step. We additionally
# rescale the energy output of the network using a per atom scale and a per atom shift. Here we choose the mean
# over all training energies divided by the number of atoms as per atom shift and the standard deviation of all 
# force components as scale. Note that one can also rescale the target data instead of the energy output. However,
# by making the scaling a quantity of the network itself, it can be applied later without any reference to the 
# training data. However, the learning rate used has to be scaled accordingly, since the loss will be larger
# in the setting of rescaling the network output.
F_scale = d['train']['F_scale']
E_mean = d['train']['E_mean']
n_atoms = d['train']['z'].shape[-1]
obs = [Energy(per_atom_scale=[F_scale.tolist()]*20, per_atom_shift=[(E_mean/n_atoms).tolist()]*20, prop_keys=prop_keys)]


# We now put everything together into the StackNet.
net = StackNet(geometry_embeddings=geometry_embeddings,
               feature_embeddings=embeddings,
               layers=so3krates_layer,
               observables=obs,
               prop_keys=prop_keys)

In [None]:
# All quantities that can be written D_{input}(output) where D_{inputs} denotes a differential operator wrt
# to some input quantity can be defined as additional observable after the network has been initialized.
# For force fields, this additional quantitie usually is the force, which is the negative gradient of the energy wrt
# the atomic positions. 
# Derivatives are defined as a `Tuple[name: str, Tuple[output_key: str, input_key: str, transform: Callable]]`, 
# where `name` determines the key in the final observable dictionary (output of `obs_fn`). The second tuple has 
# in total three entries where the first is the key of the output quantity and the second the input quantity that
# wrt the derivative should be taken. The third entry is a transformation which can be used the scale (here the 
# negtive sign) and perform further operations on the output. Since the energy output has shape (1) the gradient
# wrt the positions has shape (1,n_atoms,3) which is why we squeeze away the 0-th dimension such that forces 
# have shape (n_atoms,3)

D_force = (prop_keys['force'], (prop_keys['energy'], prop_keys['atomic_position'], lambda y: -y.squeeze(-3)))
obs_fn = get_obs_and_grad_obs_fn(net, derivatives=(D_force, ))

# Alternatively, as forces are often the only gradient observable, mlff also provides a pre-implemented 
# function that returns an observable function for all observables in the StackNet as well as the forces.
# It could be used instead of `obs_fn` for all following steps.
obs_fn_ = get_obs_and_force_fn(net)

# Since all code internally assumes no batch dimension, we vmap the input over the batch dimension. We will 
obs_fn = jax.vmap(obs_fn, in_axes=(None, 0))

In [None]:
from pprint import pformat

n_data = 2
init_input = {k:jnp.array(v[0]) for (k, v) in d['train'].items() if k in 
              [prop_keys['atomic_position'], prop_keys['atomic_type'], 'idx_i', 'idx_j']}
params = net.init(jax.random.PRNGKey(1234), init_input)
fwd_input = {k:v[:n_data] for (k, v) in d['train'].items() if k in 
             [prop_keys['atomic_position'], prop_keys['atomic_type'], 'idx_i', 'idx_j']}

# Lets see how the output of the obs_fn looks like. It returns a dictionary where each observable has 
# its own entry that can be accessed using the corresponding key of the observable that is specified in the 
# props_key dictionay in the very beginning.

observables = obs_fn(params, fwd_input)
print('All observables:\n {}'.format(pformat(observables)))
print('Energy:\n {}'.format(pformat(observables[prop_keys['energy']])))
print('Force:\n {}'.format(pformat(observables[prop_keys['force']])))

In [None]:
# mlff provides a default optimizer that follows the default settings for the AdamW optimizer. It can be
# initialized by just initializing the optimizer class. An optax optimizer will be returned by calling the
# `.get()` method, which takes a learning rate and returns the corresponding optax optimizer. Note, that 
# if we are using a network that scales and shifts its output, one has to use larger learning rates
# as if one is training on rescaled training data. This is due to the fact, that the variance of the targets
# (and thus the gradients wrt to the loss function) is larger if the training data is not rescaled. Thus, 
# usual learning rates have to be increased by the same order of magnitude. If not rescaling the data itself,
# we found training to be more stable when transforming kcal/mol to eV.
opt = Optimizer()
tx = opt.get(learning_rate=1e-3*np.sqrt(F_scale))


# mlff also provied a Coach `dataclass` which is used as storage for all quantities that are associated 
# with the training process. 
# The `input_keys` determine which quantities shall be used during the training process.
# Here one has to make sure, that the input data provides all neccessary quantitities that are used by 
# the network. E.g. if one is using the `AtomTypeEmbed` module, one has to make sure that the atomic types
# are passed as an input.
# The `target_keys` determine which observables enter the calculcation of the loss function. Here we 
# chose energy and forces. The network can still output additional observables, e.g. the partial charges
# which are just not taken into account for the calculcation of the loss if not listed here. This can be
# useful if one want to use the same model for different training routines where one starts by training
# only a subset of the quantities. However, the observable function must output all observables that are 
# listed here. 
# The `loss_weight` attribute, assigns the scaling parameter to different quantities used in the loss function.
coach = Coach(input_keys=[prop_keys['atomic_position'], prop_keys['atomic_type'], 'idx_i', 'idx_j'],
              target_keys=[prop_keys['energy'], prop_keys['force']],
              epochs=200,
              training_batch_size=2,
              validation_batch_size=2,
              loss_weights={prop_keys['energy']: .01, prop_keys['force']: .99},
              ckpt_dir=ckpt_dir,
              data_path=data_path,
              net_seed=0,
              training_seed=0)

# The `get_loss_fn` method, returns a loss function given the observable function and the `loss_weights`.
# As the loss function is acessing the `loss_weights` during training, the specific loss function only 
# works for the given observables.
loss_fn = get_loss_fn(obs_fn=obs_fn, weights=coach.loss_weights)

In [None]:
# Before starting the training, we have to split the training and validation data into input and output 
# quantitites. This is done using the `DataTuple` class, which is initialized using `input_keys` and 
# `target_keys` which are just lists keys.

data_tuple = DataTuple(input_keys=coach.input_keys,
                       target_keys=coach.target_keys)

# After initializing the `DataTuple` we can call it on the data which has been splitted using the methods from
# before. `train_ds` and `valid_ds` are Tuples[Array, Array] where the first entry is the inputs and the second
# the target data.
train_ds = data_tuple(d['train'])
valid_ds = data_tuple(d['valid'])

In [None]:
print('Inputs:\n {}'.format(pformat(jax.tree_map(lambda y: y[:2], train_ds[0]))))
print('Outputs:\n {}'.format(pformat(jax.tree_map(lambda y: y[:2], train_ds[1]))))

In [None]:
# At last step we have to initialize the parameters of the network. As the init function can not be vmaped
# and internally we assume no batch dimension, ons has to initialize with data that has no batch dimension.
# Here this is achieved using the `jax.tree_map` function which selects a single data point for all input
# quantities. After we have initialized the parameters, the `net`, its `params` and the optax optimizer `tx`
# are used to create the `train_state` which handles the gradient updates and checkpoints. The method further
# returns a dictionary representation for the train_state hyperparameters.
# The parameters are a `FrozenDict` and contains all initialized parameters of the network.
inputs = jax.tree_map(lambda x: jnp.array(x[0, ...]), train_ds[0])
params = net.init(jax.random.PRNGKey(coach.net_seed), inputs)
train_state, h_train_state = create_train_state(net, 
                                                params, 
                                                tx, 
                                                scheduled_lr_decay={'exponential': {'transition_steps': 50_000,
                                                                                    'decay_factor': 0.5}
                                                                   }
                                                )

In [None]:
# In order to reproduce and save all information, all classes in mlff implement a `__dict_repr__()` method, 
# that returns a dictionary representation of the class. This can be used to e.g. load the model after training
# to use it for evaluation. We show how to use a trained model below.
h_net = net.__dict_repr__()
h_opt = opt.__dict_repr__()
h_coach = coach.__dict_repr__()
h_dataset = md17_dataset.__dict_repr__()
h = bundle_dicts([h_net, h_opt, h_coach, h_dataset, h_train_state])
save_dict(path=ckpt_dir, filename='hyperparameters.json', data=h, exists_ok=True)

In [None]:
# initialize the weight and bias project. For all possible parameters passed to the .init() method check
# https://docs.wandb.ai/ref/python/init
wandb.init(project='mlff', name='my_first_force_field', config=h)  
# We can use the `.run()` method of the `Coach` class to run training.
# Note, that the first call might take some time, since JAX compiles the computational graph
# for heavy optimization and parallelization. 
coach.run(train_state=train_state, 
          train_ds=train_ds, 
          valid_ds=valid_ds, 
          loss_fn=loss_fn,
          log_every_t=1,
          eval_every_t=100,  # evaluate validation loss every t gradient steps
          ckpt_overwrite=False)
# after running the training once, see what changes if you try to run the training again with 
# `ckpt_overwrite=False` and `ckpt_overwrite=True`

# Evaluation on Same Data Set

In [None]:
from mlff.src.io.io import read_json
from mlff.src.nn.stacknet import init_stack_net
from flax.training import checkpoints

In [None]:
# load the hyperparemter file and restore the coach and initialize the StackNet. Since Coach is a simple 
# dataclass it can be directly loaded using the dictionary. For the StackNet, one needs to use the function
# `init_stack_net()` which initializes the underlying modules in the StackNet given the hyperparameters.
h = read_json(os.path.join(ckpt_dir, 'hyperparameters.json'))

coach = Coach(**h['coach'])
test_data_tuple = DataTuple(input_keys=coach.input_keys,
                            target_keys=coach.target_keys)

test_net = init_stack_net(h)

test_data = dict(np.load(coach.data_path))
test_dataset = DataSet(data=test_data, prop_keys=prop_keys)

In [None]:
d_test = test_dataset.load_split(file=os.path.join(ckpt_dir, 'my_first_training_split.json'), 
                                 r_cut=5.,
                                 n_test=200,
                                 split_name='random_split')
test_input, test_obs = test_data_tuple(d_test['test'])
# In total 400 geometries are used, since training and validation data are also restored. If you want to turn this 
# off, set n_train and n_test to 0 in the above function.

In [None]:
test_obs_fn = get_obs_and_force_fn(test_net)
test_obs_fn = jax.jit(jax.vmap(test_obs_fn, in_axes=(None, 0)))

In [None]:
# Restore the trained parameters from the checkpoint directory. For the behavior of `.restore_checkpoint()` method
# check the FLAX github or documentation.
test_params = checkpoints.restore_checkpoint(ckpt_dir=ckpt_dir, target=None)['params']

In [None]:
# predict energy and forces and calculate the mean absolute errors.
obs_pred = test_obs_fn(test_params, test_input)
e_mae = np.abs(obs_pred['E'].reshape(-1) - test_obs['E'].reshape(-1)).mean()
f_mae = np.abs(obs_pred['F'].reshape(-1) - test_obs['F'].reshape(-1)).mean()
print('energy mae: {} kcal/mol // force mae: {} kcal/mol A'.format(e_mae, f_mae))

In [None]:
# mlff also provides a function for model evaluation where you use pre-implemented metric functions
# or alternatively your own metric functions. For that, have a look at the source code if you want to implement
# your own.

from mlff.src.inference.evaluation import evaluate_model, mae_metric, rmse_metric

metrics, obs_pred_2 = evaluate_model(params=test_params, 
                                     obs_fn=test_obs_fn,
                                     data=test_data_tuple(d_test['test']),
                                     metric_fn={'mae': mae_metric,
                                                'rmse': rmse_metric},
                                     batch_size=10)
pprint(metrics)

# Evaluation on Different Data Set

In [None]:
# Lets assume we have trained a model on the MD17 benchmark as above and want to apply if to a different dataset
# e.g. QM7-X where the keys e.g. for atomic positions and atomic types are different from the ones in the MD17
# dataset. In order to deal with that, one can re-set the property keys of the StackNet (and all its submodules).

qm7x_prop_keys = {'energy': 'ePBE0+MBD',
                  'force': 'totFOR',
                  'atomic_position': 'atXYZ',
                  'atomic_type': 'atNUM'}

qm7x_E_key = qm7x_prop_keys['energy']
qm7x_F_key = qm7x_prop_keys['force']
qm7x_R_key = qm7x_prop_keys['atomic_position']
qm7x_z_key = qm7x_prop_keys['atomic_type']

qm7x_data = dict(np.load('example_data/qm7x_226.npz'))
qm7x_dataset = DataSet(prop_keys=qm7x_prop_keys, data=qm7x_data)
qm7x_data_split = qm7x_dataset.random_split(n_train=10,
                                            n_valid=10,
                                            n_test=40,
                                            r_cut=5,
                                            training=False,
                                            seed=0)
data_tuple = DataTuple(input_keys=[qm7x_R_key, qm7x_z_key, 'idx_i', 'idx_j'], target_keys=[qm7x_E_key, qm7x_F_key])
test_input, test_obs = data_tuple(qm7x_data_split['test'])

In [None]:
qm7x_data.keys()

In [None]:
# Here we get an error, since stacknet expects a 'z' as key for the atomic numbers.
test_obs_fn(test_params, test_input)

In [None]:
# we now reset the property keys of the trained stack net, as well as for all its submodules.
test_net.reset_prop_keys(qm7x_prop_keys, sub_modules=True)
test_obs_fn_resetted = jax.vmap(jax.jit(get_obs_and_force_fn(test_net)), (None, 0))
test_obs_fn_resetted(test_params, test_input)