Skip to content

Commit

Permalink
easier chopping/slicing operations on Data (#1078)
Browse files Browse the repository at this point in the history
* Update _data.py

refactor chop
chop works still

* remove type hint

python 3.6 has different method than 3.7+

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update _data.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update _data.py

* Update _data.py

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update _data.py

* Update _data.py

* Update CHANGELOG.md

* Update _data.py

* cleanup

* string2identifier

filter for operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor axis identifiers

* Update _data.py

stage for getitem slicing

* Data.__getitem__

* Update _data.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* codeql

* string2identifier fix

retain usual behavior but optional kwargs to change

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update _utilities.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move operator_to_identifier

* Update _data.py

* chop, at

fix parsing of coordinates

* fix Data.at

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update _utilities.py

* Update _utilities.py

* Update CHANGELOG.md

* Update _utilities.py

* Update _data.py

move axis creation to _from_slice method

* first test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 3.7+ type hints

* Update WrightTools/data/_data.py

Co-authored-by: Kyle Sunden <git@ksunden.space>

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kyle Sunden <git@ksunden.space>
  • Loading branch information
3 people committed Jul 27, 2022
1 parent 45c1efe commit bc6f240
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 79 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).
## [Unreleased]

### Added
- new `Data.at` method: syntactic sugar for chop with "at" argument.
- `Data.__getitem__` supports array slicing
- `artists.interact2D` supports `cmap` kwarg.
- iPython integration: autocomplete includes axis, variable, and channel names

### Changed
- `Data.chop` refactored to make steps modular
- `artists.interact2D` uses matplotlib norm objects to control colormap scaling

### Fixed
Expand Down
4 changes: 1 addition & 3 deletions WrightTools/data/_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def label(self) -> str:
def natural_name(self) -> str:
"""Valid python identifier representation of the expession."""
name = self.expression.strip()
for op in operators:
name = name.replace(op, operator_to_identifier[op])
return wt_kit.string2identifier(name)
return wt_kit.string2identifier(name, replace=operator_to_identifier)

@property
def ndim(self) -> int:
Expand Down
212 changes: 143 additions & 69 deletions WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# --- import --------------------------------------------------------------------------------------

from __future__ import annotations

import collections
import operator
Expand All @@ -21,7 +22,7 @@
from .. import exceptions as wt_exceptions
from .. import kit as wt_kit
from .. import units as wt_units
from ._axis import Axis, identifier_to_operator
from ._axis import Axis, identifier_to_operator, operator_to_identifier
from ._channel import Channel
from ._constant import Constant
from ._variable import Variable
Expand Down Expand Up @@ -313,6 +314,51 @@ def bring_to_front(self, channel):
new.insert(0, new.pop(channel_index))
self.channel_names = new

def at(self, parent=None, name=None, **at) -> Data:
"""Return data of a subset of the data at specified axis position(s).
kwargs
------
parent : WrightTools.Collection (optional)
parent of new object. Default is no parent.
name : string (optional)
if specified, name of new data object.
**at :
keys are axis identifiers (Axis.natural_name), values lists of format
[position, unit] if no coordinates are specified, returns a copy of
the array. _specified axes must be one-dimensional in the array_
Returns
-------
out : wt.Data
Data that retains the dimension of unspecified axes.
Example
-------
```
>>> from WrightTools import datasets
>>> data = wt.open(datasets.wt5.v1p0p1_MoS2_TrEE_movie) # axes w2, w1=wm, d2
>>> probed_at_resonance = data.at(w1__e__wm=[690, "nm"])
>>> zero_delay_data = data.at(d2=[0, "fs"])
```
Notes
-----
* _most attrs are not retained in the new data object_.
* some axis expressions use forbidden dictkey characters (e.g. "=").
For these case, you must use the string substitution
(e.g "=" -> "__e__"). Use `Data.axis_names` or `Axis.natural_name`
to check the identifier form of the expression. Also consider using
kit.string2identifier to filter.
See Also
--------
Data.chop : also reduces dimensionality, but returns a collection of one or more
data objects. Kept axes can be sliced into a collection
"""
idx = self._at_to_slice(**at)
return self._from_slice(idx, name=name, parent=parent)

def chop(self, *args, at=None, parent=None, verbose=True) -> wt_collection.Collection:
"""Divide the dataset into its lower-dimensionality components.
Expand Down Expand Up @@ -365,100 +411,128 @@ def chop(self, *args, at=None, parent=None, verbose=True) -> wt_collection.Colle
split
Split the dataset while maintaining its dimensionality.
"""
from ._axis import operators, operator_to_identifier

# parse args
args = list(args)
for i, arg in enumerate(args):
if isinstance(arg, int):
args[i] = self._axes[arg].natural_name
elif isinstance(arg, str):
# same normalization that occurs in the natural_name @property
arg = arg.strip()
for op in operators:
arg = arg.replace(op, operator_to_identifier[op])
args[i] = wt_kit.string2identifier(arg)
args[i] = wt_kit.string2identifier(arg, replace=operator_to_identifier)

# normalize the at keys to the natural name
if at is None:
at = {}
for k in [ak for ak in at.keys() if type(ak) == str]:
for op in operators:
if op in k:
nk = k.replace(op, operator_to_identifier[op])
at[nk] = at[k]
at.pop(k)
k = nk

# get output collection
out = wt_collection.Collection(name="chop", parent=parent)
# get output shape
kept = args + [ak for ak in at.keys() if type(ak) == str]
# normalize the at keys to the natural name
for k in list(at.keys()):
k = k.strip()
nk = wt_kit.string2identifier(k, replace=operator_to_identifier)
if nk != k:
at[nk] = at[k]
at.pop(k)

# distinguish looping and non-looping indices
at_idx = self._at_to_slice(**at)
at_axes = at_idx != slice(None)

kept = args + [ak for ak in at.keys()]
kept_axes = [self._axes[self.axis_names.index(a)] for a in kept]

removed_axes = [a for a in self._axes if a not in kept_axes]
removed_shape = wt_kit.joint_shape(*removed_axes)
if removed_shape == ():
removed_shape = (1,) * self.ndim
removed_shape = list(removed_shape)
for i in at.keys():
if type(i) == int:
removed_shape[i] = 1
for ax in kept_axes:
if ax.shape.count(1) == ax.ndim - 1:
removed_shape[ax.shape.index(ax.size)] = 1
removed_shape = tuple(removed_shape)

# get output collection
out = wt_collection.Collection(name="chop", parent=parent)
i_digits = int(np.log10(np.prod(removed_shape))) + 1
# iterate
i = 0
for idx in np.ndindex(removed_shape):
for i, idx in enumerate(np.ndindex(removed_shape)):
name = f"chop{i:0>{i_digits}}"
idx = np.array(idx, dtype=object)
idx[np.array(removed_shape) == 1] = slice(None)
for axis, point in at.items():
if type(axis) == int:
idx[axis] = point
continue
point, units = point
destination_units = self._axes[self.axis_names.index(axis)].units
point = wt_units.converter(point, units, destination_units)
axis_index = self.axis_names.index(axis)
axis = self._axes[axis_index]
idx_index = np.array(axis.shape) > 1
if np.sum(idx_index) > 1:
raise wt_exceptions.MultidimensionalAxisError("chop", axis.natural_name)
idx_index = list(idx_index).index(True)
idx[idx_index] = np.argmin(np.abs(axis[tuple(idx)] - point))
data = out.create_data(name="chop%03i" % i)
for v in self.variables:
kwargs = {}
kwargs["name"] = v.natural_name
kwargs["values"] = v[idx]
kwargs["units"] = v.units
kwargs["label"] = v.label
kwargs.update(v.attrs)
data.create_variable(**kwargs)
for c in self.channels:
kwargs = {}
kwargs["name"] = c.natural_name
kwargs["values"] = c[idx]
kwargs["units"] = c.units
kwargs["label"] = c.label
kwargs["signed"] = c.signed
kwargs.update(c.attrs)
data.create_channel(**kwargs)
new_axes = [a.expression for a in kept_axes if a.expression not in at.keys()]
new_axis_units = [a.units for a in kept_axes if a.expression not in at.keys()]
data.transform(*new_axes)
for const in self.constant_expressions:
data.create_constant(const, verbose=False)
for ax in self.axis_expressions:
if ax not in new_axes:
data.create_constant(ax, verbose=False)
for j, units in enumerate(new_axis_units):
data.axes[j].convert(units)
i += 1
idx[at_axes] = at_idx[at_axes]
self._from_slice(idx, name=name, parent=out)
out.flush()
# return
if verbose:
print("chopped data into %d piece(s)" % len(out), "in", new_axes)
print("chopped data into %d piece(s)" % len(out), "in", out[0].axis_expressions)
return out

def __getitem__(self, key) -> Data:
"""
data[5, :3]; return new data object with those array slices
"""
if type(key) in [int, slice]:
return self._from_slice(key)
elif (type(key) == tuple) and np.all([type(ki) in [int, slice] for ki in key]):
return self._from_slice(key)
else:
return super().__getitem__(key)

def _at_to_slice(self, **at) -> np.array:
"""create array slice using at"""
for k in at.keys():
if k not in self.axis_names:
raise ValueError(f"Axis identifier {k} not in Axes: {self.axis_names}")
idx = np.array([slice(None)] * len(self._axes))
for axis, point in at.items():
point, units = point
destination_units = self._axes[self.axis_names.index(axis)].units
point = wt_units.converter(point, units, destination_units)
axis_index = self.axis_names.index(axis)
axis = self._axes[axis_index]
idx_index = np.array(axis.shape) > 1
if np.sum(idx_index) > 1:
# we don't know how to handle a position within multiple array
# dimensions
raise wt_exceptions.MultidimensionalAxisError(
"Data._at_to_slice", axis.natural_name
)
idx_index = list(idx_index).index(True)
idx[idx_index] = np.argmin(np.abs(axis[tuple(idx)] - point))
return idx

def _from_slice(self, idx, name=None, parent=None) -> Data:
"""create self from an array slice of the parent self"""
if parent is None:
out = Data(name=name)
else:
out = parent.create_data(name=name)

for v in self.variables:
kwargs = {}
kwargs["name"] = v.natural_name
kwargs["values"] = v[idx]
kwargs["units"] = v.units
kwargs["label"] = v.label
kwargs.update(v.attrs)
out.create_variable(**kwargs)
for c in self.channels:
kwargs = {}
kwargs["name"] = c.natural_name
kwargs["values"] = c[idx]
kwargs["units"] = c.units
kwargs["label"] = c.label
kwargs["signed"] = c.signed
kwargs.update(c.attrs)
out.create_channel(**kwargs)

new_axes = [a.expression for a in self.axes if a[idx].size > 1]
out.transform(*new_axes)

constants = [a for a in self.axis_expressions if a not in new_axes]
new_axis_units = [a.units for a in self.axes if a.expression in out.axis_expressions]

for const in list(self.constant_expressions) + constants:
out.create_constant(const, verbose=False)
for j, units in enumerate(new_axis_units):
out.axes[j].convert(units)

return out

def gradient(self, axis, *, channel=0):
Expand Down
21 changes: 14 additions & 7 deletions WrightTools/kit/_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,40 @@
# --- functions -----------------------------------------------------------------------------------


def string2identifier(s):
def string2identifier(s, replace=None):
"""Turn a string into a valid python identifier.
Currently only allows ASCII letters and underscore. Illegal characters
are replaced with underscore. This is slightly more opinionated than
python 3 itself, and may be refactored in future (see PEP 3131).
This method restricts identifier characters to ASCII letters, numbers, and
underscore. The characters are slightly more restrictive than python 3
itself, and may be refactored in future (see PEP 3131).
For non-valid characters, the default replacement is "_". Replacement
assignments can be customized with the replace kwarg.
Parameters
----------
s : string
string to convert
replace: dictionary[str, str] (optional)
dictionary of characters (keys) and their replacements (values). Values
should be ASCII or underscore. Unspecified non-ascii characters are
converted to underscore.
Returns
-------
str
valid python identifier.
"""
# https://docs.python.org/3/reference/lexical_analysis.html#identifiers
# https://www.python.org/dev/peps/pep-3131/
if len(s) == 0:
return "_"
if s[0] not in string.ascii_letters:
s = "_" + s
valids = string.ascii_letters + string.digits + "_"
out = ""
for i, char in enumerate(s):
if char in valids:
if replace and (char in replace.keys()):
out += replace[char]
elif char in valids:
out += char
else:
out += "_"
Expand Down
35 changes: 35 additions & 0 deletions tests/data/at.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#! /usr/bin/env python3
"""Test at."""


# --- import -------------------------------------------------------------------------------------


import WrightTools as wt
from WrightTools import datasets


# --- tests --------------------------------------------------------------------------------------


def test_3D_to_1D():
data = wt.open(datasets.wt5.v1p0p1_MoS2_TrEE_movie)
sliced = data.at(d2=[-50, "fs"], w2=[700, "nm"])
assert sliced.axis_expressions == ("w1=wm",)
sliced = data.at(w1__e__wm=[605, "nm"], w2=[700, "nm"])
assert sliced.axis_expressions == ("d2",)
data.close()
sliced.close()


def test_chop_equivalence():
data = wt.open(datasets.wt5.v1p0p1_MoS2_TrEE_movie)
at_data = data.at(d2=[-50, "fs"], w2=[700, "nm"])
chop_data = data.chop("w1=wm", at={"d2": [-50, "fs"], "w2": [700, "nm"]})[0]
assert at_data.shape == chop_data.shape
assert at_data.axis_expressions == chop_data.axis_expressions


if __name__ == "__main__":
test_3D_to_1D()
test_chop_equivalence()

0 comments on commit bc6f240

Please sign in to comment.