Skip to content

Commit

Permalink
Make expected behavior when chopping axes with a removed axis that sp… (
Browse files Browse the repository at this point in the history
#984)

* Make expected behavior when chopping axes with a removed axis that spans the whole shape

particualarly useful for array detector when plotting channel other than the array detector, but with the array axis present

* Add test of new behavior

old master failed at the len check, where it had 36 items in the chop

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update changelog

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ksunden and pre-commit-ci[bot] committed Nov 17, 2020
1 parent 8fa0116 commit a2c7f4e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

### Changed
- Improved chopping with axes that span the kept axes removed

## [3.3.2]

## Added
Expand Down
3 changes: 3 additions & 0 deletions WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,9 @@ def chop(self, *args, at={}, parent=None, verbose=True) -> wt_collection.Collect
for i in at.keys():
if type(i) == int:
removed_shape[i] = 1
for ax in kept_axes:
if ax.shape.count(1) == ax.ndim - 1:
removed_shape[ax.shape.index(ax.size)] = 1
removed_shape = tuple(removed_shape)
# iterate
i = 0
Expand Down
19 changes: 19 additions & 0 deletions tests/data/chop.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,25 @@ def test_transformed():
assert d[0]["y"].shape == (1,)


def test_rmd_axis_full_shape():
x = np.arange(6)
y = x[::2].copy()
z = np.arange(x.size * y.size * 2).reshape(x.size, y.size, 2).astype("float")
z[:, y < 2] *= 0
data = wt.data.Data(name="data")
data.create_channel("signal", values=z, signed=False)
data.create_variable("x", values=x[:, None, None], units="wn")
data.create_variable("y", values=y[None, :, None], units="wn")
data.create_variable("z", values=z, units="wn")

data.transform("x", "y", "z")

c = data.chop("x", "y")

assert len(c) == 2
assert c[0].shape == (6, 3)


# --- run -----------------------------------------------------------------------------------------


Expand Down

0 comments on commit a2c7f4e

Please sign in to comment.