Skip to content

Commit

Permalink
Introduce type hints, remove defaults for internal functions (#151)
Browse files Browse the repository at this point in the history
Introduce typing hints throughout the code
Remove default values in the non-user facing functions (closes #116)
Fix a bug using the default kB for equipartition in all cases

* Introduce type hints and remove low-level default arguments for kinetic energy module
* Equipartition was using default kB in all cases. This was fixed, now adapt regression test values.
* Introduce type hints and remove low-level default arguments for ensemble module
* Introduce type hints and remove low-level default arguments for integrator module
* Add remaining type hints
  • Loading branch information
ptmerz committed May 9, 2021
1 parent 39c367a commit ea66536
Show file tree
Hide file tree
Showing 23 changed files with 934 additions and 854 deletions.
21 changes: 11 additions & 10 deletions physical_validation/data/ensemble_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Data structures carrying simulation data.
"""
import warnings
from typing import Tuple

from ..util import error as pv_error

Expand All @@ -33,8 +34,8 @@ class EnsembleData(object):
"""

@staticmethod
def ensembles():
return ("NVE", "NVT", "NPT", "muVT")
def ensembles() -> Tuple[str, str, str, str]:
return "NVE", "NVT", "NPT", "muVT"

def __init__(
self,
Expand Down Expand Up @@ -100,41 +101,41 @@ def __init__(
self.__t = temperature

@property
def ensemble(self):
def ensemble(self) -> str:
"""Get ensemble"""
return self.__ensemble

@property
def natoms(self):
def natoms(self) -> int:
"""Get natoms"""
return self.__n

@property
def mu(self):
def mu(self) -> float:
"""Get mu"""
return self.__mu

@property
def volume(self):
def volume(self) -> float:
"""Get volume"""
return self.__v

@property
def pressure(self):
def pressure(self) -> float:
"""Get pressure"""
return self.__p

@property
def energy(self):
def energy(self) -> float:
"""Get energy"""
return self.__e

@property
def temperature(self):
def temperature(self) -> float:
"""Get temperature"""
return self.__t

def __eq__(self, other):
def __eq__(self, other) -> bool:
if type(other) is not type(self):
return False
return (
Expand Down
44 changes: 27 additions & 17 deletions physical_validation/data/flatfile_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
r"""
flatfile_parser.py
"""
from . import ObservableData, SimulationData, TrajectoryData, parser
from typing import List, Optional

from . import (
EnsembleData,
ObservableData,
SimulationData,
SystemData,
TrajectoryData,
UnitData,
parser,
)


class FlatfileParser(parser.Parser):
Expand All @@ -26,20 +36,20 @@ def __init__(self):

def get_simulation_data(
self,
units=None,
ensemble=None,
system=None,
dt=None,
position_file=None,
velocity_file=None,
kinetic_ene_file=None,
potential_ene_file=None,
total_ene_file=None,
volume_file=None,
pressure_file=None,
temperature_file=None,
const_of_mot_file=None,
):
units: Optional[UnitData] = None,
ensemble: Optional[EnsembleData] = None,
system: Optional[SystemData] = None,
dt: Optional[float] = None,
position_file: Optional[str] = None,
velocity_file: Optional[str] = None,
kinetic_ene_file: Optional[str] = None,
potential_ene_file: Optional[str] = None,
total_ene_file: Optional[str] = None,
volume_file: Optional[str] = None,
pressure_file: Optional[str] = None,
temperature_file: Optional[str] = None,
const_of_mot_file: Optional[str] = None,
) -> SimulationData:
r"""Read simulation data from flat files
Returns a SimulationData object created from (optionally) provided UnitData, EnsembleData
Expand Down Expand Up @@ -136,7 +146,7 @@ def get_simulation_data(
return result

@staticmethod
def __read_xyz(filename):
def __read_xyz(filename: str) -> List[List[List[float]]]:
result = []
with open(filename) as f:
frame = []
Expand All @@ -159,7 +169,7 @@ def __read_xyz(filename):
return result

@staticmethod
def __read_1d(filename):
def __read_1d(filename: str) -> List[float]:
result = []
with open(filename) as f:
for line in f:
Expand Down
22 changes: 17 additions & 5 deletions physical_validation/data/gromacs_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
gromacs_parser.py
"""
import warnings
from typing import List, Optional, Union

import numpy as np

Expand All @@ -37,7 +38,7 @@ class GromacsParser(parser.Parser):
"""

@staticmethod
def units():
def units() -> UnitData:
# Gromacs uses kJ/mol
return UnitData(
kb=8.314462435405199e-3,
Expand All @@ -55,7 +56,9 @@ def units():
time_conversion=1.0,
)

def __init__(self, exe=None, includepath=None):
def __init__(
self, exe: Optional[str] = None, includepath: Union[str, List[str]] = None
):
r"""
Create a GromacsParser object
Expand Down Expand Up @@ -86,7 +89,14 @@ def __init__(self, exe=None, includepath=None):
"constant_of_motion": "Conserved-En.",
}

def get_simulation_data(self, mdp=None, top=None, edr=None, trr=None, gro=None):
def get_simulation_data(
self,
mdp: Optional[str] = None,
top: Optional[str] = None,
edr: Optional[str] = None,
trr: Optional[str] = None,
gro: Optional[str] = None,
) -> SimulationData:
r"""
Parameters
Expand Down Expand Up @@ -134,7 +144,9 @@ def get_simulation_data(self, mdp=None, top=None, edr=None, trr=None, gro=None):
).any():
raise NotImplementedError("Triclinic boxes not implemented.")
else:
box = RectangularBox([np.diag(b) for b in trajectory_dict["box"]])
box = RectangularBox(
np.array([np.diag(b) for b in trajectory_dict["box"]])
)
else:
raise RuntimeError("Unknown box shape.")
trajectory_dict["box"] = box
Expand Down Expand Up @@ -335,7 +347,7 @@ def get_simulation_data(self, mdp=None, top=None, edr=None, trr=None, gro=None):

if edr is not None:
observable_dict = self.__interface.get_quantities(
edr, self.__gmx_energy_names.values(), args=["-dp"]
edr, list(self.__gmx_energy_names.values()), args=["-dp"]
)

# constant volume simulations don't write out the volume in .edr file
Expand Down
57 changes: 40 additions & 17 deletions physical_validation/data/lammps_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
r"""
lammps_parser.py
"""
from typing import Dict, List, Optional, Union

import numpy as np

from ..util import error as pv_error
from . import (
EnsembleData,
ObservableData,
SimulationData,
SystemData,
Expand Down Expand Up @@ -86,20 +89,25 @@ def __init__(self):
)

def get_simulation_data(
self, ensemble=None, in_file=None, log_file=None, data_file=None, dump_file=None
):
self,
ensemble: Optional[EnsembleData] = None,
in_file: Optional[str] = None,
log_file: Optional[str] = None,
data_file: Optional[str] = None,
dump_file: Optional[str] = None,
) -> SimulationData:
"""
Parameters
----------
ensemble: EnsembleData, optional
in_file: str, optional
log_file: str, optional
data_file: str, optional
dump_file: str, optional
ensemble
in_file
log_file
data_file
dump_file
Returns
-------
result: SimulationData
SimulationData
"""

Expand Down Expand Up @@ -195,7 +203,9 @@ def get_simulation_data(
return result

@staticmethod
def __read_input_file(name):
def __read_input_file(
name: str,
) -> Dict[str, Union[str, List[str], List[Dict[str, Union[str, List[str]]]]]]:
# parse input file
input_dict = {}
with open(name) as f:
Expand Down Expand Up @@ -235,7 +245,16 @@ def __read_input_file(name):
return input_dict

@staticmethod
def __read_data_file(name):
def __read_data_file(
name: str,
) -> Dict[
str,
Union[
Dict[str, Union[float, List[float]]],
Dict[int, List[Union[str, float]]],
List[Dict[str, Union[int, float]]],
],
]:
# > available blocks
blocks = [
"Header", # 0
Expand Down Expand Up @@ -298,7 +317,9 @@ def __read_data_file(name):
]
header_double = ["xlo xhi", "ylo yhi", "zlo zhi"]
# default values
# Dict[str, float]
data_dict[block] = {hs: 0 for hs in header_single}
# Dict[str, List[float]]
data_dict[block].update({hd: [0.0, 0.0] for hd in header_double})
# read out
for line in file_blocks[block]:
Expand All @@ -322,6 +343,7 @@ def __read_data_file(name):
data_dict[block] = {}
for line in file_blocks[block]:
line = line.split()
# Dict[int, List[Union[str, float]]]
data_dict[block][int(line[0])] = [line[1]] + [
float(c) for c in line[2:]
]
Expand All @@ -332,6 +354,7 @@ def __read_data_file(name):
data_dict[block] = []
for line in file_blocks[block]:
line = line.split()
# List[Dict[str, Union[int, float]]]
if len(line) == 7:
data_dict[block].append(
{
Expand Down Expand Up @@ -392,9 +415,9 @@ def __read_data_file(name):
return data_dict

@staticmethod
def __read_log_file(name):
def __read_log_file(name: str) -> Dict[str, List[float]]:
# parse log file
def start_single(line1, line2):
def start_single(line1: str, line2: str) -> bool:
if not line1.split():
return False
if len(line1.split()) != len(line2.split()):
Expand All @@ -405,7 +428,7 @@ def start_single(line1, line2):
return False
return True

def end_single(line, length):
def end_single(line: str, length: int) -> bool:
if len(line.split()) != length:
return True
try:
Expand All @@ -414,12 +437,12 @@ def end_single(line, length):
return True
return False

def start_multi(line):
def start_multi(line: str) -> bool:
if "---- Step" in line and "- CPU =" in line:
return True
return False

def end_multi(line):
def end_multi(line: str) -> bool:
line = line.split()
# right length (is it actually always 9??)
if len(line) == 0 or len(line) % 3 != 0:
Expand Down Expand Up @@ -491,13 +514,13 @@ def end_multi(line):
return ene_traj

@staticmethod
def __read_dump_file(name):
def __read_dump_file(name: str) -> Dict[str, List[List[Union[float, List[float]]]]]:
# parse dump file
# the dictionary to be filled
dump_dict = {"position": [], "velocity": [], "box": []}

# helper function checking line items
def check_item(line_str, item):
def check_item(line_str: str, item: str) -> str:
item = "ITEM: " + item
if not line_str.startswith(item):
raise pv_error.FileFormatError(name, "dump file: was expecting " + item)
Expand Down

0 comments on commit ea66536

Please sign in to comment.