Skip to content

Commit

Permalink
fixes: #2943, allow overriding validates for inheritance
Browse files Browse the repository at this point in the history
Applied the patch mentioned in #2943, to allow overriding the
validates method of a given Model, Added tests for same in
test_validators.
If a Child class overrides the parent class validates method
only child class validator will be invoked unless child class
explicitly invokes parent class validator
  • Loading branch information
indiVar0508 committed Nov 1, 2023
1 parent 527fac5 commit d310bef
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 14 deletions.
4 changes: 4 additions & 0 deletions lib/sqlalchemy/orm/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4333,6 +4333,10 @@ def validates(
modify or replace the value before proceeding. The function should
otherwise return the given value.
Overriding validator method will invoke child validator method, in
order to also invoke parent validator method as well child validator
can explicitly invoke parent class validator(s).
Note that a validator for a collection **cannot** issue a load of that
collection within the validation routine - this usage raises
an assertion to avoid recursion overflows. This is a reentrant
Expand Down
29 changes: 15 additions & 14 deletions lib/sqlalchemy/orm/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,23 @@ def _register_attribute(
impl_class=None,
**kw,
):
listen_hooks = []
pre_validate_hooks = []
post_validate_hooks = []

uselist = useobject and prop.uselist

if useobject and prop.single_parent:
listen_hooks.append(single_parent_validator)

if prop.key in prop.parent.validators:
fn, opts = prop.parent.validators[prop.key]
listen_hooks.append(
lambda desc, prop: orm_util._validator_events(
desc, prop.key, fn, **opts
)
)
pre_validate_hooks.append(single_parent_validator)

if useobject:
listen_hooks.append(unitofwork.track_cascade_events)
post_validate_hooks.append(unitofwork.track_cascade_events)

# need to assemble backref listeners
# after the singleparentvalidator, mapper validator
if useobject:
backref = prop.back_populates
if backref and prop._effective_sync_backref:
listen_hooks.append(
post_validate_hooks.append(
lambda desc, prop: attributes.backref_listeners(
desc, backref, uselist
)
Expand All @@ -114,7 +107,6 @@ def _register_attribute(
# mapper here might not be prop.parent; also, a subclass mapper may
# be called here before a superclass mapper. That is, can't depend
# on mappers not already being set up so we have to check each one.

for m in mapper.self_and_descendants:
if prop is m._props.get(
prop.key
Expand All @@ -140,7 +132,16 @@ def _register_attribute(
**kw,
)

for hook in listen_hooks:
for hook in pre_validate_hooks:
hook(desc, prop)

for super_m in m.iterate_to_root():
if prop.key in super_m.validators:
fn, opts = super_m.validators[prop.key]
orm_util._validator_events(desc, prop.key, fn, **opts)
break

for hook in post_validate_hooks:
hook(desc, prop)


Expand Down
72 changes: 72 additions & 0 deletions test/orm/test_validators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from unittest.mock import call
from unittest.mock import Mock

from sqlalchemy import Column
from sqlalchemy import exc
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.orm import collections
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import validates
from sqlalchemy.testing import assert_raises
Expand Down Expand Up @@ -447,3 +451,71 @@ def validate_user(self, key, item):
call("user", User(addresses=[])),
],
)

def test_validator_inheritance_override_validator(self):
Base = declarative_base()

class A(Base):
__tablename__ = "a"
id = Column(Integer, primary_key=True)
data = Column(String)
foo = Column(String)

@validates("data")
def validate_data(self, key, value):
return "Call from A : " + value

@validates("foo")
def validate_foo(self, key, value):
ne_(value, "exclude for A")
return value

class B(A):
foo2 = Column(String)
bar = Column(String)

@validates("data")
def validate_data(self, key, value):
return "Call from B : " + value

@validates("foo")
def validate_foo(self, key, value):
# Test Calling both validators
value = super(B, self).validate_foo(key, value)
ne_(value, "exclude for B")
return value

@validates("foo2", "bar")
def validate_foobar(self, key, value):
if key == "foo2":
return value + "_"
return "_" + value

class C(B):
@validates("foo2", "bar")
def validate_foobar(self, key, value):
if key == "foo2":
return value + "-"
return "-" + value

obj = A(data="ed")
eq_(obj.data, "Call from A : ed")
assert_raises(AssertionError, setattr, obj, "foo", "exclude for A")
obj.foo = "exclude for B"

obj = B(data="ed")
eq_(obj.data, "Call from B : ed")
# Should call A's Validator
assert_raises(AssertionError, setattr, obj, "foo", "exclude for A")
# Should call B's Validator
assert_raises(AssertionError, setattr, obj, "foo", "exclude for B")
obj.foo = "Some other value"

obj.foo2 = "foo"
obj.bar = "bar"
eq_(obj.foo2 + obj.bar, "foo__bar")

obj = C(data="ed")
obj.foo2 = "foo"
obj.bar = "bar"
eq_(obj.foo2 + obj.bar, "foo--bar")

0 comments on commit d310bef

Please sign in to comment.