Skip to content

Commit

Permalink
Set default relation query_class to BaseQuery.
Browse files Browse the repository at this point in the history
Set default query_class for `db.relation`, `db.relationship`, and
  `db.dynamic_loader` to Flask-SQLAlchemy's BaseQuery.
  • Loading branch information
dplepage committed Mar 14, 2011
1 parent f81a428 commit 9cbe5bc
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGES
Expand Up @@ -7,6 +7,8 @@ release.
Version 0.12
````````````

- Set default query_class for `db.relation`, `db.relationship`, and
`db.dynamic_loader` to Flask-SQLAlchemy's BaseQuery.

Version 0.11
````````````
Expand Down
21 changes: 21 additions & 0 deletions flaskext/sqlalchemy.py
Expand Up @@ -12,6 +12,7 @@
import re
import sys
import time
import functools
import sqlalchemy
from math import ceil
from functools import partial
Expand Down Expand Up @@ -59,13 +60,33 @@ def _make_table(*args, **kwargs):
return _make_table


def _set_default_query_class(d):
if 'query_class' not in d:
d['query_class'] = BaseQuery


def _wrap_with_default_query_class(fn):
@functools.wraps(fn)
def newfn(*args, **kwargs):
_set_default_query_class(kwargs)
if "backref" in kwargs:
backref = kwargs['backref']
if isinstance(backref, basestring):
backref = (backref, {})
_set_default_query_class(backref[1])
return fn(*args, **kwargs)
return newfn


def _include_sqlalchemy(obj):
for module in sqlalchemy, sqlalchemy.orm:
for key in module.__all__:
if not hasattr(obj, key):
setattr(obj, key, getattr(module, key))
obj.Table = _make_table(obj)
obj.relationship = _wrap_with_default_query_class(obj.relationship)
obj.relation = _wrap_with_default_query_class(obj.relation)
obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader)


class _DebugQueryTuple(tuple):
Expand Down
23 changes: 23 additions & 0 deletions test_sqlalchemy.py
Expand Up @@ -187,5 +187,28 @@ def test_basic_pagination(self):
[1, 2, None, 8, 9, 10, 11, 12, 13, 14, None, 24, 25])


class DefaultQueryClassTestCase(unittest.TestCase):

def test_default_query_class(self):
app = flask.Flask(__name__)
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
app.config['TESTING'] = True
db = sqlalchemy.SQLAlchemy(app)

class Parent(db.Model):
id = db.Column(db.Integer, primary_key=True)
children = db.relationship("Child", backref = db.backref("parents", lazy='dynamic'), lazy='dynamic')
class Child(db.Model):
id = db.Column(db.Integer, primary_key=True)
parent_id = db.Column(db.Integer, db.ForeignKey('parent.id'))
p = Parent()
c = Child()
c.parent = p
self.assertEqual(type(Parent.query), sqlalchemy.BaseQuery)
self.assertEqual(type(Child.query), sqlalchemy.BaseQuery)
self.assert_(isinstance(p.children, sqlalchemy.BaseQuery))
self.assert_(isinstance(c.parents, sqlalchemy.BaseQuery))


if __name__ == '__main__':
unittest.main()

0 comments on commit 9cbe5bc

Please sign in to comment.