Skip to content

Commit

Permalink
Merge pull request #1134 from wright-group/data-drop-method
Browse files Browse the repository at this point in the history
`Data.squeeze`
  • Loading branch information
kameyer226 committed Jan 30, 2024
2 parents 6f9c549 + b71c838 commit c6f4378
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

### Added
- `Data.squeeze`: squeezing the data object to the shape of the axes.

### Fixed
- `interact2D`: fixed bug where use_imshow broke the sliders
- `data.join` ensures valid `method` is selected
Expand Down
10 changes: 10 additions & 0 deletions WrightTools/data/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,13 @@ def trim(self, neighborhood, method="ztest", factor=3, replace="nan", verbose=Tr
if verbose:
print("%i outliers removed" % len(outliers))
return outliers

def _to_dict(self):
out = {}
out["name"] = self.natural_name
out["values"] = self[:]
out["units"] = self.units
out["label"] = self.label
out["signed"] = self.signed
out.update(self.attrs)
return out
85 changes: 74 additions & 11 deletions WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,78 @@ def at(self, parent=None, name=None, **at) -> Data:
idx = self._at_to_slice(**at)
return self._from_slice(idx, name=name, parent=parent)

def squeeze(self, name=None, parent=None):
"""Reduce the data to the dimensionality of the (non-trivial) span of the axes.
i.e. if the joint shape of the axes has an array dimension with length 1, this
array dimension is squeezed.
channels and variables that span beyond the axes are removed.
Parameters
----------
name : string (optional)
name of the new Data.
parent : WrightTools Collection instance (optional)
Collection to place the new "chop" collection within. Default is
None (new parent).
Returns
-------
out : wt.Data
new data object. The new data object has dimensions of the
(non-trivial) span of the current axes
Examples
--------
>>> ...
See also
--------
Data.chop: Divide the dataset into its lower-dimensionality components.
...
"""
new = Data(name=name, parent=parent)

attrs = {
k: v
for k, v in self.attrs.items()
if k
not in [
"axes",
"channel_names",
"constants",
"name",
"source",
"item_names",
"variable_names",
]
}
new.attrs.update(attrs)

joint_shape = wt_kit.joint_shape(*[ai[:] for ai in self.axes])
cull_dims = [j == 1 for j in joint_shape]
sl = [0 if cull else slice(None) for cull in cull_dims]
matches_broadcast_axes = lambda a: all(
[a.shape[i] == 1 for i in range(self.ndim) if cull_dims[i]]
)

for v in filter(matches_broadcast_axes, self.variables):
kwargs = v._to_dict()
kwargs["values"] = v[sl]
new.create_variable(**kwargs)

for c in filter(matches_broadcast_axes, self.channels):
kwargs = c._to_dict()
kwargs["values"] = c[sl]
new.create_channel(**kwargs)

# inherit constants
for c in self.constants:
new.create_constant(c.expression)

new.transform(*self.axis_expressions)
return new

def chop(self, *args, at=None, parent=None, verbose=True) -> wt_collection.Collection:
"""Divide the dataset into its lower-dimensionality components.
Expand Down Expand Up @@ -508,21 +580,12 @@ def _from_slice(self, idx, name=None, parent=None) -> Data:
out = parent.create_data(name=name)

for v in self.variables:
kwargs = {}
kwargs["name"] = v.natural_name
kwargs = v._to_dict()
kwargs["values"] = v[idx]
kwargs["units"] = v.units
kwargs["label"] = v.label
kwargs.update(v.attrs)
out.create_variable(**kwargs)
for c in self.channels:
kwargs = {}
kwargs["name"] = c.natural_name
kwargs = c._to_dict()
kwargs["values"] = c[idx]
kwargs["units"] = c.units
kwargs["label"] = c.label
kwargs["signed"] = c.signed
kwargs.update(c.attrs)
out.create_channel(**kwargs)

new_axes = [a.expression for a in self.axes if a[idx].size > 1]
Expand Down
9 changes: 9 additions & 0 deletions WrightTools/data/_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ def label(self) -> str:
@label.setter
def label(self, label):
self.attrs["label"] = label

def _to_dict(self):
out = {}
out["name"] = self.natural_name
out["values"] = self[:]
out["units"] = self.units
out["label"] = self.label
out.update(self.attrs)
return out
46 changes: 46 additions & 0 deletions tests/data/squeeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#! /usr/bin/env python3
"""Test squeeze."""


# --- import -------------------------------------------------------------------------------------


import numpy as np
import WrightTools as wt
from WrightTools import datasets


# --- tests --------------------------------------------------------------------------------------


def test_squeeze():
d = wt.Data(name="test")
d.create_variable("x", values=np.arange(5)[:, None, None])
d.create_variable("y", values=np.arange(4)[None, :, None])
d.create_variable("redundant_array", values=np.tile(np.arange(3), (5, 4, 1)))

d.create_channel("keep", values=d.x[:] + d.y[:])
d.create_channel("throw_away", values=np.zeros((5, 4, 3)))

d.transform("x", "y")
d = d.squeeze() # make sure it runs error free
assert d.ndim == 2
assert d.shape == (5, 4)


def test_constants():
d = wt.Data(name="test")
d.create_variable("x", values=np.array([1]).reshape(1, 1))
d.create_constant("x")
d.create_variable("y", values=np.linspace(3, 5, 4).reshape(-1, 1))
d.create_variable("z", values=np.linspace(0, 1, 6).reshape(1, -1))
d.transform("y")
ds = d.squeeze()
assert "x" in ds.constant_expressions
d.print_tree()
ds.print_tree()


if __name__ == "__main__":
test_squeeze()
test_constants()

0 comments on commit c6f4378

Please sign in to comment.