Skip to content

Commit

Permalink
Fix bug when discriminator column is used as a part of a primary key:
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Mar 21, 2017
1 parent 40e3b07 commit d7cce6c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pony/orm/core.py
Expand Up @@ -3881,19 +3881,24 @@ def _set_rbits(entity, objects, attrs):
obj._rbits_ |= rbits & ~wbits
def _parse_row_(entity, row, attr_offsets):
discr_attr = entity._discriminator_attr_
if not discr_attr: real_entity_subclass = entity
if not discr_attr:
discr_value = None
real_entity_subclass = entity
else:
discr_offset = attr_offsets[discr_attr][0]
discr_value = discr_attr.validate(row[discr_offset], None, entity, from_db=True)
real_entity_subclass = discr_attr.code2cls[discr_value]
discr_value = real_entity_subclass._discriminator_ # To convert unicode to str in Python 2.x

avdict = {}
for attr in real_entity_subclass._attrs_:
offsets = attr_offsets.get(attr)
if offsets is None or attr.is_discriminator: continue
avdict[attr] = attr.parse_value(row, offsets)
if not entity._pk_is_composite_: pkval = avdict.pop(entity._pk_attrs_[0], None)
else: pkval = tuple(avdict.pop(attr, None) for attr in entity._pk_attrs_)

pkval = tuple(avdict.pop(attr, discr_value) for attr in entity._pk_attrs_)
assert None not in pkval
if not entity._pk_is_composite_: pkval = pkval[0]
return real_entity_subclass, pkval, avdict
def _load_many_(entity, objects):
database = entity._database_
Expand Down
24 changes: 24 additions & 0 deletions pony/orm/tests/test_inheritance.py
Expand Up @@ -257,5 +257,29 @@ class Entity3(Entity1):
result = select(e for e in Entity1 if e.b == 30 or e.c == 50)
self.assertEqual([ e.id for e in result ], [ 2, 3 ])

def test_discriminator_1(self):
db = Database('sqlite', ':memory:')
class Entity1(db.Entity):
a = Discriminator(str)
b = Required(int)
PrimaryKey(a, b)
class Entity2(db.Entity1):
c = Required(int)
db.generate_mapping(create_tables=True)
with db_session:
x = Entity1(b=10)
y = Entity2(b=20, c=30)
with db_session:
obj = Entity1.get(b=20)
self.assertEqual(obj.a, 'Entity2')
self.assertEqual(obj.b, 20)
self.assertEqual(obj._pkval_, ('Entity2', 20))
with db_session:
obj = Entity1['Entity2', 20]
self.assertEqual(obj.a, 'Entity2')
self.assertEqual(obj.b, 20)
self.assertEqual(obj._pkval_, ('Entity2', 20))


if __name__ == '__main__':
unittest.main()

0 comments on commit d7cce6c

Please sign in to comment.