Skip to content

Commit

Permalink
chop respects axis order of args (#1092)
Browse files Browse the repository at this point in the history
* chop respects axis order of args

* Update _data.py

* Update _data.py
  • Loading branch information
ddkohler committed Jul 30, 2022
1 parent 773ad71 commit 60e5e0c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
3 changes: 3 additions & 0 deletions WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ def chop(self, *args, at=None, parent=None, verbose=True) -> wt_collection.Colle
arg = arg.strip()
args[i] = wt_kit.string2identifier(arg, replace=operator_to_identifier)

transform_expression = [self._axes[self.axis_names.index(a)].expression for a in args]

if at is None:
at = {}
# normalize the at keys to the natural name
Expand Down Expand Up @@ -457,6 +459,7 @@ def chop(self, *args, at=None, parent=None, verbose=True) -> wt_collection.Colle
idx[np.array(removed_shape) == 1] = slice(None)
idx[at_axes] = at_idx[at_axes]
self._from_slice(idx, name=name, parent=out)
out[name].transform(*transform_expression)
out.flush()
# return
if verbose:
Expand Down
18 changes: 18 additions & 0 deletions tests/data/chop.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def test_parent():
assert chop.parent is parent


def test_axes_order():
x = np.arange(6)
y = x[::2].copy()
z = x[::3].copy()
chan = np.arange(x.size * y.size * z.size).reshape(x.size, y.size, z.size).astype("float")
data = wt.data.Data(name="data")
data.create_channel("chan", values=chan, signed=False)
data.create_variable("x", values=x[:, None, None], units="wn")
data.create_variable("y", values=y[None, :, None], units="wn")
data.create_variable("z", values=z[None, None, :], units="wn")

data.transform("y", "x", "z")

d = data.chop("z", "x", at={"y": (2, "wn")})[0]
assert d.axis_names == ("z", "x")


def test_transformed():
x = np.arange(6)
y = x[::2].copy()
Expand Down Expand Up @@ -174,6 +191,7 @@ def test_rmd_axis_full_shape():

if __name__ == "__main__":
test_transformed()
test_axes_order()
test_2D_to_1D()
test_3D_to_1D()
test_3D_to_1D_at()
Expand Down

0 comments on commit 60e5e0c

Please sign in to comment.