Skip to content

Commit

Permalink
Finished refactoring all unit tests for ensemble_EXE.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Mar 3, 2023
1 parent 8c8ed7e commit 09adb37
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 37 deletions.
2 changes: 1 addition & 1 deletion docs/theory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ Below we elaborate the details of each step carried out by our method.

Step 1: Convert the weights into probabilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
For weight :math:`g_ij` that corresponds to state :math:`j` in replica :math:`i`, we can calculate its
For weight :math:`g_{ij}` that corresponds to state :math:`j` in replica :math:`i`, we can calculate its
corresopnding probability as follows:

.. math::
Expand Down
42 changes: 34 additions & 8 deletions ensemble_md/ensemble_EXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,38 @@
class EnsembleEXE:
"""
This class provides a variety of functions useful for setting up and running
an ensemble of expanded ensemble. Below is a list of attributes of the class:
:ivar attribute1: Description of attribute1.
:ivar attribute2: Description of attribute2.
:ivar attribute3: Description of attribute3.
an ensemble of expanded ensemble. Upon instantiation, all parameters in the YAML
file will be assigned to an attribute in the class. In addition to these variables,
below is a list of attributes of the class. (All the the attributes are assigned by
:obj:`set_params` unless otherwise noted.)
:ivar yaml: The input YAML file used to instantiate the class. Assigned by the :code:`__init__` function.
:ivar warnings: Warnings about parameter specification in either YAML or MDP files.
:ivar reformatted_mdp: Whether the templated MDP file has been reformatted by replacing hyphens
with underscores or not.
:ivar template: The instance of the :obj:`MDP` class based on the template MDP file.
:ivar nsteps: The number of steps per iteration.
:ivar dt: The simulation timestep in ps.
:ivar temp: The simulation temperature in Kelvin.
:ivar fixed_weights: Whether the weights will be fixed during the simulation (according to the template MDP file).
:ivar kT: 1 kT in kJ/mol at the simulation temperature.
:ivar lambda_types: The types of lambda variables involved in expanded ensemble simulations, e.g.
:code:`fep_lambdas`, :code:`mass_lambdas`, :code:`coul_lambdas`, etc.
:ivar n_tot: The total number of states for all replicas.
:ivar n_sub: The numbmer of states for each replica. The current implementation assumes homogenous replicas.
:ivar state_ranges: A list of list of state indices for each replica.
:ivar equil: A list of times it took to equilibrated the weights for different replicas.
:ivar lambda_dict: A dictionary with keys being the tuples of coupling parameters used in each replicas and
values being the corresponding global index (starting from 0). Assigned by :obj:`map_lambda2state`.
:ivar lambda_ranges: A list of lambda vectors of state range of each replica. Assigned by :obj:`map_lambda2state`.
:ivar n_rejected: The number of proposed exchanges that have been rejected. Updated by :obj:`accept_or_reject`.
:ivar n_swap_attempts: The number of swaps attempted so far. This does not include the cases
where there is no swappable pair. Updated by :obj:`get_swapping_pattern`.
:ivar rep_trajs: The replica-space trajectories of all replicas.
:ivar get_u_nk: Whether to get the :math:`u_{nk}` dataset from the DHDL files. Only meaningful during
data analysis and if :code:`df_method` is specified.
:ivar get_dHdl: Whether to get the :math:`dH/dλ` dataset from the DHDL files. Only meaningful
during data analysis and if :code:`df_method` is specified.
"""

def __init__(self, yaml_file):
Expand Down Expand Up @@ -263,7 +290,7 @@ def set_params(self):

# 6-10. The time series of the (processed) whole-range alchemical weights
# If no weight combination is applied, self.g_vecs will just be a list of None's.
self.g_vecs = []
# self.g_vecs = []

# 6-11. Data analysis
if self.df_method == 'MBAR':
Expand Down Expand Up @@ -550,8 +577,6 @@ def propose_swaps(self, states):
n_ex = len(swappables) ** 3
else:
n_ex = self.n_ex

self.n_swap_attempts += n_ex
print(f"Swappable pairs: {swappables}")

try:
Expand Down Expand Up @@ -596,6 +621,7 @@ def get_swapping_pattern(self, swap_list, dhdl_files, states, lambda_vecs, weigh
swapping, simulations/replicas with indices 0, 1, 2, and 3 should be in configurations 0, 1, 3,
respectively.
"""
self.n_swap_attempts += len(swap_list)
swap_pattern = list(range(self.n_sim)) # Can be regarded as the indices of DHDL files/configurations
if swap_list is []:
print('No swap is proposed because there is no swappable pair at all.')
Expand Down
72 changes: 44 additions & 28 deletions ensemble_md/tests/test_ensemble_EXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
import os
import sys
import yaml
import random
import shutil
import pytest
import numpy as np
import ensemble_md
import gmxapi as gmx
from mpi4py import MPI
from ensemble_md.utils import gmx_parser
from ensemble_md.ensemble_EXE import EnsembleEXE
from ensemble_md.utils.exceptions import ParameterError
Expand Down Expand Up @@ -186,6 +189,7 @@ def test_set_params(self, params_dict):
assert EEXE.fixed_weights is False

# 4. Checked the derived parameters
# Note that lambda_dict and lambda_ranges will also be tested in test_map_lambda2state.
k = 1.380649e-23
NA = 6.0221408e23
assert EEXE.kT == k * NA * 298 / 1000
Expand All @@ -201,7 +205,6 @@ def test_set_params(self, params_dict):
assert EEXE.n_rejected == 0
assert EEXE.n_swap_attempts == 0
assert EEXE.rep_trajs == [[0], [1], [2], [3]]
assert EEXE.g_vecs == []

params_dict['df_method'] = 'MBAR'
EEXE = get_EEXE_instance(params_dict)
Expand Down Expand Up @@ -432,13 +435,11 @@ def test_extract_final_log_info(self, params_dict):
[0, 0, 0, 1, 18, 31], ]
assert EEXE.equil == [-1, -1, -1, -1]


"""
def test_propose_swaps(self):
def test_propose_swaps(self, params_dict):
random.seed(0)
EEXE.n_sim = 4
EEXE = get_EEXE_instance(params_dict)
EEXE.state_ranges = [list(range(i, i + 5)) for i in range(EEXE.n_sim)] # 5 states per replica
states = [5, 2, 2, 7] # This would lead to the swappables: [(0, 1), (0, 2), (1, 2)]
states = [4, 2, 2, 7] # This would lead to the swappables: [(0, 1), (0, 2), (1, 2)]

# Case 1: Neighboring swapping (n_ex = 0 --> swappables = [(0, 1), (1, 2)])
EEXE.n_ex = 0
Expand All @@ -455,7 +456,8 @@ def test_propose_swaps(self):
swap_list = EEXE.propose_swaps(states)
assert swap_list == []

def test_get_swapped_configs(self):
def test_get_swapping_pattern(self, params_dict):
EEXE = get_EEXE_instance(params_dict)
EEXE.state_ranges = [
[0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 6],
Expand All @@ -468,21 +470,27 @@ def test_get_swapped_configs(self):
[0, 1.22635, 2.30707, 2.44120, 4.10308, 6.03106],
[0, 0.66431, 1.25475, 1.24443, 0.59472, 0.70726], # the 4th prob was ajusted (from 0.24443) to tweak prob_acc # noqa: E501
[0, 0.09620, 1.59937, -4.31679, -22.89436, -28.08701]]
dhdl_files = [os.path.join(input_path, f"dhdl_{i}.xvg") for i in range(4)]
dhdl_files = [os.path.join(input_path, f"dhdl/dhdl_{i}.xvg") for i in range(4)]
EEXE.mc_scheme = "metropolis"

# Case 1: Empty swap list
swap_list = []
configs_1 = EEXE.get_swapping_pattern(swap_list, dhdl_files, states, lambda_vecs, weights)
# When counting n_swap_attempts and n_rejected, we do not consider cases where swap_list is empty.
assert EEXE.n_swap_attempts == 0
assert EEXE.n_rejected == 0
assert configs_1 == [0, 1, 2, 3]

# Case 2: Multiple swaps
swap_list = [(0, 2) for i in range(5)] # prob_acc should be around 0.516
random.seed(0) # r1 = 0.844, r2 = 0.758, r3=0.421, r4=0.259 r5=0.511 --> 3 accepted moves --> [2, 1, 0, 3]
configs_2 = EEXE.get_swapping_pattern(swap_list, dhdl_files, states, lambda_vecs, weights)
assert EEXE.n_swap_attempts == 5
assert EEXE.n_rejected == 2
assert configs_2 == [2, 1, 0, 3]

def test_calc_prob_acc(self):
def test_calc_prob_acc(self, params_dict):
EEXE = get_EEXE_instance(params_dict)
EEXE.state_ranges = [
[0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 6],
Expand All @@ -495,7 +503,7 @@ def test_calc_prob_acc(self):
[0, 1.22635, 2.30707, 2.44120, 4.10308, 6.03106],
[0, 0.66431, 1.25475, 0.24443, 0.59472, 0.70726],
[0, 0.09620, 1.59937, -4.31679, -22.89436, -28.08701]]
dhdl_files = [os.path.join(input_path, f"dhdl_{i}.xvg") for i in range(4)]
dhdl_files = [os.path.join(input_path, f"dhdl/dhdl_{i}.xvg") for i in range(4)]

# Test 1: Same-state swapping (True)
swap = (1, 2)
Expand All @@ -518,19 +526,25 @@ def test_calc_prob_acc(self):
swap = (0, 2)
EEXE.mc_scheme = "metropolis"
prob_acc_4 = EEXE.calc_prob_acc(swap, dhdl_files, states, lambda_vecs, weights)
# dH ~-1.67 kT as calculated above, dg = (2.55736 - 6.13408) + (0.24443 - 0) ~ -3.33229 kT
# dU - dg ~ 1.66212 kT, so p_acc ~ 0.189 ...
assert prob_acc_4 == pytest.approx(0.18989559074633955) # check this number again

def test_accept_or_reject(self):
def test_accept_or_reject(self, params_dict):
EEXE = get_EEXE_instance(params_dict)
random.seed(0)
swap_bool_1 = EEXE.accept_or_reject(0)
swap_bool_2 = EEXE.accept_or_reject(0.8) # rand = 0.844
swap_bool_3 = EEXE.accept_or_reject(0.8) # rand = 0.758

assert EEXE.n_swap_attempts == 0 # since we didn't use get_swapping_pattern
assert EEXE.n_rejected == 2
assert swap_bool_1 is False
assert swap_bool_2 is False
assert swap_bool_3 is True

def test_historgam_correction(self):
def test_historgam_correction(self, params_dict):
EEXE = get_EEXE_instance(params_dict)
# Case 1: No histogram correction
EEXE.N_cutoff = -1
weights_1 = [[0, 10.304, 20.073, 29.364]]
Expand All @@ -539,7 +553,6 @@ def test_historgam_correction(self):
assert weights_1 == [[0, 10.304, 20.073, 29.364]]

# Case 2: Perform histogram correction (N_cutoff reached)
EEXE.verbose = False
EEXE.N_cutoff = 5000
weights_1 = EEXE.histogram_correction(weights_1, counts_1)
assert np.allclose(weights_1, [
Expand All @@ -551,50 +564,54 @@ def test_historgam_correction(self):
]
]) # noqa: E501

# Case 3: Perform histogram correction (N_cutoff not reached)
EEXE.verbose = True
# Case 3: Perform histogram correction (N_cutoff not reached by both N_k and N_{k-1})
weights_2 = [[0, 10.304, 20.073, 29.364]]
counts_2 = [[3141, 4570, 5545, 5955]]
weights_2 = EEXE.histogram_correction(weights_2, counts_2)
assert np.allclose(weights_2, [[0, 10.304, 20.073, 29.364 + np.log(5545 / 5955)]])

def test_combine_weights(self):
def test_combine_weights(self, params_dict):
EEXE = get_EEXE_instance(params_dict)
EEXE.n_tot = 6
EEXE.n_sub = 4
EEXE.s = 1
EEXE.n_sim = 3
EEXE.state_ranges = [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]]
weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]]

# Method: None
w1, g_vec_1 = EEXE.combine_weights(weights, method=None)
w2, g_vec_2 = EEXE.combine_weights(weights, method='mean')
EEXE.verbose = False # just to reach some print statementss with verbose = False
w3, g_vec_3 = EEXE.combine_weights(weights, method='geo-mean')
w4, g_vec_4 = EEXE.combine_weights(weights, method='g-diff')
assert np.allclose(w1, weights)
assert g_vec_1 is None

# Method: mean
w2, g_vec_2 = EEXE.combine_weights(weights, method='mean')
assert np.allclose(w2, [
[0.0, 2.20097, 3.99803, 3.59516],
[0.0, 1.79706, 1.39419, 2.69607],
[0.0, -0.40286, 0.89901, 1.88303]])
assert np.allclose(list(g_vec_2), [0.0, 2.200968785917372, 3.9980269151210854, 3.5951633659351256, 4.897041830662871, 5.881054277773005]) # noqa: E501

# Method: geo-mean
w3, g_vec_3 = EEXE.combine_weights(weights, method='geo-mean')
assert np.allclose(w3, [
[0.0, 2.2, 3.98889, 3.58889],
[0.0, 1.78889, 1.38889, 2.68333],
[0.0, -0.4, 0.89444, 1.87778]])
assert np.allclose(list(g_vec_3), [0.0, 2.1999999999999997, 3.9888888888888885, 3.5888888888888886, 4.883333333333334, 5.866666666666667]) # noqa: E501

# Method: g-diff
w4, g_vec_4 = EEXE.combine_weights(weights, method='g-diff')
assert np.allclose(w4, [
[0, 2.1, 3.9, 3.5],
[0, 1.8, 1.4, 2.75],
[0, -0.4, 0.95, 1.95]])
assert g_vec_1 is None
assert np.allclose(list(g_vec_2), [0.0, 2.200968785917372, 3.9980269151210854, 3.5951633659351256, 4.897041830662871, 5.881054277773005]) # noqa: E501
assert np.allclose(list(g_vec_3), [0.0, 2.1999999999999997, 3.9888888888888885, 3.5888888888888886, 4.883333333333334, 5.866666666666667]) # noqa: E501
assert np.allclose(list(g_vec_4), [0, 2.1, 3.9, 3.5, 4.85, 5.85])

def test_run_EEXE(self):
def test_run_EEXE(self, params_dict):
# We probably can only test serial EEXE
rank = MPI.COMM_WORLD.Get_rank()
EEXE = EnsembleEXE('ensemble_md/tests/data/params.yaml')
EEXE = get_EEXE_instance(params_dict)
if rank == 0:
for i in range(EEXE.n_sim):
os.mkdir(f'sim_{i}')
Expand All @@ -610,4 +627,3 @@ def test_run_EEXE(self):
if rank == 0:
os.system('rm -r sim_*')
os.system('rm -r gmxapi.commandline.cli*_i0*')
"""

0 comments on commit 09adb37

Please sign in to comment.