Skip to content

Commit

Permalink
Merge 6219cb2 into 66caaf1
Browse files Browse the repository at this point in the history
  • Loading branch information
neilvyas committed Feb 20, 2018
2 parents 66caaf1 + 6219cb2 commit c718e59
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 53 deletions.
78 changes: 56 additions & 22 deletions pyrsistent/_checked_types.py
@@ -1,5 +1,7 @@
from collections import Iterable
import six

from pyrsistent._compat import Enum, string_types
from pyrsistent._pmap import PMap, pmap
from pyrsistent._pset import PSet, pset
from pyrsistent._pvector import PythonPVector, python_pvector
Expand Down Expand Up @@ -45,18 +47,50 @@ def __str__(self):
missing_fields=', '.join(self.missing_fields))


_preserved_iterables = (
Enum,
)
"""For types that are themselves iterable, we would mistakenly take their
elements for the type annotation. This set defines types to 'preserve' by not
iterating over them."""


def maybe_type_to_list(t):
"""Try to coerce a user-supplied type directive into a list of types."""
# TODO FIXME: Move type validation inside this function to consolidate it.
# Preserve type annotations that are thmeselves iterable.
if isinstance(t, type) and issubclass(t, _preserved_iterables):
return [t]
# Preserve type annotations given as strings.
elif isinstance(t, string_types):
return [t]
else:
if not isinstance(t, Iterable):
return [t]
else:
return list(t)


def maybe_types_to_list(ts):
res = []

for t in ts:
res.extend(maybe_type_to_list(t))

return tuple(res)


def _store_types(dct, bases, destination_name, source_name):
def to_list(elem):
if not isinstance(elem, Iterable) or isinstance(elem, six.string_types):
return [elem]
return list(elem)

dct[destination_name] = to_list(dct[source_name]) if source_name in dct else []
dct[destination_name] += sum([to_list(b.__dict__[source_name]) for b in bases if source_name in b.__dict__], [])
dct[destination_name] = tuple(dct[destination_name])
if not all(isinstance(t, type) or isinstance(t, six.string_types) for t in dct[destination_name]):
maybe_types = maybe_types_to_list([
d[source_name]
for d in ([dct] + [b.__dict__ for b in bases]) if source_name in d
])

if not all(isinstance(t, type) or isinstance(t, string_types) for t in maybe_types):
raise TypeError('Type specifications must be types or strings')

dct[destination_name] = maybe_types


def _merge_invariant_results(result):
verdict = True
Expand Down Expand Up @@ -208,19 +242,19 @@ def optional(*typs):


def _checked_type_create(cls, source_data, _factory_fields=None):
if isinstance(source_data, cls):
return source_data

# Recursively apply create methods of checked types if the types of the supplied data
# does not match any of the valid types.
types = get_types(cls._checked_types)
checked_type = next((t for t in types if issubclass(t, CheckedType)), None)
if checked_type:
return cls([checked_type.create(data)
if not any(isinstance(data, t) for t in types) else data
for data in source_data])

return cls(source_data)
if isinstance(source_data, cls):
return source_data

# Recursively apply create methods of checked types if the types of the supplied data
# does not match any of the valid types.
types = get_types(cls._checked_types)
checked_type = next((t for t in types if issubclass(t, CheckedType)), None)
if checked_type:
return cls([checked_type.create(data)
if not any(isinstance(data, t) for t in types) else data
for data in source_data])

return cls(source_data)

@six.add_metaclass(_CheckedTypeMeta)
class CheckedPVector(PythonPVector, CheckedType):
Expand Down
9 changes: 9 additions & 0 deletions pyrsistent/_compat.py
@@ -0,0 +1,9 @@
from six import string_types


# enum compat
try:
from enum import Enum
except:
class Enum(object): pass
# no objects will be instances of this class
43 changes: 22 additions & 21 deletions pyrsistent/_field_common.py
@@ -1,23 +1,20 @@
from collections import Iterable
import six
from pyrsistent._checked_types import (
CheckedType, CheckedPSet, CheckedPMap, CheckedPVector,
optional as optional_type, InvariantException, get_type, wrap_invariant,
_restore_pickle, get_type)


try:
from enum import Enum as _Enum
except:
class _Enum(object): pass
# no objects will be instances of this class


def isenum(type_):
try:
return issubclass(type_, _Enum)
except TypeError:
return False # type_ is not a class
from pyrsistent._checked_types import (
CheckedPMap,
CheckedPSet,
CheckedPVector,
CheckedType,
InvariantException,
_restore_pickle,
get_type,
maybe_type_to_list,
maybe_types_to_list,
)
from pyrsistent._checked_types import optional as optional_type
from pyrsistent._checked_types import wrap_invariant
from pyrsistent._compat import Enum


def set_fields(dct, bases, name):
Expand Down Expand Up @@ -91,11 +88,15 @@ def field(type=PFIELD_NO_TYPE, invariant=PFIELD_NO_INVARIANT, initial=PFIELD_NO_
:param serializer: function that returns a serialized version of the field
"""

if isinstance(type, Iterable) and not isinstance(type, six.string_types) and not isenum(type):
# Enums and strings are iterable types
types = set(type)
# NB: We have to check this predicate separately from the predicates in
# `maybe_type_to_list` et al. because this one is related to supporting the
# argspec for `field`, while those are related to supporting the valid ways
# to specify types.
# Multiple types must be passed in a tuple or a list.
if isinstance(type, (tuple, list)):
types = set(maybe_types_to_list(type))
else:
types = set([type])
types = set(maybe_type_to_list(type))

invariant_function = wrap_invariant(invariant) if invariant != PFIELD_NO_INVARIANT and callable(invariant) else invariant
field = _PField(type=types, invariant=invariant_function, initial=initial,
Expand Down
25 changes: 15 additions & 10 deletions tests/field_test.py
@@ -1,18 +1,23 @@
from pyrsistent import field
from pyrsistent._compat import Enum

from pyrsistent import field, pvector_field

def test_enum():
try:
from enum import Enum
except ImportError:
# skip enum test if Enums are not available
return

class TestEnum(Enum):
x = 1
y = 2
class TestEnum(Enum):
x = 1
y = 2


def test_enum():
f = field(type=TestEnum)

assert TestEnum in f.type
assert len(f.type) == 1


# This is meant to exercise `_seq_field`.
def test_pvector_field_enum_type():
f = pvector_field(TestEnum)

assert len(f.type) == 1
assert TestEnum is list(f.type)[0].__type__

0 comments on commit c718e59

Please sign in to comment.