Skip to content

Commit

Permalink
Improved the docstrings for utils.py and developed test_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Feb 16, 2023
1 parent 9b5a74e commit d83e5ae
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 19 deletions.
9 changes: 8 additions & 1 deletion docs/requirements.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@ dependencies:
- pandoc
- ipykernel
- pip


# System package dependencies
- apt:
- libclang
- cmake
- mpich
- libmpich-dev

# Pip-only installs
- pip:
- sphinx
Expand Down
2 changes: 1 addition & 1 deletion ensemble_md/analysis/analyze_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def parse_transmtx(log_file, expanded_ensemble=True):

def calc_equil_prob(trans_mtx):
"""
(**TODO**: Consider using PyEMMA instead.) Calculates the equilibrium probability of each
Calculates the equilibrium probability of each
state from the state transition matrix. The input state transition matrix can be either
left or right stochastic, although the left stochastic ones are not common in GROMACS.
Generally, transition matrices in GROMACS are either doubly stochastic (replica exchange),
Expand Down
2 changes: 1 addition & 1 deletion ensemble_md/cli/run_EEXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main():

if len(EEXE.warnings) > args.maxwarn:
raise ParameterError(
f"The execution failed due to warning(s) about parameter spcificaiton. Consider setting maxwarn in the input YAML file if you want to ignore them.") # noqa: E501, F541
f"The execution failed due to warning(s) about parameter spcificaiton. Check the warnings, or consider setting maxwarn in the input YAML file if you find them harmless.") # noqa: E501, F541

# Step 2: If there is no checkpoint file found/provided, perform the 1st iteration (index 0)
if os.path.isfile(args.ckpt) is False:
Expand Down
14 changes: 10 additions & 4 deletions ensemble_md/ensemble_EXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,13 @@ def __init__(self, yml_file):
else:
self.fixed_weights = True

if self.template['symmetrized_transition_matrix'] == 'yes':
if 'lmc_seed' in self.template and self.template['lmc_seed'] != -1:
self.warnings.append('Warning: We recommend setting lmc_seed as -1 so the random seed is different for each iteration.') # noqa: E501

if 'gen_seed' in self.template and self.template['gen_seed'] != -1:
self.warnings.append('Warning: We recommend setting gen_seed as -1 so the random seed is different for each iteration.') # noqa: E501

if 'symmetrized_transition_matrix' in self.template and self.template['symmetrized_transition_matrix'] == 'yes': # noqa: E501
self.warnings.append('Warning: We recommend setting symmetrized-transition-matrix to no instead of yes.')

if self.template['nstlog'] > self.nst_sim:
Expand Down Expand Up @@ -364,7 +370,7 @@ def update_MDP(self, new_template, sim_idx, iter_idx, states, wl_delta, weights)
MDP : gmx_parser.MDP obj
An updated object of gmx_parser.MDP that can be used to write MDP files.
"""
new_template = gmx_parser.MDP(new_template) # turn into a gmx_parser.MDP object
new_template = gmx_parser.MDP(new_template) # turn into a gmx_parser.MDP object
MDP = copy.deepcopy(new_template)
MDP["tinit"] = self.nst_sim * self.dt * iter_idx
MDP["nsteps"] = self.nst_sim
Expand Down Expand Up @@ -506,7 +512,7 @@ def propose_swaps(self, states):
swap_list = random.choices(swappables, k=n_ex)
except IndexError:
# In the case that swappables is an empty list, i.e. no swappable pairs.
swap_list = None
swap_list = []

return swap_list

Expand Down Expand Up @@ -540,7 +546,7 @@ def get_swapping_pattern(self, swap_list, dhdl_files, states, lambda_vecs, weigh
A list that represents how the replicas should be swapped.
"""
swap_pattern = list(range(self.n_sim)) # Can be regarded as the indices of dhdl files/configurations
if swap_list is None:
if swap_list is []:
print('No swap is proposed because there is no swappable pair at all.')
else:
for i in range(len(swap_list)):
Expand Down
4 changes: 2 additions & 2 deletions ensemble_md/tests/data/expanded.mdp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pcoupl = no

gen_vel = yes
gen_temp = 298
gen_seed = 6722267
gen_seed = -1

; options for bonds
constraints = h-bonds
Expand All @@ -88,7 +88,7 @@ nstdhdl = 10
dhdl_print_energy = total

; Seed for Monte Carlo in lambda space
lmc_seed = 1000
lmc_seed = -1
lmc_gibbsdelta = -1
lmc_forced_nstart = 0
symmetrized_transition_matrix = yes
Expand Down
2 changes: 1 addition & 1 deletion ensemble_md/tests/test_ensemble_EXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_propose_swaps(self):
# Case 3: Empty swappable list
states = [10, 10, 10, 10]
swap_list = EEXE.propose_swaps(states)
assert swap_list is None
assert swap_list == []

def test_gest_swapped_configus(self):
EEXE.state_ranges = [
Expand Down
82 changes: 82 additions & 0 deletions ensemble_md/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
####################################################################
# #
# ensemble_md, #
# a python package for running GROMACS simulation ensembles #
# #
# Written by Wei-Tse Hsu <wehs7661@colorado.edu> #
# Copyright (c) 2022 University of Colorado Boulder #
# #
####################################################################
"""
Unit tests for the module utils.py.
"""
import sys
import tempfile
import numpy as np
from ensemble_md.utils import utils


def test_logger():
# Create a temporary file for the log
with tempfile.TemporaryFile(mode="w+t") as log_file:
# Get the file path for the temporary file
log_path = log_file.name

# Create a logger that redirects output to the temporary file
logger = utils.Logger(log_path)

# Redirect stdout to the logger
sys.stdout = logger

# Write some messages to stdout
print("Hello, world!")
print("Testing logger...")

# Flush the logger to ensure that all messages are written to the log
logger.flush()

# Reset stdout to the original stream
sys.stdout = sys.__stdout__


def test_format_time():
assert utils.format_time(0) == "0.0 second(s)"
assert utils.format_time(1) == "1.0 second(s)"
assert utils.format_time(59) == "59.0 second(s)"
assert utils.format_time(60) == "1 minute(s) 0 second(s)"
assert utils.format_time(61) == "1 minute(s) 1 second(s)"
assert utils.format_time(3599) == "59 minute(s) 59 second(s)"
assert utils.format_time(3600) == "1 hour(s) 0 minute(s) 0 second(s)"
assert utils.format_time(3661) == "1 hour(s) 1 minute(s) 1 second(s)"
assert utils.format_time(86399) == "23 hour(s) 59 minute(s) 59 second(s)"
assert utils.format_time(86400) == "1 day, 0 hour(s) 0 minute(s) 0 second(s)"
assert utils.format_time(90061) == "1 day, 1 hour(s) 1 minute(s) 1 second(s)"


def test_autoconvert():
# Test non-string input
assert utils.autoconvert(42) == 42

# Test string input that can be converted to int
assert utils.autoconvert("42") == 42

# Test string input that can be converted to float
assert utils.autoconvert("3.14159") == 3.14159

# Test string input that can be converted to a numpy array of ints
assert np.array_equal(utils.autoconvert("1 2 3"), np.array([1, 2, 3]))

# Test string input that can be converted to a numpy array of floats
assert np.allclose(utils.autoconvert("1.0 2.0 3.0"), np.array([1.0, 2.0, 3.0]))


def test_get_subplot_dimension():
assert utils.get_subplot_dimension(1) == (1, 1)
assert utils.get_subplot_dimension(2) == (1, 2)
assert utils.get_subplot_dimension(3) == (2, 2)
assert utils.get_subplot_dimension(4) == (2, 2)
assert utils.get_subplot_dimension(5) == (2, 3)
assert utils.get_subplot_dimension(6) == (2, 3)
assert utils.get_subplot_dimension(7) == (3, 3)
assert utils.get_subplot_dimension(8) == (3, 3)
assert utils.get_subplot_dimension(9) == (3, 3)
61 changes: 54 additions & 7 deletions ensemble_md/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,50 @@
class Logger:
"""
Redirects the STDOUT to a specified output file while preserving STDOUT on screen.
Parameters
----------
logfile : str
Name of the output file to write the logged messages.
Attributes
----------
terminal : file object
The file object that represents the original STDOUT (i.e., the screen).
log : file object
The file object that represents the logfile where messages will be written.
"""

def __init__(self, logfile):
"""
Initializes a Logger instance.
Parameters
----------
logfile : str
Name of the output file to write the logged messages.
"""
self.terminal = sys.stdout
self.log = open(logfile, "a")

def write(self, message):
"""
Writes the given message to both the STDOUT and the logfile.
Parameters
----------
message : str
The message to be written to STDOUT and logfile.
"""
self.terminal.write(message)
self.log.write(message)

def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
"""
This method is needed for Python 3 compatibility.
This handles the flush command by doing nothing.
You might want to specify some extra behavior here.
"""
# self.terminal.log()
pass

Expand Down Expand Up @@ -113,7 +143,7 @@ def format_time(t):
if "day" in hh_mm_ss[0]:
# hh_mm_ss[0] will contain "day" and cannot be converted to float
hh, mm, ss = hh_mm_ss[0], float(hh_mm_ss[1]), float(hh_mm_ss[2])
t_str = f"{hh_mm_ss[0]} {hh} hour(s) {mm:.0f} minute(s) {ss:.0f} second(s)"
t_str = f"{hh} hour(s) {mm:.0f} minute(s) {ss:.0f} second(s)"
else:
hh, mm, ss = float(hh_mm_ss[0]), float(hh_mm_ss[1]), float(hh_mm_ss[2])
if hh == 0:
Expand All @@ -133,8 +163,24 @@ def autoconvert(s):
Modified from `utilities.py in GromacsWrapper <https://github.com/Becksteinlab/GromacsWrapper>`_.
Copyright (c) 2009 Oliver Beckstein <orbeckst@gmail.com>
- A non-string object is returned as it is
- Try conversion to int, float, str.
Parameters
----------
s : str or any
The input value to be converted to a numerical type if possible. If :code:`s` is not a string,
it is returned as is.
Returns
-------
numerical : int, float, numpy.ndarray, or any
The converted numerical value. If :code:`s` can be converted to a single numerical value,
that value is returned as an :code:`int` or :code:`float`. If :code:`s` can be converted to
multiple numerical values, a :code:`numpy.ndarray` containing those values is returned.
If :code:`s` cannot be converted to a numerical value, :code:`s` is returned as is.
Raises
------
ValueError
If :code:`s` cannot be converted to a numerical value.
"""
if type(s) is not str:
return s
Expand All @@ -152,7 +198,8 @@ def autoconvert(s):

def get_subplot_dimension(n_panels):
"""
Gets the numbers of rows and columns in a subplot.
Gets the numbers of rows and columns in a subplot such that
the arrangement of the .
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ build:
image: latest

python:
version: 3.8
version: 3.9
install:
- method: pip
path: .

conda:
environment: docs/requirements.yaml
environment: docs/requirements.yaml

0 comments on commit d83e5ae

Please sign in to comment.