From d7cce6c5f6064dce46d13334255548033583be35 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 21 Mar 2017 15:35:06 +0300 Subject: [PATCH] Fix bug when discriminator column is used as a part of a primary key: http://stackoverflow.com/questions/42860579/return-none-as-classtype-in-entity-inheritance-on-ponyorm --- pony/orm/core.py | 11 ++++++++--- pony/orm/tests/test_inheritance.py | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 46ec8208e..85e47d295 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3881,19 +3881,24 @@ def _set_rbits(entity, objects, attrs): obj._rbits_ |= rbits & ~wbits def _parse_row_(entity, row, attr_offsets): discr_attr = entity._discriminator_attr_ - if not discr_attr: real_entity_subclass = entity + if not discr_attr: + discr_value = None + real_entity_subclass = entity else: discr_offset = attr_offsets[discr_attr][0] discr_value = discr_attr.validate(row[discr_offset], None, entity, from_db=True) real_entity_subclass = discr_attr.code2cls[discr_value] + discr_value = real_entity_subclass._discriminator_ # To convert unicode to str in Python 2.x avdict = {} for attr in real_entity_subclass._attrs_: offsets = attr_offsets.get(attr) if offsets is None or attr.is_discriminator: continue avdict[attr] = attr.parse_value(row, offsets) - if not entity._pk_is_composite_: pkval = avdict.pop(entity._pk_attrs_[0], None) - else: pkval = tuple(avdict.pop(attr, None) for attr in entity._pk_attrs_) + + pkval = tuple(avdict.pop(attr, discr_value) for attr in entity._pk_attrs_) + assert None not in pkval + if not entity._pk_is_composite_: pkval = pkval[0] return real_entity_subclass, pkval, avdict def _load_many_(entity, objects): database = entity._database_ diff --git a/pony/orm/tests/test_inheritance.py b/pony/orm/tests/test_inheritance.py index 4b4173528..6538d4abf 100644 --- a/pony/orm/tests/test_inheritance.py +++ b/pony/orm/tests/test_inheritance.py @@ -257,5 +257,29 @@ class Entity3(Entity1): result = select(e for e in Entity1 if e.b == 30 or e.c == 50) self.assertEqual([ e.id for e in result ], [ 2, 3 ]) + def test_discriminator_1(self): + db = Database('sqlite', ':memory:') + class Entity1(db.Entity): + a = Discriminator(str) + b = Required(int) + PrimaryKey(a, b) + class Entity2(db.Entity1): + c = Required(int) + db.generate_mapping(create_tables=True) + with db_session: + x = Entity1(b=10) + y = Entity2(b=20, c=30) + with db_session: + obj = Entity1.get(b=20) + self.assertEqual(obj.a, 'Entity2') + self.assertEqual(obj.b, 20) + self.assertEqual(obj._pkval_, ('Entity2', 20)) + with db_session: + obj = Entity1['Entity2', 20] + self.assertEqual(obj.a, 'Entity2') + self.assertEqual(obj.b, 20) + self.assertEqual(obj._pkval_, ('Entity2', 20)) + + if __name__ == '__main__': unittest.main()