Skip to content

Commit

Permalink
Use built-in __subclasses__() instead of registering
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Jurke authored and lkraider committed Jun 27, 2017
1 parent 82943da commit 60f8285
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
6 changes: 0 additions & 6 deletions schematics/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,6 @@ def set_owner_model(field, klass):
set_owner_model(field, klass)
field.name = field_name

# Register class on ancestor models
klass._subclasses = []
for base in klass.__mro__[1:]:
if isinstance(base, ModelMeta):
base._subclasses.append(klass)

return klass

@classmethod
Expand Down
26 changes: 22 additions & 4 deletions schematics/types/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..exceptions import ValidationError, ConversionError, ModelValidationError, StopValidation
from ..transforms import export_loop, EMPTY_LIST, EMPTY_DICT

from .base import BaseType

from six import iteritems
Expand Down Expand Up @@ -357,10 +358,7 @@ def find_model(self, data):
if self.claim_function:
chosen_class = self.claim_function(self, data)
else:
candidates = self.model_classes
if self.allow_subclasses:
candidates = itertools.chain.from_iterable(
([m] + m._subclasses for m in candidates))
candidates = self._get_candidates()
fallback = None
matching_classes = []
for cls in candidates:
Expand Down Expand Up @@ -400,5 +398,25 @@ def export_loop(self, model_instance, field_converter,
elif print_none:
return shaped

def _get_candidates(self):
candidates = self.model_classes

if self.allow_subclasses:
candidates = itertools.chain.from_iterable(
([m] + get_all_subclasses(m) for m in candidates)
)

return candidates


def get_all_subclasses(cls):
all_subclasses = []

for subclass in cls.__subclasses__():
all_subclasses.append(subclass)
all_subclasses.extend(get_all_subclasses(subclass))

return all_subclasses


from ..models import Model
13 changes: 7 additions & 6 deletions tests/test_polymodeltype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from schematics.models import Model
from schematics.types import StringType
from schematics.types.compound import PolyModelType
from schematics.types.compound import PolyModelType, get_all_subclasses


class A(Model): # fallback model (doesn't define a claim method)
Expand Down Expand Up @@ -42,11 +42,13 @@ class Foo(Model):
cfn = PolyModelType([B, C], claim_function=claim_func, strict=False)


def test_subclass_registry():
def test_get_all_subclasses():
assert A.__subclasses__() == [Aaa, B]
assert B.__subclasses__() == [C]
assert C.__subclasses__() == []

assert get_all_subclasses(A) == [Aaa, B, C]

assert A._subclasses == [Aaa, B, C]
assert B._subclasses == [C]
assert C._subclasses == []

def test_inheritance_based_polymorphic(): # base

Expand Down Expand Up @@ -115,4 +117,3 @@ def test_refuse_unrelated_export():
foo = Foo()
foo.strict = Aaa()
foo.to_primitive()

0 comments on commit 60f8285

Please sign in to comment.