Skip to content

Commit

Permalink
Port of w_pdist custom weights for WESTPA 2.0 (#249)
Browse files Browse the repository at this point in the history
* port of w_pdist custom weights

Co-authored-by: ASinanSaglam <asinansaglam@gmail.com>

* fix w_pdist test

* fix w_trace test memory leak while I'm at it

Co-authored-by: ASinanSaglam <asinansaglam@gmail.com>
  • Loading branch information
jeremyleung521 and ASinanSaglam committed Jun 15, 2022
1 parent 1133340 commit 747daee
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 4 deletions.
25 changes: 22 additions & 3 deletions src/westpa/cli/tools/w_pdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import h5py
import numpy as np

from westpa.tools import WESTParallelTool, WESTDataReader, WESTDSSynthesizer, IterRangeSelection, ProgressIndicatorComponent
from westpa.tools import (
WESTParallelTool,
WESTDataReader,
WESTDSSynthesizer,
WESTWDSSynthesizer,
IterRangeSelection,
ProgressIndicatorComponent,
)

from westpa.fasthist import histnd, normhistnd
from westpa.core import h5io
from westpa.core.h5io import SingleIterDSSpec


log = logging.getLogger('w_pdist')
Expand Down Expand Up @@ -191,6 +197,7 @@ def __init__(self):
self.progress = ProgressIndicatorComponent()
self.data_reader = WESTDataReader()
self.input_dssynth = WESTDSSynthesizer(default_dsname='pcoord')
self.input_wdssynth = WESTWDSSynthesizer(default_dsname='seg_index')
self.iter_range = IterRangeSelection(self.data_reader)
self.iter_range.include_args['iter_step'] = False
self.binspec = None
Expand Down Expand Up @@ -266,6 +273,15 @@ def add_args(self, parser):
'--dsspecs', nargs='+', metavar='DSSPEC', help='''Construct probability distribution from one or more DSSPECs.'''
)

wgroup = parser.add_argument_group('input weight dataset options').add_mutually_exclusive_group(required=False)
wgroup.add_argument(
'--construct-wdataset',
help='''Use the given function (as in module.function) to extract weight data.
This function will be called once per iteration as function(n_iter, iter_group)
to construct data for one iteration. Data returned must be indexable as
[seg_id]''',
)

self.progress.add_args(parser)

def process_args(self, args):
Expand All @@ -281,7 +297,10 @@ def process_args(self, args):
with self.data_reader:
self.iter_range.process_args(args)

self.wt_dsspec = SingleIterDSSpec(self.data_reader.we_h5filename, 'seg_index', slice=np.index_exp['weight'])
# Reading potential custom weights
self.input_wdssynth.h5filename = self.data_reader.we_h5filename
self.input_wdssynth.process_args(args)
self.wt_dsspec = self.input_wdssynth.dsspec

self.binspec = args.bins
self.output_filename = args.output
Expand Down
3 changes: 2 additions & 1 deletion src/westpa/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''tools -- classes for implementing command-line tools for WESTPA'''
from .core import WESTTool, WESTParallelTool, WESTToolComponent, WESTSubcommand, WESTMasterCommand, WESTMultiTool
from .data_reader import WESTDataReader, WESTDSSynthesizer
from .data_reader import WESTDataReader, WESTDSSynthesizer, WESTWDSSynthesizer
from .iter_range import IterRangeSelection
from .selected_segs import SegSelector
from .binning import BinMappingComponent, mapper_from_dict
Expand All @@ -18,6 +18,7 @@
'WESTMultiTool',
'WESTDataReader',
'WESTDSSynthesizer',
'WESTWDSSynthesizer',
'IterRangeSelection',
'SegSelector',
'BinMappingComponent',
Expand Down
35 changes: 35 additions & 0 deletions src/westpa/tools/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,38 @@ def process_args(self, args):
# we can only get here if a default dataset name was specified
assert self.default_dsname
self.dsspec = SingleSegmentDSSpec(self.h5filename, self.default_dsname)


class WESTWDSSynthesizer(WESTToolComponent):
group_name = 'weight dataset options'

def __init__(self, default_dsname=None, h5filename=None):
super(WESTWDSSynthesizer, self).__init__()

self.h5filename = h5filename
self.default_dsname = default_dsname

self.dsspec = None

def add_args(self, parser):
wgroup = parser.add_argument_group(self.group_name).add_mutually_exclusive_group(required=not bool(self.default_dsname))

wgroup.add_argument(
'--construct-wdataset',
help='''Use the given function (as in module.function) to extract source data.
This function will be called once per iteration as function(n_iter, iter_group)
to construct data for one iteration. Data returned must be indexable as
[seg_id]''',
)
wgroup.add_argument('--wdsspecs', nargs='+', metavar='WDSSPEC', help='''Construct weight data from one or more DSSPECs.''')

def process_args(self, args):
if args.construct_wdataset:
self.dsspec = FnDSSpec(self.h5filename, get_object(args.construct_wdataset, path=['.']))
elif args.dsspecs:
self.dsspec = MultiDSSpec([SingleSegmentDSSpec.from_string(dsspec, self.h5filename) for dsspec in args.dsspecs])
else:
# we can only get here if a default dataset name was specified
assert self.default_dsname
# we gotta slice by weight for weights if we want to get the default to work
self.dsspec = SingleIterDSSpec(self.h5filename, self.default_dsname, slice=np.index_exp['weight'])
1 change: 1 addition & 0 deletions tests/test_tools/test_w_pdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_run_w_pdist(self, ref_50iter):
compress=False,
work_manager=None,
n_workers=None,
construct_wdataset=None,
),
):
entry_point()
Expand Down
2 changes: 2 additions & 0 deletions tests/test_tools/test_w_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def test_trace(self, ref_50iter):
):
entry_point()

self.outfile.close()

# Compare text file output
assert cmp(
os.path.join(test_dir, output_txt), os.path.join(ref_dir, output_txt), shallow=False
Expand Down

0 comments on commit 747daee

Please sign in to comment.