Skip to content

Commit

Permalink
Developed plot_state_hist in analyze_traj.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Feb 14, 2023
1 parent 2640f82 commit 9b5a74e
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 18 deletions.
3 changes: 2 additions & 1 deletion docs/requirements.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ dependencies:
- matplotlib
- pyemma
- pymbar==4.0.1
# - mpi4py
- mpi4py
- gmxapi==0.4.0
2 changes: 1 addition & 1 deletion ensemble_md/analysis/analyze_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def calc_spectral_gap(trans_mtx):
if np.isclose(eig_vals[0], 1, atol=1e-4) is False:
raise ParameterError(f'The largest eigenvalue of the input transition matrix {eig_vals[0]} is not close to 1.')

spectral_gap = eig_vals[0] - eig_vals[1]
spectral_gap = np.abs(eig_vals[0]) - np.abs(eig_vals[1])

return spectral_gap

Expand Down
61 changes: 55 additions & 6 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def plot_rep_trajs(trajs, fig_name, dt=None, stride=None):
plt.savefig(f'{fig_name}', dpi=600)


def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=None):
def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=1):
"""
Plots the time series of states visited by each configuration in a subplot.
Expand All @@ -187,13 +187,13 @@ def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=None):
trajs : list
A list of arrays that represent the state space trajectories of all configurations.
state_ranges : list
A list of sets of state indices. (Like the attribute :code:`state_ranges` in :code:`EnsemblEXE`.)
A list of lists of state indices. (Like the attribute :code:`state_ranges` in :code:`EnsemblEXE`.)
fig_name : str
The file name of the png file to be saved (with the extension).
dt : str or float
One trajectory timestep in ps. If None, it assumes there are no timeframes but MC steps.
stride : int
The stride for plotting the time series. The default is 100 if the length of
The stride for plotting the time series. The default is 10 if the length of
any trajectory has more than 100,000 frames. Otherwise, it will be 1. Typically
plotting more than 10 million frames can take a lot of memory.
"""
Expand All @@ -213,7 +213,7 @@ def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=None):

if stride is None:
if len(trajs[0]) > 100000:
stride = 100
stride = 10
else:
stride = 1

Expand All @@ -235,8 +235,14 @@ def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=None):
bounds[1] += 0.5
plt.fill_between(x_range, y1=bounds[1], y2=bounds[0], color=colors[j], alpha=0.1)

# Then plot the trajectories
plt.plot(x[::stride], trajs[i][::stride], color=colors[i])
if len(trajs[0]) > 100000:
linewidth = 0.01
else:
linewidth = 1 # this is the default

# Finally, plot the trajectories
linewidth = 1 # this is the default
plt.plot(x[::stride], trajs[i][::stride], color=colors[i], linewidth=linewidth)
if dt is None:
plt.xlabel('MC moves')
else:
Expand All @@ -259,6 +265,49 @@ def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=None):
plt.savefig(f'{fig_name}', dpi=600)


def plot_state_hist(trajs, state_ranges, fig_name):
"""
Plots the histograms of the state index for each configuration.
Parameters
----------
trajs : list
A list of arrays that represent the state space trajectories of all configurations.
state_ranges : list
A list of lists of state indices. (Like the attribute :code:`state_ranges` in :code:`EnsemblEXE`.)
fig_name : str
The file name of the png file to be saved (with the extension).
"""
n_configs = len(trajs)
cmap = plt.cm.ocean # other good options are CMRmap, gnuplot, terrain, turbo, brg, etc.
colors = [cmap(i) for i in np.arange(n_configs) / n_configs]

fig = plt.figure()
ax = fig.add_subplot(111)
lower_bound = min(trajs[0]) - 0.5
upper_bound = max(trajs[-1]) + 0.5
for i in range(len(trajs)):
plt.hist(trajs[i], np.arange(lower_bound, upper_bound + 1, 1), label=f'Configuration {i}', alpha=0.5, edgecolor='black', color=colors[i]) # noqa: E501
plt.xticks(range(max(state_ranges[-1]) + 1))

# Here we color the different regions to show alchemical ranges
y_min, y_max = ax.get_ylim()
for i in range(n_configs):
bounds = [list(state_ranges[i])[0], list(state_ranges[i])[-1]]
if i == 0:
bounds[0] -= 0.5
if i == n_configs - 1:
bounds[1] += 0.5
plt.fill_betweenx([y_min, y_max], x1=bounds[1] + 0.5, x2=bounds[0] - 0.5, color=colors[i], alpha=0.1, zorder=0)
plt.xlim([lower_bound, upper_bound])
plt.ylim([y_min, y_max])
plt.xlabel('State index')
plt.ylabel('Count')
plt.grid()
plt.legend()
plt.savefig(f'{fig_name}', dpi=600)


def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
"""
Caclulcates and plots the average transit times for each configuration, including the time
Expand Down
26 changes: 16 additions & 10 deletions ensemble_md/cli/analyze_EEXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
from datetime import datetime
from deeptime.markov.tools.analysis import is_transition_matrix
warnings.simplefilter(action='ignore', category=UserWarning)

Expand Down Expand Up @@ -89,6 +90,7 @@ def main():
rc('mathtext', **{'default': 'regular'})
plt.rc('font', family='serif')

print(f'Current time: {datetime.now().strftime("%d/%m/%Y %H:%M:%S")}')
print(f'Command line: {" ".join(sys.argv)}')

EEXE = EnsembleEXE(args.yaml)
Expand Down Expand Up @@ -156,24 +158,28 @@ def main():
dt_traj = EEXE.dt * EEXE.template['nstdhdl'] # in ps
analyze_traj.plot_state_trajs(state_trajs, EEXE.state_ranges, f'{args.dir}/state_trajs.png', dt_traj)

# 2-2. Plot the overall state transition matrices calculated from the state-space trajectories
print('\n2-2. Plotting the overall state transition matrices ...')
# 2-2. Plot the histograms for the states
print('\n2-2. Plotting the histograms of the state index ...')
analyze_traj.plot_state_hist(state_trajs, EEXE.state_ranges, f'{args.dir}/state_hist.png')

# 2-3. Plot the overall state transition matrices calculated from the state-space trajectories
print('\n2-3. Plotting the overall state transition matrices ...')
mtx_list = []
for i in range(EEXE.n_sim):
mtx = analyze_traj.traj2transmtx(state_trajs[i], EEXE.n_tot)
mtx_list.append(mtx)
analyze_matrix.plot_matrix(mtx, f'{args.dir}/config_{i}_state_transmtx.png')

# 2-3. For each configurration, calculate the spectral gap of the overall transition matrix obtained in step 2-2.
print('\n2-3. Calculating the spectral gap of the state transition matrices ...')
# 2-4. For each configurration, calculate the spectral gap of the overall transition matrix obtained in step 2-2.
print('\n2-4. Calculating the spectral gap of the state transition matrices ...')
spectral_gaps = [analyze_matrix.calc_spectral_gap(mtx) for mtx in mtx_list]
if None not in spectral_gaps:
for i in range(EEXE.n_sim):
print(f' - Configuration {i}: {spectral_gaps[i]:.3f}')
print(f' - Average of the above: {np.mean(spectral_gaps):.3f} (std: {np.std(spectral_gaps, ddof=1):.3f})')

# 2-4. For each configuration, calculate the stationary distribution from the overall transition matrix obtained in step 2-2. # noqa: E501
print('\n2-4. Calculating the stationary distributions ...')
# 2-5. For each configuration, calculate the stationary distribution from the overall transition matrix obtained in step 2-2. # noqa: E501
print('\n2-5. Calculating the stationary distributions ...')
pi_list = [analyze_matrix.calc_equil_prob(mtx) for mtx in mtx_list]
if any([x is None for x in pi_list]):
pass # None is in the list
Expand All @@ -183,15 +189,15 @@ def main():
if len({len(i) for i in pi_list}) == 1: # all lists in pi_list have the same length
print(f' - Average of the above: {", ".join([f"{i:.3f}" for i in np.mean(pi_list, axis=0).reshape(-1)])}') # noqa: E501

# 2-5. Calculate the state index correlation time for each configuration (this step is more time-consuming one)
print('\n2-5. Calculating the state index correlation time ...')
# 2-6. Calculate the state index correlation time for each configuration (this step is more time-consuming one)
print('\n2-6. Calculating the state index correlation time ...')
tau_list = [(pymbar.timeseries.statistical_inefficiency(state_trajs[i], fast=True) - 1) / 2 * dt_traj for i in range(EEXE.n_sim)] # noqa: E501
for i in range(EEXE.n_sim):
print(f' - Configuration {i}: {tau_list[i]:.1f} ps')
print(f' - Average of the above: {np.mean(tau_list):.1f} ps (std: {np.std(tau_list, ddof=1):.1f} ps)')

# 2-6. Calculate transit times for each configuration
print('\n2-6. Plotting the average transit times ...')
# 2-7. Calculate transit times for each configuration
print('\n2-7. Plotting the average transit times ...')
t_0k_list, t_k0_list, t_roundtrip_list, units = analyze_traj.plot_transit_time(state_trajs, EEXE.n_tot, dt=dt_traj, folder=args.dir) # noqa: E501
meta_list = [t_0k_list, t_k0_list, t_roundtrip_list]
t_names = [
Expand Down
2 changes: 2 additions & 0 deletions ensemble_md/cli/run_EEXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import numpy as np
from mpi4py import MPI
from datetime import datetime

from ensemble_md.utils import utils
from ensemble_md.ensemble_EXE import EnsembleEXE
Expand Down Expand Up @@ -69,6 +70,7 @@ def main():
rank = comm.Get_rank() # Note that this is a GLOBAL variable

if rank == 0:
print(f'Current time: {datetime.now().strftime("%d/%m/%Y %H:%M:%S")}')
print(f'Command line: {" ".join(sys.argv)}\n')

EEXE = EnsembleEXE(args.yaml)
Expand Down

0 comments on commit 9b5a74e

Please sign in to comment.