Skip to content

Commit

Permalink
Move distinct coefficient to jax for large speedup (#552)
Browse files Browse the repository at this point in the history
* move GK distinct to jax

* move einstein to jax

* Fix axis label on plot

* Update einstein_distinct_diffusion_coefficients.py

* move helper functions and add tests
  • Loading branch information
SamTov committed Aug 10, 2022
1 parent ca4080e commit e1139c1
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_experiment(traj_file, true_values, tmp_path):
)

project.experiments["NaCl"].run.GreenKuboDistinctDiffusionCoefficients(
plot=False, correlation_time=100
plot=False, correlation_time=500
)

# data_dict = (
Expand Down
1 change: 0 additions & 1 deletion CI/unit_tests/transformations/test_transformator_parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def test_transformation_on_new_data_(tmp_path):
project = mds.Project()

def check_trafo(trafo_class, exp_name):

project.add_experiment(name=exp_name, timestep=12345)
exp = project.experiments[exp_name]

Expand Down
71 changes: 70 additions & 1 deletion CI/unit_tests/utils/test_calculator_helper_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@
"""
import numpy as np
import pytest
from numpy.testing import assert_array_equal, assert_raises

from mdsuite.utils.calculator_helper_methods import fit_einstein_curve
from mdsuite.utils.calculator_helper_methods import (
correlate,
fit_einstein_curve,
msd_operation,
)


class TestCalculatorHelperMethods:
Expand Down Expand Up @@ -64,3 +69,67 @@ def test_fit_einstein_curve(self):
x_data=x_data, y_data=y_data, fit_max_index=999
)
assert popt[0] == pytest.approx(5.0, 0.01)

def test_correlate(self):
"""
Test the correlate helper function.
Returns
-------
Tests to see if the net cross correlation between a sine with itself and a sine
with a lagged version of itself is zero.
The first signal is auto-correlated, the second is perfectly anti-correlated.
Therefore, when summed, they should cancel to zero.
"""
# generate 10 points
t = np.arange(10)
# Create a 3d array
x_data = np.vstack((t, t, t)).reshape(10, 3)

sine_data = np.sin(x_data)
lagged_sine_data = np.sin(x_data + np.pi)

auto_correlation = np.array(correlate(sine_data, sine_data))
cross_correlation = np.array(correlate(sine_data, lagged_sine_data))

assert_raises(AssertionError, assert_array_equal, auto_correlation, np.zeros(10))
assert_raises(AssertionError, assert_array_equal, cross_correlation, np.zeros(10))

# Clip to correlate precision.
summed_data = auto_correlation + cross_correlation
summed_data[summed_data < 1e-10] = 0.0

assert summed_data.sum() == 0.0

def test_msd_operation(self):
"""
Test the msd helper function.
Returns
-------
Tests to see if the net cross correlation between a sine with itself and a sine
with a lagged version of itself is zero.
The first signal is auto-correlated, the second is perfectly anti-correlated.
Therefore, when summed, they should cancel to zero.
"""
# generate 10 points
t = np.arange(10)
# Create a 3d array
x_data = np.vstack((t, t, t)).reshape(10, 3)

sine_data = np.sin(x_data)
lagged_sine_data = np.sin(x_data + np.pi)

auto_correlation = np.array(msd_operation(sine_data, sine_data))
cross_correlation = np.array(msd_operation(sine_data, lagged_sine_data))

assert_raises(AssertionError, assert_array_equal, auto_correlation, np.zeros(10))
assert_raises(AssertionError, assert_array_equal, cross_correlation, np.zeros(10))

# Clip to correlate precision.
summed_data = auto_correlation + cross_correlation
summed_data[summed_data < 1e-10] = 0.0

assert summed_data.sum() == 0.0
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ mdanalysis
black
znvis
ase
astroid
140 changes: 78 additions & 62 deletions mdsuite/calculators/einstein_distinct_diffusion_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
from dataclasses import dataclass
from typing import Any, List, Union

import jax
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from mdsuite.calculators.calculator import call
from mdsuite.calculators.trajectory_calculator import TrajectoryCalculator
from mdsuite.database.mdsuite_properties import mdsuite_properties
from mdsuite.utils.calculator_helper_methods import fit_einstein_curve
from mdsuite.utils.calculator_helper_methods import fit_einstein_curve, msd_operation


@dataclass
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, **kwargs):

self.database_group = "Diffusion_Coefficients"
self.x_label = r"$$\text{Time} / s $$"
self.y_label = r"$$\text{VACF} / m^{2}/s^{2}$$"
self.y_label = r"$$\text{MSD} / m^{2}$$"
self.analysis_name = "Einstein_Distinct_Diffusion_Coefficients"
self.experimental = True
self.result_keys = ["diffusion_coefficient", "uncertainty"]
Expand Down Expand Up @@ -170,35 +170,75 @@ def __call__(
species=species,
fit_range=fit_range,
)
self.time = self._handle_tau_values()
self.time = self._handle_tau_values() * self.experiment.units.time

self.msd_array = np.zeros(self.args.data_range) # define empty msd array

def msd_operation(self, ensemble: tf.Tensor, square: bool = True):
def _map_over_particles(self, ds_a: np.ndarray, ds_b: np.ndarray) -> np.ndarray:
"""
Perform a simple msd operation.
Function to map a correlation in a Gram matrix style over two data sets.
This function will perform the nxm calculations to compute the correlation
between all particles in ds_a with all particles in ds_b.
Parameters
----------
ensemble : tf.Tensor
Trajectory over which to compute the msd.
square : bool
If true, square the result, else just return the difference.
ds_a : np.ndarray (n_particles, n_configurations, dimension)
Dataset to compute correlation with.
ds_b : np.ndarray (n_particles, n_configurations, dimension)
Other dataset to compute correlation with. Does not need to be the
same shape as ds_a along the zeroth (particle) axis.
Returns
-------
msd : tf.Tensor shape=(n_atoms, data_range, 3)
Mean square displacement.
"""
if square:
return tf.math.squared_difference(
tf.gather(ensemble, self.args.tau_values, axis=1), ensemble[:, None, 0]
)
else:
return tf.math.subtract(ensemble, ensemble[:, None, 0])

def ref_conf_map(ref_dataset, full_ds):
"""
Maps over the atoms axis in dataset
Parameters
----------
Returns
-------
"""

def test_conf_map(test_dataset):
"""
Map over atoms in test dataset.
Parameters
----------
test_dataset
Returns
-------
"""
return msd_operation(ref_dataset, test_dataset)

return np.mean(jax.vmap(test_conf_map, in_axes=0)(full_ds), axis=0)

acf_calc = jax.vmap(ref_conf_map, in_axes=(0, None))

return np.mean(acf_calc(ds_a, ds_b), axis=0)

def _compute_self_correlation(self, ds_a, ds_b):
"""
Compute the self correlation coefficients.
Parameters
----------
ds_a : np.ndarray (n_timesteps, n_atoms, dimension)
ds_b : np.ndarray (n_timesteps, n_atoms, dimension)
Returns
-------
"""
atomwise_vmap = jax.vmap(msd_operation, in_axes=0)

return np.mean(atomwise_vmap(ds_a, ds_b), axis=0)

def _compute_msd(self, data: dict, data_path: list, combination: tuple):
"""
Compute the vacf on the given dictionary of data.
Compute the msd on the given dictionary of data.
Parameters
----------
Expand All @@ -212,44 +252,17 @@ def _compute_msd(self, data: dict, data_path: list, combination: tuple):
-------
updates the class state
"""
# shape = (n_atoms, data_range, 3)
msd_a = self.msd_operation(data[data_path[0]], square=False)
msd_b = self.msd_operation(data[data_path[0]], square=False)

for i in range(len(data[data_path[0]])):
for j in range(len(data[data_path[1]])):
if combination[0] == combination[1] and i == j:
continue
else:
self.msd_array += self.prefactor * np.array(
tf.reduce_sum(msd_a[i] * msd_b[j], axis=1)
)

def _calculate_prefactor(self, species: Union[str, tuple] = None):
"""
calculate the calculator pre-factor.
msd_array = self._map_over_particles(
data[data_path[0]].numpy(), data[data_path[1]].numpy()
)

Parameters
----------
species : str
Species property if required.
Returns
-------
Updates the prefactor attribute of the class.
"""
if species[0] == species[1]:
atom_scale = self.experiment.species[species[0]].n_particles * (
self.experiment.species[species[1]].n_particles - 1
)
else:
atom_scale = (
self.experiment.species[species[0]].n_particles
* self.experiment.species[species[1]].n_particles
if combination[0] == combination[1]:
self_correction = self._compute_self_correlation(
data[data_path[0]].numpy(), data[data_path[1]].numpy()
)
msd_array -= self_correction

numerator = self.experiment.units.length**2
denominator = self.experiment.units.time * atom_scale
self.prefactor = numerator / denominator
self.msd_array += msd_array

def _apply_averaging_factor(self):
"""
Expand All @@ -267,29 +280,32 @@ def _post_operation_processes(self, species: Union[str, tuple] = None):
-------
"""
self._apply_averaging_factor() # update in place
self.msd_array *= self.experiment.units.length**2
try:
fit_values, covariance, gradients, gradient_errors = fit_einstein_curve(
x_data=self.time, y_data=self.msd_array, fit_max_index=self.args.fit_range
)
error = np.sqrt(np.diag(covariance))[0]

data = {
"diffusion_coefficient": fit_values[0],
"uncertainty": error,
"time": self.time.tolist(),
"msd": self.msd_array.tolist(),
self.result_keys[0]: 1 / 2 * fit_values[0],
self.result_keys[1]: 1 / 2 * error,
self.result_series_keys[0]: self.time.tolist(),
self.result_series_keys[1]: self.msd_array.tolist(),
}

except ValueError:
fit_values, covariance, gradients, gradient_errors = fit_einstein_curve(
x_data=self.time,
y_data=abs(self.msd_array),
fit_max_index=self.args.fit_range,
)
error = np.sqrt(np.diag(covariance))[0]

# division by dimension is performed in the mapping, therefore, only 2 here.
data = {
self.result_keys[0]: -1 / 6 * fit_values[0],
self.result_keys[1]: 1 / 6 * error,
self.result_keys[0]: -1 / 2 * fit_values[0],
self.result_keys[1]: 1 / 2 * error,
self.result_series_keys[0]: self.time.tolist(),
self.result_series_keys[1]: self.msd_array.tolist(),
}
Expand Down Expand Up @@ -320,7 +336,6 @@ def run_calculator(self):
for species in species_values
]
batch_ds = self.get_batch_dataset(species_values)
self._calculate_prefactor(combination)

for batch in tqdm(
batch_ds,
Expand All @@ -334,3 +349,4 @@ def run_calculator(self):
self._compute_msd(ensemble, dict_ref, combination)

self._post_operation_processes(combination)
self.msd_array = np.zeros(self.args.data_range) # define empty msd array
Loading

0 comments on commit e1139c1

Please sign in to comment.