Skip to content

Commit

Permalink
Merge pull request #125 from neilvyas/neilvyas/fix-pvector-field-type…
Browse files Browse the repository at this point in the history
…-enum-arg

Fix how fields handle type arguments.
  • Loading branch information
tobgu committed Mar 26, 2018
2 parents 66caaf1 + f05ea24 commit def099e
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 53 deletions.
90 changes: 68 additions & 22 deletions pyrsistent/_checked_types.py
@@ -1,5 +1,7 @@
from collections import Iterable
import six

from pyrsistent._compat import Enum, string_types
from pyrsistent._pmap import PMap, pmap
from pyrsistent._pset import PSet, pset
from pyrsistent._pvector import PythonPVector, python_pvector
Expand Down Expand Up @@ -45,17 +47,61 @@ def __str__(self):
missing_fields=', '.join(self.missing_fields))


_preserved_iterable_types = (
Enum,
)
"""Some types are themselves iterable, but we want to use the type itself and
not its members for the type specification. This defines a set of such types
that we explicitly preserve.
Note that strings are not such types because the string inputs we pass in are
values, not types.
"""


def maybe_parse_user_type(t):
"""Try to coerce a user-supplied type directive into a list of types.
This function should be used in all places where a user specifies a type,
for consistency.
The policy for what defines valid user input should be clear from the implementation.
"""
is_type = isinstance(t, type)
is_preserved = isinstance(t, type) and issubclass(t, _preserved_iterable_types)
is_string = isinstance(t, string_types)
is_iterable = isinstance(t, Iterable)

if is_preserved:
return [t]
elif is_string:
return [t]
elif is_type and not is_iterable:
return [t]
elif is_iterable:
# Recur to validate contained types as well.
ts = t
return tuple(e for t in ts for e in maybe_parse_user_type(t))
else:
# If this raises because `t` cannot be formatted, so be it.
raise TypeError(
'Type specifications must be types or strings. Input: {}'.format(t)
)


def maybe_parse_many_user_types(ts):
# Just a different name to communicate that you're parsing multiple user
# inputs. `maybe_parse_user_type` handles the iterable case anyway.
return maybe_parse_user_type(ts)


def _store_types(dct, bases, destination_name, source_name):
def to_list(elem):
if not isinstance(elem, Iterable) or isinstance(elem, six.string_types):
return [elem]
return list(elem)
maybe_types = maybe_parse_many_user_types([
d[source_name]
for d in ([dct] + [b.__dict__ for b in bases]) if source_name in d
])

dct[destination_name] = to_list(dct[source_name]) if source_name in dct else []
dct[destination_name] += sum([to_list(b.__dict__[source_name]) for b in bases if source_name in b.__dict__], [])
dct[destination_name] = tuple(dct[destination_name])
if not all(isinstance(t, type) or isinstance(t, six.string_types) for t in dct[destination_name]):
raise TypeError('Type specifications must be types or strings')
dct[destination_name] = maybe_types


def _merge_invariant_results(result):
Expand Down Expand Up @@ -208,19 +254,19 @@ def optional(*typs):


def _checked_type_create(cls, source_data, _factory_fields=None):
if isinstance(source_data, cls):
return source_data

# Recursively apply create methods of checked types if the types of the supplied data
# does not match any of the valid types.
types = get_types(cls._checked_types)
checked_type = next((t for t in types if issubclass(t, CheckedType)), None)
if checked_type:
return cls([checked_type.create(data)
if not any(isinstance(data, t) for t in types) else data
for data in source_data])

return cls(source_data)
if isinstance(source_data, cls):
return source_data

# Recursively apply create methods of checked types if the types of the supplied data
# does not match any of the valid types.
types = get_types(cls._checked_types)
checked_type = next((t for t in types if issubclass(t, CheckedType)), None)
if checked_type:
return cls([checked_type.create(data)
if not any(isinstance(data, t) for t in types) else data
for data in source_data])

return cls(source_data)

@six.add_metaclass(_CheckedTypeMeta)
class CheckedPVector(PythonPVector, CheckedType):
Expand Down
9 changes: 9 additions & 0 deletions pyrsistent/_compat.py
@@ -0,0 +1,9 @@
from six import string_types


# enum compat
try:
from enum import Enum
except:
class Enum(object): pass
# no objects will be instances of this class
47 changes: 26 additions & 21 deletions pyrsistent/_field_common.py
@@ -1,23 +1,20 @@
from collections import Iterable
import six
from pyrsistent._checked_types import (
CheckedType, CheckedPSet, CheckedPMap, CheckedPVector,
optional as optional_type, InvariantException, get_type, wrap_invariant,
_restore_pickle, get_type)


try:
from enum import Enum as _Enum
except:
class _Enum(object): pass
# no objects will be instances of this class


def isenum(type_):
try:
return issubclass(type_, _Enum)
except TypeError:
return False # type_ is not a class
from pyrsistent._checked_types import (
CheckedPMap,
CheckedPSet,
CheckedPVector,
CheckedType,
InvariantException,
_restore_pickle,
get_type,
maybe_parse_user_type,
maybe_parse_many_user_types,
)
from pyrsistent._checked_types import optional as optional_type
from pyrsistent._checked_types import wrap_invariant
from pyrsistent._compat import Enum


def set_fields(dct, bases, name):
Expand Down Expand Up @@ -91,11 +88,19 @@ def field(type=PFIELD_NO_TYPE, invariant=PFIELD_NO_INVARIANT, initial=PFIELD_NO_
:param serializer: function that returns a serialized version of the field
"""

if isinstance(type, Iterable) and not isinstance(type, six.string_types) and not isenum(type):
# Enums and strings are iterable types
types = set(type)
# NB: We have to check this predicate separately from the predicates in
# `maybe_parse_user_type` et al. because this one is related to supporting
# the argspec for `field`, while those are related to supporting the valid
# ways to specify types.

# Multiple types must be passed in one of the following containers. Note
# that a type that is a subclass of one of these containers, like a
# `collections.namedtuple`, will work as expected, since we check
# `isinstance` and not `issubclass`.
if isinstance(type, (list, set, tuple)):
types = set(maybe_parse_many_user_types(type))
else:
types = set([type])
types = set(maybe_parse_user_type(type))

invariant_function = wrap_invariant(invariant) if invariant != PFIELD_NO_INVARIANT and callable(invariant) else invariant
field = _PField(type=types, invariant=invariant_function, initial=initial,
Expand Down
29 changes: 19 additions & 10 deletions tests/field_test.py
@@ -1,18 +1,27 @@
from pyrsistent import field
from pyrsistent._compat import Enum

from pyrsistent import field, pvector_field

def test_enum():
try:
from enum import Enum
except ImportError:
# skip enum test if Enums are not available
return

class TestEnum(Enum):
x = 1
y = 2
# NB: This derives from the internal `pyrsistent._compat.Enum` in order to
# simplify coverage across python versions. Since we use
# `pyrsistent._compat.Enum` in `pyrsistent`'s implementation, it's useful to
# use it in the test coverage as well, for consistency.
class TestEnum(Enum):
x = 1
y = 2


def test_enum():
f = field(type=TestEnum)

assert TestEnum in f.type
assert len(f.type) == 1


# This is meant to exercise `_seq_field`.
def test_pvector_field_enum_type():
f = pvector_field(TestEnum)

assert len(f.type) == 1
assert TestEnum is list(f.type)[0].__type__

0 comments on commit def099e

Please sign in to comment.