In [None]:
import numpy as np

from pprint import pprint

from mlff.src.data import DataSet

# set the property keys that set the correspondence between the keys in the npz file and the chemical quantities
prop_keys = {'energy': 'E',
             'force': 'F',
             'atomic_type': 'z',
             'atomic_position': 'R',
             }

E_key = prop_keys['energy']
F_key = prop_keys['force']
R_key = prop_keys['atomic_position']
z_key = prop_keys['atomic_type']

In [None]:
# Load the npz data. As the DataSet object works with dictionarys we have to transform it to a dictionary first.
data_path = 'example_data/ethanol.npz'
data = dict(np.load(data_path))

# Initialize a DataSet object with the property keys and the loaded (and dict transformed) data set.
md17_dataset = DataSet(prop_keys=prop_keys, data=data)

In [None]:
# A data set object supports three different split functions, which are `random_split`, `strat_split` and 
# `index_split` and all split the data into training, validation and testing data. The `random_split` function
# randomly selectes `n_train`, `n_valid` and `n_test` data points. The `strat_split` function does select
# data points to match the distribution of the quantity in `strat_key` as good as possible. The `index_split`
# function splits the data given the data point indices data file.

# If `n_test = None`, it takes all points
# that are not part of the training and validation data as test points. If `r_cut` is not `None`, the function
# also calculates the neighborhood lists for all geometries. The `training` argument allows to control if the 
# neighborhood lists should be also calculated for the test data set, since the test data is not required during
# training. For large data sets, e.g. MD17, this saves the overhead of calculating neighborhood lists that are not
# needed at training time. The `split_name` argument allows to give the split a custom name in order to save and load
# it afterwards. If no `split_name` is passed, it defaults to `random_split`, `strat_split` and `index_split`, 
# respectively. If the same name is used twice, it is overwritten internally.


random_split = md17_dataset.random_split(n_train=20, 
                                         n_valid=20, 
                                         n_test=None,
                                         training=True,
                                         r_cut=5., 
                                         seed=0,
                                         split_name='my_random_split')


strat_split = md17_dataset.strat_split(n_train=30, 
                                       n_valid=30, 
                                       n_test=10,
                                       training=True,
                                       r_cut=5., 
                                       seed=0,
                                       strat_key=E_key,
                                       split_name='my_strat_split')


index_split = md17_dataset.index_split(data_idx_train=np.array([0, 1, 2]), 
                                       data_idx_valid=np.array([3, 4]), 
                                       data_idx_test=np.array([5], dtype=int), 
                                       r_cut=5.,
                                       split_name='my_index_split')

In [None]:
# At the top level, we find the data indicees that give the index of the data point in the original data file.
print(strat_split.keys())

In [None]:
# Going e.g. one level deeper, we find that each quantity has its own entry under which we find the data.
print(strat_split['train'].keys())
# So the energies in the training set can be called as
print('Energy data shape: {}'.format(strat_split['train']['E'].shape))
print('Energy data:\n {}'.format(strat_split['train']['E']))

# Save a data split

In [None]:
# If we look at the splits value of the data set object, we find index lists for each split. In order to recover
# splits, e.g. for testing, one can save the splits to a file using the `save_splits_to_file` function.
 
pprint(md17_dataset.splits)
md17_dataset.save_splits_to_file(path='', filename='my_splits.json')

# Load a data split

In [None]:
# The saved index lists can be loaded using the `load_splits_from_file` function.

my_saved_splits = md17_dataset.load_splits_from_file(path='', filename='my_splits.json')
list(my_saved_splits.keys())

In [None]:
# If one is interested in the data itself, rather than only in the index lists, one can load the data set 
# by using the `load_split` function.
rec_strat_split = md17_dataset.load_split(file='my_splits.json',
                                          r_cut=5.,
                                          split_name='my_strat_split')

In [None]:
# One can additionally pass n_train, n_valid and n_test arguments to the function, if one only wants to
# recover e.g. a subpart of the data. `None` defaults 
rec_random_split = md17_dataset.load_split(file='my_splits.json',
                                           r_cut=5.,
                                           n_train=None,
                                           n_valid=None,
                                           n_test=100,
                                           split_name='my_random_split')

In [None]:
# check that the recovered splits actually match the original ones

rec_random_split['train'][R_key].all() == random_split['train'][R_key].all()