Skip to content

Commit

Permalink
Clean up TrajectoryData (#146)
Browse files Browse the repository at this point in the history
Remove unneeded code
Add unit tests
Also silences a warning in the GROMACS parser test and fixes some comments
  • Loading branch information
ptmerz committed Apr 29, 2021
1 parent 1ce021f commit 1623f3f
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 117 deletions.
2 changes: 1 addition & 1 deletion physical_validation/data/gromacs_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def get_simulation_data(self, mdp=None, top=None, edr=None, trr=None, gro=None):
)
else:
if trajectory_dict is not None:
box = trajectory_dict["box"]["box"][0]
box = trajectory_dict["box"].box[0]
# Different box shapes?
if box.ndim == 1:
volume = box[0] * box[1] * box[2]
Expand Down
12 changes: 3 additions & 9 deletions physical_validation/data/observable_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,11 @@ def __init__(
self.temperature = temperature
self.constant_of_motion = constant_of_motion

def get(self, key):
return self[key]

def __getitem__(self, key):
if key not in self.observables():
raise KeyError
return self.__getters[key](self)

def set(self, key, value):
self[key] = value

def __setitem__(self, key, value):
if key not in self.observables():
raise KeyError
Expand Down Expand Up @@ -210,14 +204,14 @@ def nframes(self):
"""Get number of frames"""
frames = None
for observable in ObservableData.observables():
if self.get(observable) is not None:
if self[observable] is not None:
if frames is not None:
if self.get(observable).size == frames:
if self[observable].size == frames:
continue
else:
return None
else:
frames = self.get(observable).size
frames = self[observable].size

return frames

Expand Down
144 changes: 40 additions & 104 deletions physical_validation/data/trajectory_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,81 +13,22 @@
r"""
Data structures carrying simulation data.
"""
import warnings
from typing import Optional

import numpy as np

from ..util import error as pv_error
from ..util.util import array_equal_shape_and_close


class Box(object):
class RectangularBox:
def __init__(self, box=None):
pass

def get(self, key):
raise NotImplementedError

def __getitem__(self, key):
raise NotImplementedError

def set(self, key, value):
raise NotImplementedError

def __setitem__(self, key, value):
raise NotImplementedError

@property
def volume(self):
raise NotImplementedError

@property
def box(self):
raise NotImplementedError

@box.setter
def box(self, b):
raise NotImplementedError

def gather(self, positions, bonds, molec_idx):
raise NotImplementedError


class RectangularBox(Box):
def __init__(self, box=None):
Box.__init__(self)
self.__box = None
self.__nframes = 0

if box is not None:
self.box = box

self.__getters = {
"box": RectangularBox.box.__get__,
"volume": RectangularBox.volume.__get__,
}
self.__setters = {"box": RectangularBox.box.__set__}

def get(self, key):
return self[key]

def __getitem__(self, key):
if key not in self.__getters:
raise KeyError
return self.__getters[key](self)

def set(self, key, value):
self[key] = value

def __setitem__(self, key, value):
if key not in self.__setters:
raise KeyError
self.__setters[key](self, value)

@property
def volume(self):
return np.prod(self.box, axis=1)

@property
def box(self):
return self.__box
Expand Down Expand Up @@ -156,17 +97,13 @@ class TrajectoryData(object):

@staticmethod
def trajectories():
return ("position", "velocity")
return "position", "velocity"

def __init__(self, position=None, velocity=None):
self.__position = None
self.__velocity = None
self.__nframes = 0

if position is not None:
self.position = position
if velocity is not None:
self.velocity = velocity
self.__nframes = None
self.__natoms = None

self.__getters = {
"position": TrajectoryData.position.__get__,
Expand All @@ -178,22 +115,50 @@ def __init__(self, position=None, velocity=None):
"velocity": TrajectoryData.velocity.__set__,
}

def get(self, key):
return self[key]
# Consistency check
assert set(self.__getters.keys()) == set(self.__setters.keys())
assert set(self.__getters.keys()) == set(TrajectoryData.trajectories())

if position is not None:
self.position = position
if velocity is not None:
self.velocity = velocity

def __getitem__(self, key):
if key not in self.trajectories():
raise KeyError
return self.__getters[key](self)

def set(self, key, value):
self[key] = value

def __setitem__(self, key, value):
if key not in self.trajectories():
raise KeyError
self.__setters[key](self, value)

def __check_value(self, value, key: str) -> Optional[np.ndarray]:
value = np.array(value)
if value.ndim == 2:
# create 3-dimensional array
value = np.array([value])
if value.ndim != 3:
raise pv_error.InputError([key], "Expected 2- or 3-dimensional array.")
if self.__nframes is None:
self.__nframes = value.shape[0]
elif self.__nframes != value.shape[0]:
raise pv_error.InputError(
[key], "Expected equal number of frames as in all trajectories."
)
if self.__natoms is None:
self.__natoms = value.shape[1]
elif self.__natoms != value.shape[1]:
raise pv_error.InputError(
[key], "Expected equal number of atoms as in all trajectories."
)
if value.shape[2] != 3:
raise pv_error.InputError(
[key], "Expected 3 spatial dimensions (#frames x #atoms x 3)."
)
return value

@property
def position(self):
"""Get position"""
Expand All @@ -202,19 +167,7 @@ def position(self):
@position.setter
def position(self, pos):
"""Set position"""
pos = np.array(pos)
if pos.ndim == 2:
# create 3-dimensional array
pos = np.array([pos])
if pos.ndim != 3:
warnings.warn("Expected 2- or 3-dimensional array.")
if self.__nframes == 0 and self.__velocity is None:
self.__nframes = pos.shape[0]
elif self.__nframes != pos.shape[0]:
raise pv_error.InputError(
["pos"], "Expected equal number of frames as in velocity trajectory."
)
self.__position = pos
self.__position = self.__check_value(pos, "position")

@property
def velocity(self):
Expand All @@ -224,24 +177,7 @@ def velocity(self):
@velocity.setter
def velocity(self, vel):
"""Set velocity"""
vel = np.array(vel)
if vel.ndim == 2:
# create 3-dimensional array
vel = np.array([vel])
if vel.ndim != 3:
warnings.warn("Expected 2- or 3-dimensional array.")
if self.__nframes == 0 and self.__position is None:
self.__nframes = vel.shape[0]
elif self.__nframes != vel.shape[0]:
raise pv_error.InputError(
["vel"], "Expected equal number of frames as in position trajectory."
)
self.__velocity = vel

@property
def nframes(self):
"""Get number of frames"""
return self.__nframes
self.__velocity = self.__check_value(vel, "velocity")

def __eq__(self, other):
if type(other) is not type(self):
Expand Down
1 change: 1 addition & 0 deletions physical_validation/tests/test_data_gromacs_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_gromacs_topology_exception() -> None:
)

@staticmethod
@pytest.mark.filterwarnings("ignore:NVT with undefined volume")
def test_gromacs_topology_with_bonds() -> None:
r"""
Check that GROMACS parser reads a system with bonds and angle
Expand Down
6 changes: 3 additions & 3 deletions physical_validation/tests/test_data_observable_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# #
###########################################################################
r"""
This file contains tests for the `physical_validation.util.util` module.
This file contains tests for the `physical_validation.data.observable_data` module.
"""
import numpy as np
import pytest
Expand All @@ -22,7 +22,7 @@

def test_observable_data_getters_and_setters() -> None:

# Check that newly create observable data object has None
# Check that newly created observable data object has `None` frames
observable_data = ObservableData()
assert observable_data.nframes is None

Expand Down Expand Up @@ -89,7 +89,7 @@ def test_observable_data_getters_and_setters() -> None:

# Check that all observables can be read in two ways
for observable in ObservableData.observables():
observable_data.set(observable, np.random.random(num_frames))
observable_data[observable] = np.random.random(num_frames)
assert np.array_equal(
observable_data.kinetic_energy, observable_data["kinetic_energy"]
)
Expand Down
63 changes: 63 additions & 0 deletions physical_validation/tests/test_data_trajectory_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
###########################################################################
# #
# physical_validation, #
# a python package to test the physical validity of MD results #
# #
# Written by Pascal T. Merz <pascal.merz@me.com> #
# Michael R. Shirts <michael.shirts@colorado.edu> #
# #
# Copyright (c) 2017-2021 University of Colorado Boulder #
# (c) 2012 The University of Virginia #
# #
###########################################################################
r"""
This file contains tests for the `physical_validation.data.trajectory` module.
"""
import numpy as np
import pytest

from ..data.trajectory_data import TrajectoryData
from ..util import error as pv_error


def test_trajectory_data_getters_and_setters() -> None:

num_frames = 3
num_atoms = 5
position = np.random.random((num_frames, num_atoms, 3))
velocity = np.random.random((num_frames, num_atoms, 3))

# Check that objects created and populated in different ways are equivalent
trajectory_data_1 = TrajectoryData()
trajectory_data_1["position"] = position
trajectory_data_1["velocity"] = velocity

trajectory_data_2 = TrajectoryData()
trajectory_data_2.velocity = velocity
trajectory_data_2.position = position

trajectory_data_3 = TrajectoryData(position=position, velocity=velocity)

trajectory_data_2 = TrajectoryData()
trajectory_data_2.velocity = velocity
trajectory_data_2.position = position

assert trajectory_data_1 == trajectory_data_2
assert trajectory_data_1 == trajectory_data_3

# Check setter error messages
with pytest.raises(pv_error.InputError):
# different number of frames
trajectory_data_1.position = np.random.random((num_frames - 1, num_atoms, 3))
with pytest.raises(pv_error.InputError):
# different number of atoms
trajectory_data_1.position = np.random.random((num_frames, num_atoms + 1, 3))
with pytest.raises(pv_error.InputError):
# different number of spatial dimensions
trajectory_data_1.position = np.random.random((num_frames, num_atoms, 2))
with pytest.raises(pv_error.InputError):
# extra dimension in input
trajectory_data_1.position = np.random.random((1, num_frames, num_atoms, 3))
with pytest.raises(pv_error.InputError):
# missing dimension in input
trajectory_data_1.position = np.random.random(num_frames)

0 comments on commit 1623f3f

Please sign in to comment.