Skip to content

Commit

Permalink
Clean up ObservableData (#145)
Browse files Browse the repository at this point in the history
Reduce code duplication
Add unit tests
  • Loading branch information
ptmerz committed Apr 29, 2021
1 parent f380afe commit 1ce021f
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 130 deletions.
6 changes: 0 additions & 6 deletions physical_validation/data/gromacs_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,6 @@ def get_simulation_data(self, mdp=None, top=None, edr=None, trr=None, gro=None):
volume = box[0] * box[1] * box[2]
elif box.ndim == 2:
volume = box[0, 0] * box[1, 1] * box[2, 2]
else:
warnings.warn(
"Constant volume simulation with undefined volume."
)
else:
warnings.warn("Constant volume simulation with undefined volume.")

if constant_temp and constant_press:
ens = "NPT"
Expand Down
166 changes: 50 additions & 116 deletions physical_validation/data/observable_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
Data structures carrying simulation data.
"""
import warnings
from typing import Optional

import numpy as np

import physical_validation.util.error as pv_error
from ..util import error as pv_error
from ..util.util import array_equal_shape_and_close


class ObservableData(object):
Expand Down Expand Up @@ -67,17 +69,8 @@ def __init__(
self.__pressure = None
self.__temperature = None
self.__constant_of_motion = None
self.__nframes = -1
self.__kinetic_energy_per_molec = None

self.kinetic_energy = kinetic_energy
self.potential_energy = potential_energy
self.total_energy = total_energy
self.volume = volume
self.pressure = pressure
self.temperature = temperature
self.constant_of_motion = constant_of_motion

self.__getters = {
"kinetic_energy": ObservableData.kinetic_energy.__get__,
"potential_energy": ObservableData.potential_energy.__get__,
Expand All @@ -98,6 +91,18 @@ def __init__(
"constant_of_motion": ObservableData.constant_of_motion.__set__,
}

# Consistency check
assert set(self.__getters.keys()) == set(self.__setters.keys())
assert set(self.__getters.keys()) == set(ObservableData.observables())

self.kinetic_energy = kinetic_energy
self.potential_energy = potential_energy
self.total_energy = total_energy
self.volume = volume
self.pressure = pressure
self.temperature = temperature
self.constant_of_motion = constant_of_motion

def get(self, key):
return self[key]

Expand All @@ -114,6 +119,18 @@ def __setitem__(self, key, value):
raise KeyError
self.__setters[key](self, value)

def __check_value(self, value, key: str) -> Optional[np.ndarray]:
if value is None:
return None
value = np.array(value)
if value.ndim != 1:
raise pv_error.InputError(key, "Expected 1-dimensional array.")
if self.nframes is not None and self.nframes != value.size:
warnings.warn(
"ObservableData: Mismatch in number of frames. Setting `nframes = None`."
)
return value

@property
def kinetic_energy(self):
"""Get kinetic_energy"""
Expand All @@ -122,18 +139,7 @@ def kinetic_energy(self):
@kinetic_energy.setter
def kinetic_energy(self, kinetic_energy):
"""Set kinetic_energy"""
if kinetic_energy is None:
self.__kinetic_energy = None
return
kinetic_energy = np.array(kinetic_energy)
if kinetic_energy.ndim != 1:
raise pv_error.InputError("kinetic_energy", "Expected 1-dimensional array.")
if self.nframes == -1:
self.__nframes = kinetic_energy.size
elif self.nframes != kinetic_energy.size:
warnings.warn("Mismatch in number of frames. Setting `nframes = None`.")
self.__nframes = None
self.__kinetic_energy = kinetic_energy
self.__kinetic_energy = self.__check_value(kinetic_energy, "kinetic_energy")

@property
def potential_energy(self):
Expand All @@ -143,20 +149,9 @@ def potential_energy(self):
@potential_energy.setter
def potential_energy(self, potential_energy):
"""Set potential_energy"""
if potential_energy is None:
self.__potential_energy = None
return
potential_energy = np.array(potential_energy)
if potential_energy.ndim != 1:
raise pv_error.InputError(
"potential_energy", "Expected 1-dimensional array."
)
if self.nframes == -1:
self.__nframes = potential_energy.size
elif self.nframes != potential_energy.size:
warnings.warn("Mismatch in number of frames. Setting `nframes = None`.")
self.__nframes = None
self.__potential_energy = potential_energy
self.__potential_energy = self.__check_value(
potential_energy, "potential_energy"
)

@property
def total_energy(self):
Expand All @@ -166,18 +161,7 @@ def total_energy(self):
@total_energy.setter
def total_energy(self, total_energy):
"""Set total_energy"""
if total_energy is None:
self.__total_energy = None
return
total_energy = np.array(total_energy)
if total_energy.ndim != 1:
raise pv_error.InputError("total_energy", "Expected 1-dimensional array.")
if self.nframes == -1:
self.__nframes = total_energy.size
elif self.nframes != total_energy.size:
warnings.warn("Mismatch in number of frames. Setting `nframes = None`.")
self.__nframes = None
self.__total_energy = total_energy
self.__total_energy = self.__check_value(total_energy, "total_energy")

@property
def volume(self):
Expand All @@ -187,18 +171,7 @@ def volume(self):
@volume.setter
def volume(self, volume):
"""Set volume"""
if volume is None:
self.__volume = None
return
volume = np.array(volume)
if volume.ndim != 1:
raise pv_error.InputError("volume", "Expected 1-dimensional array.")
if self.nframes == -1:
self.__nframes = volume.size
elif self.nframes != volume.size:
warnings.warn("Mismatch in number of frames. Setting `nframes = None`.")
self.__nframes = None
self.__volume = volume
self.__volume = self.__check_value(volume, "volume")

@property
def pressure(self):
Expand All @@ -208,18 +181,7 @@ def pressure(self):
@pressure.setter
def pressure(self, pressure):
"""Set pressure"""
if pressure is None:
self.__pressure = None
return
pressure = np.array(pressure)
if pressure.ndim != 1:
raise pv_error.InputError("pressure", "Expected 1-dimensional array.")
if self.nframes == -1:
self.__nframes = pressure.size
elif self.nframes != pressure.size:
warnings.warn("Mismatch in number of frames. Setting `nframes = None`.")
self.__nframes = None
self.__pressure = pressure
self.__pressure = self.__check_value(pressure, "pressure")

@property
def temperature(self):
Expand All @@ -229,18 +191,7 @@ def temperature(self):
@temperature.setter
def temperature(self, temperature):
"""Set temperature"""
if temperature is None:
self.__temperature = None
return
temperature = np.array(temperature)
if temperature.ndim != 1:
raise pv_error.InputError("temperature", "Expected 1-dimensional array.")
if self.nframes == -1:
self.__nframes = temperature.size
elif self.nframes != temperature.size:
warnings.warn("Mismatch in number of frames. Setting `nframes = None`.")
self.__nframes = None
self.__temperature = temperature
self.__temperature = self.__check_value(temperature, "temperature")

@property
def constant_of_motion(self):
Expand All @@ -250,30 +201,25 @@ def constant_of_motion(self):
@constant_of_motion.setter
def constant_of_motion(self, constant_of_motion):
"""Set constant_of_motion"""
if constant_of_motion is None:
self.__constant_of_motion = None
return
constant_of_motion = np.array(constant_of_motion)
if constant_of_motion.ndim != 1:
raise pv_error.InputError(
"constant_of_motion", "Expected 1-dimensional array."
)
if self.nframes == -1:
self.__nframes = constant_of_motion.size
elif self.nframes != constant_of_motion.size:
warnings.warn("Mismatch in number of frames. Setting `nframes = None`.")
self.__nframes = None
self.__constant_of_motion = constant_of_motion
self.__constant_of_motion = self.__check_value(
constant_of_motion, "constant_of_motion"
)

@property
def nframes(self):
"""Get number of frames"""
if self.__nframes is None:
warnings.warn(
"A mismatch in the number of frames between observables "
"was detected. Setting `nframes = None`."
)
return self.__nframes
frames = None
for observable in ObservableData.observables():
if self.get(observable) is not None:
if frames is not None:
if self.get(observable).size == frames:
continue
else:
return None
else:
frames = self.get(observable).size

return frames

@property
def kinetic_energy_per_molecule(self):
Expand All @@ -283,23 +229,12 @@ def kinetic_energy_per_molecule(self):
@kinetic_energy_per_molecule.setter
def kinetic_energy_per_molecule(self, kinetic_energy):
"""Set kinetic_energy per molecule - used internally"""
if kinetic_energy is None:
self.__kinetic_energy_per_molec = None
return
# used internally - check for consistency?
self.__kinetic_energy_per_molec = kinetic_energy

def __eq__(self, other):
if type(other) is not type(self):
return False

def array_equal_shape_and_close(array1: np.ndarray, array2: np.ndarray):
if array1 is None and array2 is None:
return True
if array1.shape != array2.shape:
return False
return np.allclose(array1, array2, rtol=1e-12, atol=1e-12)

return (
array_equal_shape_and_close(self.__kinetic_energy, other.__kinetic_energy)
and array_equal_shape_and_close(
Expand All @@ -315,5 +250,4 @@ def array_equal_shape_and_close(array1: np.ndarray, array2: np.ndarray):
and array_equal_shape_and_close(
self.__kinetic_energy_per_molec, other.__kinetic_energy_per_molec
)
and self.__nframes == other.__nframes
)
10 changes: 2 additions & 8 deletions physical_validation/data/trajectory_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import numpy as np

import physical_validation.util.error as pv_error
from ..util import error as pv_error
from ..util.util import array_equal_shape_and_close


class Box(object):
Expand Down Expand Up @@ -246,13 +247,6 @@ def __eq__(self, other):
if type(other) is not type(self):
return False

def array_equal_shape_and_close(array1: np.ndarray, array2: np.ndarray):
if array1 is None and array2 is None:
return True
if array1.shape != array2.shape:
return False
return np.allclose(array1, array2, rtol=1e-12, atol=1e-12)

return (
array_equal_shape_and_close(self.__position, other.__position)
and array_equal_shape_and_close(self.__velocity, other.__velocity)
Expand Down

0 comments on commit 1ce021f

Please sign in to comment.