-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #79 from wmayner/feature/tpm-class
ExplicitTPM: Numpy universal funcs and NDArrayOperatorsMixin
- Loading branch information
Showing
13 changed files
with
263 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
# data_structures/__init__.py | ||
|
||
from ordered_set import OrderedSet | ||
|
||
from .array_like import ArrayLike | ||
from .frozen_map import FrozenMap | ||
from .hashable_ordered_set import HashableOrderedSet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# -*- coding: utf-8 -*- | ||
# data_structures/array_like.py | ||
|
||
import numpy as np | ||
from numpy.lib.mixins import NDArrayOperatorsMixin | ||
from numbers import Number | ||
|
||
|
||
class ArrayLike(NDArrayOperatorsMixin): | ||
# Only support operations with instances of _HANDLED_TYPES. | ||
_HANDLED_TYPES = (np.ndarray, list, Number) | ||
|
||
# TODO(tpm) populate this list | ||
_TYPE_CLOSED_FUNCTIONS = ( | ||
np.concatenate, | ||
np.stack, | ||
np.all, | ||
np.sum, | ||
) | ||
|
||
# Holds the underlying array | ||
_VALUE_ATTR = "value" | ||
|
||
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): | ||
out = kwargs.get("out", ()) | ||
for x in inputs + out: | ||
# Only support operations with instances of _HANDLED_TYPES. | ||
# Use ArrayLike instead of type(self) for isinstance to | ||
# allow subclasses that don't override __array_ufunc__ to | ||
# handle ArrayLike objects. | ||
if not isinstance(x, self._HANDLED_TYPES + (ArrayLike,)): | ||
return NotImplemented | ||
|
||
# Defer to the implementation of the ufunc on unwrapped values. | ||
inputs = tuple(self._unwrap_arraylike(inputs)) | ||
if out: | ||
kwargs["out"] = tuple(self._unwrap_arraylike(out)) | ||
result = getattr(ufunc, method)(*inputs, **kwargs) | ||
|
||
if type(result) is tuple: | ||
# Multiple return values | ||
return tuple(type(self)(x) for x in result) | ||
elif method == "at": | ||
# No return value | ||
return None | ||
else: | ||
# one return value | ||
return type(self)(result) | ||
|
||
@staticmethod | ||
def _unwrap_arraylike(values): | ||
return ( | ||
getattr(x, x._VALUE_ATTR) if isinstance(x, ArrayLike) else x for x in values | ||
) | ||
|
||
def __array_function__(self, func, types, args, kwargs): | ||
if func not in self._TYPE_CLOSED_FUNCTIONS: | ||
return NotImplemented | ||
# Note: this allows subclasses that don't override | ||
# __array_function__ to handle MyArray objects | ||
if not all(issubclass(t, ArrayLike) for t in types): | ||
return NotImplemented | ||
# extract wrapped array-like objects from args | ||
updated_args = [] | ||
|
||
for arg in args: | ||
if hasattr(arg, self._VALUE_ATTR): | ||
updated_args.append(arg.__getattribute__(self._VALUE_ATTR)) | ||
else: | ||
updated_args.append(arg) | ||
|
||
# defer to NumPy implementation | ||
result = func(*updated_args, **kwargs) | ||
|
||
# cast to original wrapper if possible | ||
return type(self)(result) if type(result) in self._HANDLED_TYPES else result | ||
|
||
def __array__(self, dtype=None): | ||
# TODO(tpm) We should use `np.asarray` instead of accessing `.tpm` | ||
# whenever the underlying array is needed | ||
return np.asarray(self.__getattribute__(self._VALUE_ATTR), dtype=dtype) | ||
|
||
def __getattr__(self, name): | ||
return getattr(self.__getattribute__(self._VALUE_ATTR), name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# -*- coding: utf-8 -*- | ||
# data_structures/frozen_map.py | ||
|
||
import typing | ||
|
||
K = typing.TypeVar("K") | ||
V = typing.TypeVar("V") | ||
|
||
|
||
class FrozenMap(typing.Mapping[K, V]): | ||
"""An immutable mapping from keys to values.""" | ||
|
||
__slots__ = ("_dict", "_hash") | ||
|
||
def __init__(self, *args, **kwargs): | ||
self._dict: typing.Dict[K, V] = dict(*args, **kwargs) | ||
self._hash: typing.Optional[int] = None | ||
|
||
def __getitem__(self, key: K) -> V: | ||
return self._dict[key] | ||
|
||
def __contains__(self, key: K) -> bool: | ||
return key in self._dict | ||
|
||
def __iter__(self) -> typing.Iterator[K]: | ||
return iter(self._dict) | ||
|
||
def __len__(self) -> int: | ||
return len(self._dict) | ||
|
||
def __repr__(self) -> str: | ||
return f"FrozenMap({repr(self._dict)})" | ||
|
||
def __hash__(self) -> int: | ||
if self._hash is None: | ||
self._hash = hash( | ||
(frozenset(self._dict), frozenset(iter(self._dict.values()))) | ||
) | ||
return self._hash | ||
|
||
def replace(self, /, **changes): | ||
return self.__class__(self, **changes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.