Skip to content

Commit

Permalink
Merge pull request #79 from wmayner/feature/tpm-class
Browse files Browse the repository at this point in the history
ExplicitTPM: Numpy universal funcs and NDArrayOperatorsMixin
  • Loading branch information
wmayner committed Jan 3, 2023
2 parents 79f2c8d + a6eb15a commit ebc19e6
Show file tree
Hide file tree
Showing 13 changed files with 263 additions and 106 deletions.
8 changes: 3 additions & 5 deletions pyphi/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,12 +588,10 @@ def always_zero(a, b):
)

PARALLEL_PURVIEW_EVALUATION = Option(
4.0,
type=float,
False,
doc="""
Controls parallel evaluation of candidate purviews. If mechanism size is
greater or equal than this floating point value, parallelization will occur. A
value of ``math.inf`` will enforce sequential processing.""",
Controls parallel evaluation of candidate purviews. A numeric value may
be used to threshold parallelization on mechanism size (inclusive).""",
)

PARALLEL_MECHANISM_PARTITION_EVALUATION = Option(
Expand Down
9 changes: 9 additions & 0 deletions pyphi/data_structures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# data_structures/__init__.py

from ordered_set import OrderedSet

from .array_like import ArrayLike
from .frozen_map import FrozenMap
from .hashable_ordered_set import HashableOrderedSet
84 changes: 84 additions & 0 deletions pyphi/data_structures/array_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# data_structures/array_like.py

import numpy as np
from numpy.lib.mixins import NDArrayOperatorsMixin
from numbers import Number


class ArrayLike(NDArrayOperatorsMixin):
# Only support operations with instances of _HANDLED_TYPES.
_HANDLED_TYPES = (np.ndarray, list, Number)

# TODO(tpm) populate this list
_TYPE_CLOSED_FUNCTIONS = (
np.concatenate,
np.stack,
np.all,
np.sum,
)

# Holds the underlying array
_VALUE_ATTR = "value"

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out = kwargs.get("out", ())
for x in inputs + out:
# Only support operations with instances of _HANDLED_TYPES.
# Use ArrayLike instead of type(self) for isinstance to
# allow subclasses that don't override __array_ufunc__ to
# handle ArrayLike objects.
if not isinstance(x, self._HANDLED_TYPES + (ArrayLike,)):
return NotImplemented

# Defer to the implementation of the ufunc on unwrapped values.
inputs = tuple(self._unwrap_arraylike(inputs))
if out:
kwargs["out"] = tuple(self._unwrap_arraylike(out))
result = getattr(ufunc, method)(*inputs, **kwargs)

if type(result) is tuple:
# Multiple return values
return tuple(type(self)(x) for x in result)
elif method == "at":
# No return value
return None
else:
# one return value
return type(self)(result)

@staticmethod
def _unwrap_arraylike(values):
return (
getattr(x, x._VALUE_ATTR) if isinstance(x, ArrayLike) else x for x in values
)

def __array_function__(self, func, types, args, kwargs):
if func not in self._TYPE_CLOSED_FUNCTIONS:
return NotImplemented
# Note: this allows subclasses that don't override
# __array_function__ to handle MyArray objects
if not all(issubclass(t, ArrayLike) for t in types):
return NotImplemented
# extract wrapped array-like objects from args
updated_args = []

for arg in args:
if hasattr(arg, self._VALUE_ATTR):
updated_args.append(arg.__getattribute__(self._VALUE_ATTR))
else:
updated_args.append(arg)

# defer to NumPy implementation
result = func(*updated_args, **kwargs)

# cast to original wrapper if possible
return type(self)(result) if type(result) in self._HANDLED_TYPES else result

def __array__(self, dtype=None):
# TODO(tpm) We should use `np.asarray` instead of accessing `.tpm`
# whenever the underlying array is needed
return np.asarray(self.__getattribute__(self._VALUE_ATTR), dtype=dtype)

def __getattr__(self, name):
return getattr(self.__getattribute__(self._VALUE_ATTR), name)
42 changes: 42 additions & 0 deletions pyphi/data_structures/frozen_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# data_structures/frozen_map.py

import typing

K = typing.TypeVar("K")
V = typing.TypeVar("V")


class FrozenMap(typing.Mapping[K, V]):
"""An immutable mapping from keys to values."""

__slots__ = ("_dict", "_hash")

def __init__(self, *args, **kwargs):
self._dict: typing.Dict[K, V] = dict(*args, **kwargs)
self._hash: typing.Optional[int] = None

def __getitem__(self, key: K) -> V:
return self._dict[key]

def __contains__(self, key: K) -> bool:
return key in self._dict

def __iter__(self) -> typing.Iterator[K]:
return iter(self._dict)

def __len__(self) -> int:
return len(self._dict)

def __repr__(self) -> str:
return f"FrozenMap({repr(self._dict)})"

def __hash__(self) -> int:
if self._hash is None:
self._hash = hash(
(frozenset(self._dict), frozenset(iter(self._dict.values())))
)
return self._hash

def replace(self, /, **changes):
return self.__class__(self, **changes)
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# data_structures.py

import typing
# data_structures/hashable_ordered_set.py

from ordered_set import OrderedSet


class HashableOrderedSet(OrderedSet):
"""An OrderedSet that implements the hash method.
Expand Down Expand Up @@ -51,42 +47,3 @@ def __getstate__(self):

def __setstate__(self, state):
self.__init__(state[0])


K = typing.TypeVar("K")
V = typing.TypeVar("V")


class FrozenMap(typing.Mapping[K, V]):
"""An immutable mapping from keys to values."""

__slots__ = ("_dict", "_hash")

def __init__(self, *args, **kwargs):
self._dict: typing.Dict[K, V] = dict(*args, **kwargs)
self._hash: typing.Optional[int] = None

def __getitem__(self, key: K) -> V:
return self._dict[key]

def __contains__(self, key: K) -> bool:
return key in self._dict

def __iter__(self) -> typing.Iterator[K]:
return iter(self._dict)

def __len__(self) -> int:
return len(self._dict)

def __repr__(self) -> str:
return f"FrozenMap({repr(self._dict)})"

def __hash__(self) -> int:
if self._hash is None:
self._hash = hash(
(frozenset(self._dict), frozenset(iter(self._dict.values())))
)
return self._hash

def replace(self, /, **changes):
return self.__class__(self, **changes)
7 changes: 5 additions & 2 deletions pyphi/metrics/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,10 @@ def repertoire_distance(r1, r2, direction=None, repertoire_distance=None, **kwar
func_key = fallback(repertoire_distance, config.REPERTOIRE_DISTANCE)
func = measures[func_key]
try:
distance = func(r1, r2, direction=direction, **kwargs)
try:
distance = func(r1, r2, direction=direction, **kwargs)
except TypeError:
distance = func(r1, r2, **kwargs)
except TypeError:
distance = func(r1, r2, **kwargs)
distance = func(r1, r2, direction=direction)
return round(distance, config.PRECISION)
11 changes: 7 additions & 4 deletions pyphi/subsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
):
# The network this subsystem belongs to.
validate.is_network(network)
network._tpm = ExplicitTPM.enforce(network.tpm)
network._tpm = network.tpm
self.network = network

self.node_labels = network.node_labels
Expand Down Expand Up @@ -329,7 +329,7 @@ def _single_node_cause_repertoire(self, mechanism_node_index, purview):
mechanism_node = self._index2node[mechanism_node_index]
# We're conditioning on this node's state, so take the TPM for the node
# being in that state.
tpm = ExplicitTPM.enforce(mechanism_node.tpm[..., mechanism_node.state])
tpm = mechanism_node.tpm[..., mechanism_node.state]
# Marginalize-out all parents of this mechanism node that aren't in the
# purview.
return tpm.marginalize_out((mechanism_node.inputs - purview)).tpm
Expand Down Expand Up @@ -392,7 +392,7 @@ def _single_node_effect_repertoire(
# pylint: disable=missing-docstring
purview_node = self._index2node[purview_node_index]
# Condition on the state of the purview inputs that are in the mechanism
purview_node.tpm = ExplicitTPM.enforce(purview_node.tpm)
purview_node.tpm = purview_node.tpm
tpm = purview_node.tpm.condition_tpm(condition)
# TODO(4.0) remove reference to TPM
# Marginalize-out the inputs that aren't in the mechanism.
Expand Down Expand Up @@ -1076,7 +1076,10 @@ def find_mice(self, direction, mechanism, purviews=False, **kwargs):
Returns:
MaximallyIrreducibleCauseOrEffect: The |MIC| or |MIE|.
"""
parallel = len(mechanism) >= config.PARALLEL_PURVIEW_EVALUATION
parallel = (
bool(config.PARALLEL_PURVIEW_EVALUATION)
and len(mechanism) >= config.PARALLEL_PURVIEW_EVALUATION
)

purviews = self.potential_purviews(direction, mechanism, purviews)

Expand Down
65 changes: 33 additions & 32 deletions pyphi/tpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
# tpm.py

"""
Provides the TPM, ExplicitTPM, and ImplicitTPM classes.
Provides the ExplicitTPM and related classes.
"""

from itertools import chain
from typing import Mapping, Set

import numpy as np

from . import config, convert, exceptions
from . import config, convert, data_structures, exceptions
from .constants import OFF, ON
from .data_structures import FrozenMap
from .utils import all_states, np_hash, np_immutable


# TODO(tpm) remove pending ArrayLike refactor
class ProxyMetaclass(type):
"""A metaclass to create wrappers for the TPM array's special attributes.
Expand All @@ -30,7 +32,7 @@ class ProxyMetaclass(type):
1. Manually "overload" all the necessary methods.
2. Use this metaclass to introspect the underlying array
and automatically overload methods in our custom TPM class definition.
and automatically overload methods in our custom TPM class definition.
"""

def __init__(cls, type_name, bases, dct):
Expand Down Expand Up @@ -116,16 +118,19 @@ def proxy(self):

type.__init__(cls, type_name, bases, dct)

if cls.__wraps__:
ignore = cls.__ignore__
# Go through all the attribute strings in the wrapped array type.
for name in dir(cls.__wraps__):
# Filter special attributes. The rest will be handled
# by `__getattr__()`.
if name.startswith("__") and name not in ignore and name not in dct:
# Create proxy function for `name` and bind it to future
# instances of cls.
setattr(cls, name, property(make_proxy(name)))
if not cls.__wraps__:
return

ignore = cls.__ignore__

# Go through all the attribute strings in the wrapped array type.
for name in dir(cls.__wraps__):
# Filter special attributes, rest will be handled by `__getattr__()`
if any([not name.startswith("__"), name in ignore, name in dct]):
continue

# Create function for `name` and bind to future instances of `cls`.
setattr(cls, name, property(make_proxy(name)))


class Wrapper(metaclass=ProxyMetaclass):
Expand Down Expand Up @@ -153,11 +158,16 @@ def __init__(self):
raise ValueError(f"Wrapped object must be of type {self.__wraps__}")


class ExplicitTPM(Wrapper):
class ExplicitTPM(data_structures.ArrayLike):

"""An explicit network TPM in multidimensional form."""

_VALUE_ATTR = "_tpm"

# TODO(tpm) remove pending ArrayLike refactor
__wraps__ = np.ndarray

# TODO(tpm) remove pending ArrayLike refactor
# Casting semantics: values belonging to our custom TPM class should
# remain closed under the following methods:

Expand Down Expand Up @@ -209,13 +219,11 @@ class ExplicitTPM(Wrapper):
}
)

# Proxy access to regular attributes of the wrapped array.
def __getattr__(self, name):
# Fix error with serialization. TODO: Implement dumps(), tobytes()?
if "_tpm" not in vars(self):
raise AttributeError

return _new_attribute(name, self.__closures__, self._tpm)
if name in self.__closures__:
return _new_attribute(name, self.__closures__, self._tpm)
else:
return getattr(self.__getattribute__(self._VALUE_ATTR), name)

def __init__(self, tpm, validate=False):
self._tpm = np.array(tpm)
Expand Down Expand Up @@ -509,17 +517,6 @@ def array_equal(self, o: object):
"""
return isinstance(o, type(self)) and np.array_equal(self._tpm, o._tpm)

@classmethod
def enforce(cls, tpm: object):
"""Create a new TPM object if necessary.
This acts as a partially applied ternary operator with the condition set
to type-checking the input.
"""
if not isinstance(tpm, cls):
return cls(tpm)
return tpm

def __str__(self):
return self.__repr__()

Expand Down Expand Up @@ -552,8 +549,12 @@ def reconstitute_tpm(subsystem):
return np.concatenate(node_tpms, axis=-1)


# TODO(tpm) remove pending ArrayLike refactor
def _new_attribute(
name: str, closures: Set[str], tpm: ExplicitTPM.__wraps__, cls=ExplicitTPM
name: str,
closures: Set[str],
tpm: ExplicitTPM.__wraps__,
cls=ExplicitTPM
) -> object:
"""Helper function to return adequate proxy attributes for TPM arrays.
Expand Down

0 comments on commit ebc19e6

Please sign in to comment.