Skip to content

Commit

Permalink
bring topas API in line with superclass (#79)
Browse files Browse the repository at this point in the history
* bring topas API in line with superclass

* Update test for current return

* Add plotting and verbose output to TOPAS save

* Pass full and verbose kwargs when saving subcurves

* Fix variable name
  • Loading branch information
ksunden authored and untzag committed Dec 27, 2019
1 parent 2aac816 commit e765e77
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion attune/curve/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def save(self, save_directory=None, plot=True, verbose=True, full=False):
np.savetxt(f, out_arr.T, fmt=self.fmt, delimiter="\t")
# save subcurve
if full and self.subcurve:
self.subcurve.save(save_directory=save_directory)
self.subcurve.save(save_directory=save_directory, full=True, verbose=verbose)
# plot
if plot:
image_path = out_path.with_suffix(".png")
Expand Down
24 changes: 17 additions & 7 deletions attune/curve/_topas.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _read_file(cls, filepath):
f.close()
return curves

def save(self, save_directory, full=True):
def save(self, save_directory=None, plot=True, verbose=True, full=False):
"""Save a curve object.
Parameters
Expand Down Expand Up @@ -181,16 +181,26 @@ def save(self, save_directory, full=True):
to_insert["NON-NON-NON-Idl"] = _convert(curve)
to_insert["NON-NON-NON-Idl"].interaction = "NON-NON-NON-Idl"

save_directory = pathlib.Path(save_directory)
# get save directory
if save_directory is None:
save_directory = pathlib.Path()
else:
save_directory = pathlib.Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
timestamp = wt.kit.TimeStamp().path
out_paths = []

ret_name = curve.kind + "- " + timestamp
ret_path = (save_directory / ret_name).with_suffix(".crv")

if plot:
image_path = ret_path.with_suffix(".png")
title = ret_path.stem
self.plot(autosave=True, save_path=image_path, title=title)

while len(to_insert):
_, curve = to_insert.popitem()
out_name = curve.kind + "- " + timestamp
out_path = (save_directory / out_name).with_suffix(".crv")
out_paths.append(out_path)
all_sibs = [curve]
if curve.siblings:
all_sibs += curve.siblings
Expand All @@ -201,8 +211,10 @@ def save(self, save_directory, full=True):
for c in all_sibs:
_write_curve(new_crv, c)
to_insert.pop(c.interaction, None)
if verbose:
print("curve saved at", out_path)

return out_paths
return ret_path

def _get_family_dict(self, start=None):
if start is None:
Expand All @@ -229,8 +241,6 @@ def _insert(curve):
arr[0] = curve.source_setpoints[:]
arr[1] = curve.setpoints[:]
arr[2] = len(motor_indexes)
print(motor_indexes)
print([d.index for d in curve.dependents.values()])
for i, m in enumerate(motor_indexes):
arr[3 + i] = next(d for d in curve.dependents.values() if d.index == m)[:]
return arr.T
Expand Down
3 changes: 2 additions & 1 deletion tests/TopasCurve/read/2018-10-26/niuce.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def test_round_trip():
curve = attune.TopasCurve.read(paths, interaction_string="NON-SH-NON-Sig")
with tempfile.TemporaryDirectory() as td:
td = pathlib.Path(td)
paths = curve.save(save_directory=td)
curve.save(save_directory=td, full=True)
paths = td.glob("*.crv")
read_curve = attune.TopasCurve.read(paths, interaction_string="NON-SH-NON-Sig")
assert np.allclose(curve.setpoints[:], read_curve.setpoints[:])
assert curve.dependent_names == read_curve.dependent_names
Expand Down

0 comments on commit e765e77

Please sign in to comment.