Skip to content

Commit

Permalink
Merge pull request #110 from douglas-treadwell/support_field_type_enum
Browse files Browse the repository at this point in the history
Adding support for field(type=EnumClass).
  • Loading branch information
tobgu committed Jun 4, 2017
2 parents e8d5b50 + 7e1083d commit 12caeac
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 4 deletions.
6 changes: 3 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ by providing an iterable of types.
PTypeError: Invalid type for field BRecord.x, was float
Note that the Enum type introduced in Python 3 is an iterable. This makes it impossible to
use as a single, standalone, type. This can be worked around by wrapping it in a tuple.
This is trick is valid for all types that are also iterables. See #83 for more information.
Custom types (classes) that are iterable should be wrapped in a tuple to prevent their
members being added to the set of valid types. Although Enums in particular are now
supported without wrapping, see #83 for more information.

Mandatory fields
****************
Expand Down
22 changes: 21 additions & 1 deletion pyrsistent/_field_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@
_restore_pickle)


try:
from enum import Enum as _Enum
except:
class _Enum(object): pass
# no objects will be instances of this class


def isenum(type_):
try:
return issubclass(type_, _Enum)
except TypeError:
return False # type_ is not a class


def set_fields(dct, bases, name):
dct[name] = dict(sum([list(b.__dict__.get(name, {}).items()) for b in bases], []))

Expand Down Expand Up @@ -76,7 +90,13 @@ def field(type=PFIELD_NO_TYPE, invariant=PFIELD_NO_INVARIANT, initial=PFIELD_NO_
:param factory: function called when field is set.
:param serializer: function that returns a serialized version of the field
"""
types = set(type) if isinstance(type, Iterable) and not isinstance(type, six.string_types) else set([type])

if isinstance(type, Iterable) and not isinstance(type, six.string_types) and not isenum(type):
# Enums and strings are iterable types
types = set(type)
else:
types = set([type])

invariant_function = wrap_invariant(invariant) if invariant != PFIELD_NO_INVARIANT and callable(invariant) else invariant
field = _PField(type=types, invariant=invariant_function, initial=initial,
mandatory=mandatory, factory=factory, serializer=serializer)
Expand Down
18 changes: 18 additions & 0 deletions tests/field_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pyrsistent import field


def test_enum():
try:
from enum import Enum
except ImportError:
# skip enum test if Enums are not available
return

class TestEnum(Enum):
x = 1
y = 2

f = field(type=TestEnum)

assert TestEnum in f.type
assert len(f.type) == 1
16 changes: 16 additions & 0 deletions tests/record_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,22 @@ class BRecord(PRecord):
assert a.y == 2


def test_enum_field():
try:
from enum import Enum
except ImportError:
return # Enum not supported in this environment

class TestEnum(Enum):
x = 1
y = 2

class RecordContainingEnum(PRecord):
enum_field = field(type=TestEnum)

r = RecordContainingEnum(enum_field=TestEnum.x)
assert r.enum_field == TestEnum.x

def test_type_specification_must_be_a_type():
with pytest.raises(TypeError):
class BRecord(PRecord):
Expand Down

0 comments on commit 12caeac

Please sign in to comment.