/
sqlalchemy_django_query.py
248 lines (211 loc) · 8.24 KB
/
sqlalchemy_django_query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# -*- coding: utf-8 -*-
# flake8: noqa
"""
sqlalchemy_django_query
~~~~~~~~~~~~~~~~~~~~~~~
A module that implements a more Django like interface for SQLAlchemy
query objects. It's still API compatible with the regular one but
extends it with Djangoisms.
Example queries::
Post.query.filter_by(pub_date__year=2008)
Post.query.exclude_by(id=42)
User.query.filter_by(name__istartswith='e')
Post.query.filter_by(blog__name__exact='something')
Post.query.order_by('-blog__name')
:copyright: 2011 by Armin Ronacher, Mike Bayer.
license: BSD, see LICENSE for more details.
"""
from sqlalchemy import exc
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.base import _entity_descriptor
from sqlalchemy.orm.query import Query
from sqlalchemy.sql import extract
from sqlalchemy.sql import operators
from sqlalchemy.util import to_list
from sqlalchemy_utils.functions import orm
from flask_jsonapi import exceptions
def joinedload_all(column):
elements = column.split('.')
joined = joinedload(elements.pop(0))
for element in elements:
joined = joined.joinedload(element)
return joined
class DjangoQueryMixin(object):
"""Can be mixed into any Query class of SQLAlchemy and extends it to
implements more Django like behavior:
- `filter_by` supports implicit joining and subitem accessing with
double underscores.
- `exclude_by` works like `filter_by` just that every expression is
automatically negated.
- `order_by` supports ordering by field name with an optional `-`
in front.
"""
_underscore_operators = {
'eq': operators.eq,
'ne': operators.ne,
'gt': operators.gt,
'lt': operators.lt,
'gte': operators.ge,
'lte': operators.le,
'contains': operators.contains_op,
'notcontains': operators.notcontains_op,
'exact': operators.eq,
'iexact': operators.ilike_op,
'startswith': operators.startswith_op,
'istartswith': lambda c, x: c.ilike(x.replace('%', '%%') + '%'),
'iendswith': lambda c, x: c.ilike('%' + x.replace('%', '%%')),
'endswith': operators.endswith_op,
'isnull': lambda c, x: x and c != None or c == None,
'range': operators.between_op,
'year': lambda c, x: extract('year', c) == x,
'month': lambda c, x: extract('month', c) == x,
'day': lambda c, x: extract('day', c) == x
}
_underscore_list_operators = {
'in': operators.in_op,
'notin': operators.notin_op,
}
def filter_by(self, **kwargs):
return self._filter_or_exclude(False, kwargs)
def exclude_by(self, **kwargs):
return self._filter_or_exclude(True, kwargs)
def select_related(self, *columns, **options):
depth = options.pop('depth', None)
if options:
raise TypeError('Unexpected argument %r' % iter(options).next())
if depth not in (None, 1):
raise TypeError('Depth can only be 1 or None currently')
need_all = depth is None
columns = list(columns)
for idx, column in enumerate(columns):
column = column.replace('__', '.')
if '.' in column:
need_all = True
columns[idx] = column
func = (need_all and joinedload_all or joinedload)
return self.options(func(*columns))
def order_by(self, *args):
args = list(args)
joins_needed = []
for idx, arg in enumerate(args):
q = self
if not isinstance(arg, str):
continue
if arg[0] in '+-':
desc = arg[0] == '-'
arg = arg[1:]
else:
desc = False
column = None
for token in arg.split('__'):
column = get_column(self._filter_by_zero(), token, joins_needed)
if column and column.impl.uses_objects:
q = q.join(column)
joins_needed.append(column)
column = None
if column is None:
raise exceptions.InvalidSort(
"You can't sort on {}, {}".format(token, str(arg)))
if desc:
column = column.desc()
args[idx] = column
q = super(DjangoQueryMixin, self).order_by(*args)
for join in joins_needed:
q = q.join(join)
return q
def _filter_or_exclude(self, negate, kwargs):
q = self
negate_if = lambda expr: expr if not negate else ~expr
column = None
joins_needed = []
for arg, value in kwargs.items():
for token in arg.split('__'):
if column is None:
column = get_column(self._filter_by_zero(), token, joins_needed)
if column and column.impl.uses_objects:
q = q.join(column)
joins_needed.append(column)
column = None
elif token in self._underscore_operators:
op = self._underscore_operators[token]
q = q.filter(negate_if(op(column, *to_list(value))))
column = None
elif token in self._underscore_list_operators:
op = self._underscore_list_operators[token]
q = q.filter(negate_if(op(column, to_list(value))))
column = None
else:
raise ValueError('No idea what to do with %r' % token)
if column is not None:
q = q.filter(negate_if(column == value))
column = None
q = q.reset_joinpoint()
return q
class DjangoQuery(DjangoQueryMixin, Query):
pass
def get_column(joinpoint, token, joins_needed):
try:
result = _entity_descriptor(get_mapper(joinpoint), token)
if type(result) != property:
return result
except exc.InvalidRequestError:
pass
for join in joins_needed:
try:
return _entity_descriptor(join, token)
except exc.InvalidRequestError:
pass
def get_mapper(mixed):
"""
Return related SQLAlchemy Mapper for given SQLAlchemy object.
:param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object
::
from sqlalchemy_utils import get_mapper
get_mapper(User)
get_mapper(User())
get_mapper(User.__table__)
get_mapper(User.__mapper__)
get_mapper(sa.orm.aliased(User))
get_mapper(sa.orm.aliased(User.__table__))
Raises:
ValueError: if multiple mappers were found for given argument
.. versionadded: 0.26.1
"""
if isinstance(mixed, orm._MapperEntity):
mixed = mixed.expr
elif isinstance(mixed, orm.sa.Column):
mixed = mixed.table
elif isinstance(mixed, orm._ColumnEntity):
mixed = mixed.expr
if isinstance(mixed, orm.sa.orm.Mapper):
return mixed
if isinstance(mixed, orm.sa.orm.util.AliasedClass):
return orm.sa.inspect(mixed).mapper
if isinstance(mixed, orm.sa.sql.selectable.Alias):
mixed = mixed.element
if isinstance(mixed, orm.AliasedInsp):
return mixed.mapper
if isinstance(mixed, orm.sa.orm.attributes.InstrumentedAttribute):
mixed = mixed.class_
if isinstance(mixed, orm.sa.Table):
if hasattr(orm.mapperlib, '_all_registries'):
all_mappers = set()
for mapper_registry in orm.mapperlib._all_registries():
all_mappers.update(mapper_registry.mappers)
else: # SQLAlchemy <1.4
all_mappers = orm.mapperlib._mapper_registry
mappers = [
mapper for mapper in all_mappers
if mixed in {mapper.local_table}
]
if len(mappers) > 1:
raise Exception('Still to many mappers %s' % str(mappers))
if not mappers:
raise ValueError(
"Could not get mapper for table '%s'." % mixed.name
)
else:
return mappers[0]
if not orm.isclass(mixed):
mixed = type(mixed)
return orm.sa.inspect(mixed)