-
Notifications
You must be signed in to change notification settings - Fork 136
/
mongo.py
950 lines (834 loc) · 36.7 KB
/
mongo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
import copy
import random
from datetime import datetime
from .._compat import basestring, long
from ..exceptions import NotOnNOSQLError
from ..helpers.classes import (
FakeCursor, Reference, SQLALL, ConnectionConfigurationMixin)
from ..helpers.methods import use_common_filters, xorify
from ..objects import Field, Row, Query, Expression
from .base import NoSQLAdapter
from . import adapters
try:
from bson import Binary
from bson.binary import USER_DEFINED_SUBTYPE
except:
class Binary(object):
pass
USER_DEFINED_SUBTYPE = 0
@adapters.register_for('mongodb')
class Mongo(ConnectionConfigurationMixin, NoSQLAdapter):
dbengine = 'mongodb'
drivers = ('pymongo',)
def find_driver(self):
super(Mongo, self).find_driver()
#: ensure pymongo version >= 3.0
if 'fake_version' in self.driver_args:
version = self.driver_args['fake_version']
else:
from pymongo import version
if int(version.split('.')[0]) < 3:
raise RuntimeError(
"pydal requires pymongo version >= 3.0, found '%s'" % version)
def _initialize_(self, do_connect):
super(Mongo, self)._initialize_(do_connect)
#: uri parse
from pymongo import uri_parser
m = uri_parser.parse_uri(self.uri)
if isinstance(m, tuple):
m = {"database": m[1]}
if m.get('database') is None:
raise SyntaxError("Database is required!")
self._driver_db = m['database']
#: mongodb imports and utils
from bson.objectid import ObjectId
from bson.son import SON
from pymongo.write_concern import WriteConcern
self.epoch = datetime.fromtimestamp(0)
self.SON = SON
self.ObjectId = ObjectId
self.WriteConcern = WriteConcern
#: options
self.db_codec = 'UTF-8'
# this is the minimum amount of replicates that it should wait
# for on insert/update
self.minimumreplication = self.adapter_args.get(
'minimumreplication', 0)
# by default all inserts and selects are performed asynchronous,
# but now the default is
# synchronous, except when overruled by either this default or
# function parameter
self.safe = 1 if self.adapter_args.get('safe', True) else 0
self._mock_reconnect()
def connector(self):
conn = self.driver.MongoClient(self.uri, w=self.safe)[self._driver_db]
conn.cursor = lambda: FakeCursor()
conn.close = lambda: None
conn.commit = lambda: None
return conn
def _configure_on_first_reconnect(self):
#: server version
self._server_version = self.connection.command(
"serverStatus")['version']
self.server_version = tuple(
[int(x) for x in self._server_version.split('.')])
self.server_version_major = (
self.server_version[0] + self.server_version[1] / 10.0)
def object_id(self, arg=None):
""" Convert input to a valid Mongodb ObjectId instance
self.object_id("<random>") -> ObjectId (not unique) instance """
if not arg:
arg = 0
if isinstance(arg, basestring):
# we assume an integer as default input
rawhex = len(arg.replace("0x", "").replace("L", "")) == 24
if arg.isdigit() and (not rawhex):
arg = int(arg)
elif arg == "<random>":
arg = int("0x%s" % "".join([
random.choice("0123456789abcdef")
for x in range(24)]), 0)
elif arg.isalnum():
if not arg.startswith("0x"):
arg = "0x%s" % arg
try:
arg = int(arg, 0)
except ValueError as e:
raise ValueError(
"invalid objectid argument string: %s" % e)
else:
raise ValueError("Invalid objectid argument string. " +
"Requires an integer or base 16 value")
elif isinstance(arg, self.ObjectId):
return arg
elif isinstance(arg, (Row, Reference)):
return self.object_id(long(arg['id']))
elif not isinstance(arg, (int, long)):
raise TypeError(
"object_id argument must be of type ObjectId or an objectid " +
"representable integer (type %s)" % type(arg))
hexvalue = hex(arg)[2:].rstrip('L').zfill(24)
return self.ObjectId(hexvalue)
def _get_collection(self, tablename, safe=None):
ctable = self.connection[tablename]
if safe is not None and safe != self.safe:
wc = self.WriteConcern(w=self._get_safe(safe))
ctable = ctable.with_options(write_concern=wc)
return ctable
def _get_safe(self, val=None):
if val is None:
return self.safe
return 1 if val else 0
def _regex_select_as_parser(self, colname):
return self.dialect.REGEX_SELECT_AS_PARSER.search(colname)
@staticmethod
def _parse_data(expression, attribute, value=None):
if isinstance(expression, (list, tuple)):
ret = False
for e in expression:
ret = Mongo._parse_data(e, attribute, value) or ret
return ret
if value is not None:
try:
expression._parse_data[attribute] = value
except AttributeError:
return None
try:
return expression._parse_data[attribute]
except (AttributeError, TypeError):
return None
def _expand(self, expression, field_type=None, query_env={}):
if isinstance(expression, Field):
if expression.type == 'id':
result = "_id"
else:
result = expression.name
if self._parse_data(expression, 'pipeline'):
# field names as part of expressions need to start with '$'
result = '$' + result
elif isinstance(expression, (Expression, Query)):
first = expression.first
second = expression.second
if isinstance(first, Field) and "reference" in first.type:
# cast to Mongo ObjectId
if isinstance(second, (tuple, list, set)):
second = [
self.object_id(item) for item in expression.second]
else:
second = self.object_id(expression.second)
op = expression.op
optional_args = expression.optional_args or {}
optional_args['query_env'] = query_env
if second is not None:
result = op(first, second, **optional_args)
elif first is not None:
result = op(first, **optional_args)
elif isinstance(op, str):
result = op
else:
result = op(**optional_args)
elif isinstance(expression, Expansion):
expression.query = (self.expand(expression.query, field_type,
query_env=query_env))
result = expression
elif isinstance(expression, (list, tuple)):
result = [self.represent(item, field_type) for item in expression]
elif field_type:
result = self.represent(expression, field_type)
else:
result = expression
return result
def represent(self, obj, field_type):
if isinstance(obj, self.ObjectId):
return obj
return super(Mongo, self).represent(obj, field_type)
def truncate(self, table, mode, safe=None):
ctable = self.connection[table._tablename]
ctable.delete_many({})
def count(self, query, distinct=None, snapshot=True):
if not isinstance(query, Query):
raise SyntaxError("Type '%s' not supported in count" % type(query))
distinct_fields = []
if distinct is True:
distinct_fields = [x for x in query.first.table if x.name != 'id']
elif distinct:
if isinstance(distinct, Field):
distinct_fields = [distinct]
else:
while (isinstance(distinct, Expression) and
isinstance(distinct.second, Field)):
distinct_fields += [distinct.second]
distinct = distinct.first
if isinstance(distinct, Field):
distinct_fields += [distinct]
distinct = True
expanded = Expansion(
self, 'count', query, fields=distinct_fields, distinct=distinct)
ctable = expanded.get_collection()
if not expanded.pipeline:
return ctable.count(filter=expanded.query_dict)
for record in ctable.aggregate(expanded.pipeline):
return record['count']
return 0
def select(self, query, fields, attributes, snapshot=False):
attributes['snapshot'] = snapshot
return self.__select(query, fields, **attributes)
def __select(self, query, fields, left=False, join=False, distinct=False,
orderby=False, groupby=False, having=False, limitby=False,
orderby_on_limitby=True, for_update=False, outer_scoped=[],
required=None, cache=None, cacheable=None, processor=None,
snapshot=False):
new_fields = []
for item in fields:
if isinstance(item, SQLALL):
new_fields += item._table
else:
new_fields.append(item)
fields = new_fields
tablename = self.get_table(query, *fields)._tablename
if for_update:
self.db.logger.warning(
"Attribute 'for_update' unsupported by MongoDB")
if join or left:
raise NotOnNOSQLError("Joins not supported on NoSQL databases")
if required or cache or cacheable:
self.db.logger.warning(
"Attributes 'required', 'cache' and 'cacheable' are" +
" unsupported by MongoDB")
if limitby and orderby_on_limitby and not orderby:
if groupby:
orderby = groupby
else:
table = self.db[tablename]
orderby = [
table[x] for x in (
hasattr(table, '_primarykey') and
table._primarykey or ['_id'])]
if not orderby:
mongosort_list = []
else:
if snapshot:
raise RuntimeError(
"snapshot and orderby are mutually exclusive")
if isinstance(orderby, (list, tuple)):
orderby = xorify(orderby)
if str(orderby) == '<random>':
# !!!! need to add 'random'
mongosort_list = self.dialect.random
else:
mongosort_list = []
for f in self.expand(orderby).split(','):
include = 1
if f.startswith('-'):
include = -1
f = f[1:]
if f.startswith('$'):
f = f[1:]
mongosort_list.append((f, include))
expanded = Expansion(
self, 'select', query, fields or self.db[tablename],
groupby=groupby, distinct=distinct, having=having)
ctable = self.connection[tablename]
modifiers = {'snapshot': snapshot}
if not expanded.pipeline:
if limitby:
limitby_skip, limitby_limit = limitby[0], int(limitby[1]) - 1
else:
limitby_skip = limitby_limit = 0
mongo_list_dicts = ctable.find(
expanded.query_dict, expanded.field_dicts, skip=limitby_skip,
limit=limitby_limit, sort=mongosort_list, modifiers=modifiers)
null_rows = []
else:
if mongosort_list:
sortby_dict = self.SON()
for f in mongosort_list:
sortby_dict[f[0]] = f[1]
expanded.pipeline.append({'$sort': sortby_dict})
if limitby and limitby[1]:
expanded.pipeline.append({'$limit': limitby[1]})
if limitby and limitby[0]:
expanded.pipeline.append({'$skip': limitby[0]})
mongo_list_dicts = ctable.aggregate(expanded.pipeline)
null_rows = [(None,)]
rows = []
# populate row in proper order
# Here we replace ._id with .id to follow the standard naming
colnames = []
newnames = []
for field in expanded.fields:
if hasattr(field, "tablename"):
if field.name in ('id', '_id'):
# Mongodb reserved uuid key
colname = (tablename + '.' + 'id', '_id')
else:
colname = (field.longname, field.name)
elif not isinstance(query, Expression):
colname = (field.name, field.name)
colnames.append(colname[1])
newnames.append(colname[0])
for record in mongo_list_dicts:
row = []
for colname in colnames:
try:
value = record[colname]
except:
value = None
if self.server_version_major < 2.6:
# '$size' not present in server versions < 2.6
if isinstance(value, list) and '$addToSet' in colname:
value = len(value)
row.append(value)
rows.append(row)
if not rows:
rows = null_rows
processor = processor or self.parse
result = processor(rows, fields, newnames, blob_decode=True)
return result
def check_notnull(self, table, values):
for fieldname in table._notnulls:
if fieldname not in values or values[fieldname] is None:
raise Exception("NOT NULL constraint failed: %s" % fieldname)
def check_unique(self, table, values):
if len(table._uniques) > 0:
db = table._db
unique_queries = []
for fieldname in table._uniques:
if fieldname in values:
value = values[fieldname]
else:
value = table[fieldname].default
unique_queries.append(
Query(db, self.dialect.eq, table[fieldname], value))
if len(unique_queries) > 0:
unique_query = unique_queries[0]
# if more than one field, build a query of ORs
for query in unique_queries[1:]:
unique_query = Query(
db, self.dialect._or, unique_query, query)
if self.count(unique_query, distinct=False) != 0:
for query in unique_queries:
if self.count(query, distinct=False) != 0:
# one of the 'OR' queries failed, see which one
raise Exception(
"NOT UNIQUE constraint failed: %s" %
query.first.name)
def insert(self, table, fields, safe=None):
"""Safe determines whether a asynchronous request is done or a
synchronous action is done
For safety, we use by default synchronous requests"""
values = {}
safe = self._get_safe(safe)
ctable = self._get_collection(table._tablename, safe)
for k, v in fields:
if k.name not in ["id", "safe"]:
fieldname = k.name
fieldtype = table[k.name].type
values[fieldname] = self.represent(v, fieldtype)
# validate notnulls
try:
self.check_notnull(table, values)
except Exception as e:
if hasattr(table, '_on_insert_error'):
return table._on_insert_error(table, fields, e)
raise e
# validate uniques
try:
self.check_unique(table, values)
except Exception as e:
if hasattr(table, '_on_insert_error'):
return table._on_insert_error(table, fields, e)
raise e
# perform the insert
result = ctable.insert_one(values)
if result.acknowledged:
Oid = result.inserted_id
rid = Reference(long(str(Oid), 16))
(rid._table, rid._record) = (table, None)
return rid
else:
return None
def update(self, table, query, fields, safe=None):
# return amount of adjusted rows or zero, but no exceptions
# @ related not finding the result
if not isinstance(query, Query):
raise RuntimeError("Not implemented")
safe = self._get_safe(safe)
if safe:
amount = 0
else:
amount = self.count(query, distinct=False)
if amount == 0:
return amount
expanded = Expansion(self, 'update', query, fields)
ctable = expanded.get_collection(safe)
if expanded.pipeline:
try:
for doc in ctable.aggregate(expanded.pipeline):
result = ctable.replace_one({'_id': doc['_id']}, doc)
if safe and result.acknowledged:
amount += result.matched_count
return amount
except Exception as e:
# TODO Reverse update query to verify that the query succeeded
raise RuntimeError(
"uncaught exception when updating rows: %s" % e)
try:
result = ctable.update_many(
filter=expanded.query_dict,
update={'$set': expanded.field_dicts})
if safe and result.acknowledged:
amount = result.matched_count
return amount
except Exception as e:
# TODO Reverse update query to verify that the query succeeded
raise RuntimeError(
"uncaught exception when updating rows: %s" % e)
def delete(self, table, query, safe=None):
if not isinstance(query, Query):
raise RuntimeError("query type %s is not supported" % type(query))
safe = self._get_safe(safe)
expanded = Expansion(self, 'delete', query)
ctable = expanded.get_collection(safe)
if expanded.pipeline:
deleted = [x['_id'] for x in ctable.aggregate(expanded.pipeline)]
else:
deleted = [x['_id'] for x in ctable.find(expanded.query_dict)]
# find references to deleted items
db = self.db
cascade = []
set_null = []
for field in table._referenced_by:
if field.type == 'reference ' + table._tablename:
if field.ondelete == 'CASCADE':
cascade.append(field)
if field.ondelete == 'SET NULL':
set_null.append(field)
cascade_list = []
set_null_list = []
for field in table._referenced_by_list:
if field.type == 'list:reference ' + table._tablename:
if field.ondelete == 'CASCADE':
cascade_list.append(field)
if field.ondelete == 'SET NULL':
set_null_list.append(field)
# perform delete
result = ctable.delete_many({"_id": {"$in": deleted}})
if result.acknowledged:
amount = result.deleted_count
else:
amount = len(deleted)
# clean up any references
if amount and deleted:
# ::TODO:: test if deleted references cascade
def remove_from_list(field, deleted, safe):
for delete in deleted:
modify = {field.name: delete}
dtable = self._get_collection(field.tablename, safe)
dtable.update_many(
filter=modify, update={'$pull': modify})
# for cascaded items, if the reference is the only item in the
# list, then remove the entire record, else delete reference
# from the list
for field in cascade_list:
for delete in deleted:
modify = {field.name: [delete]}
dtable = self._get_collection(field.tablename, safe)
dtable.delete_many(filter=modify)
remove_from_list(field, deleted, safe)
for field in set_null_list:
remove_from_list(field, deleted, safe)
for field in cascade:
db(field.belongs(deleted)).delete()
for field in set_null:
db(field.belongs(deleted)).update(**{field.name: None})
return amount
def bulk_insert(self, table, items):
return [self.insert(table, item) for item in items]
class Expansion(object):
"""
Class to encapsulate a pydal expression and track the parse
expansion and its results.
Two different MongoDB mechanisms are targeted here. If the query
is sufficiently simple, then simple queries are generated. The
bulk of the complexity here is however to support more complex
queries that are targeted to the MongoDB Aggregation Pipeline.
This class supports four operations: 'count', 'select', 'update'
and 'delete'.
Behavior varies somewhat for each operation type. However
building each pipeline stage is shared where the behavior is the
same (or similar) for the different operations.
In general an attempt is made to build the query without using the
pipeline, and if that fails then the query is rebuilt with the
pipeline.
QUERY constructed in _build_pipeline_query():
$project : used to calculate expressions if needed
$match: filters out records
FIELDS constructed in _expand_fields():
FIELDS:COUNT
$group : filter for distinct if needed
$group: count the records remaining
FIELDS:SELECT
$group : implement aggregations if needed
$project: implement expressions (etc) for select
FIELDS:UPDATE
$project: implement expressions (etc) for update
HAVING constructed in _add_having():
$project : used to calculate expressions
$match: filters out records
$project : used to filter out previous expression fields
"""
def __init__(self, adapter, crud, query, fields=(), tablename=None,
groupby=None, distinct=False, having=None):
self.adapter = adapter
self.NULL_QUERY = {'_id': {
'$gt': self.adapter.ObjectId('000000000000000000000000')}}
self._parse_data = {'pipeline': False, 'need_group':
bool(groupby or distinct or having)}
self.crud = crud
self.having = having
self.distinct = distinct
if not groupby and distinct:
if distinct is True:
# groupby gets all fields
self.groupby = fields
else:
self.groupby = distinct
else:
self.groupby = groupby
if crud == 'update':
self.values = [(f[0], self.annotate_expression(f[1]))
for f in (fields or [])]
self.fields = [f[0] for f in self.values]
else:
self.fields = [self.annotate_expression(f)
for f in (fields or [])]
self.tablename = (tablename or
adapter.get_table(query, *self.fields)._tablename)
if use_common_filters(query):
query = adapter.common_filter(query, [self.tablename])
self.query = self.annotate_expression(query)
# expand the query
self.pipeline = []
self.query_dict = adapter.expand(self.query)
self.field_dicts = adapter.SON()
self.field_groups = adapter.SON()
self.field_groups['_id'] = adapter.SON()
if self._parse_data['pipeline']:
# if the query needs the aggregation engine, set that up
self._build_pipeline_query()
# expand the fields for the aggregation engine
self._expand_fields(None)
else:
# expand the fields
try:
if not self._parse_data['need_group']:
self._expand_fields(self._fields_loop_abort)
else:
self._parse_data['pipeline'] = True
raise StopIteration
except StopIteration:
# if the fields needs the aggregation engine, set that up
self.field_dicts = adapter.SON()
if self.query_dict:
if self.query_dict != self.NULL_QUERY:
self.pipeline = [{'$match': self.query_dict}]
self.query_dict = {}
# expand the fields for the aggregation engine
self._expand_fields(None)
if not self._parse_data['pipeline']:
if crud == 'update':
# do not update id fields
for fieldname in ("_id", "id"):
if fieldname in self.field_dicts:
del self.field_dicts[fieldname]
else:
if crud == 'update':
self._add_all_fields_projection(self.field_dicts)
self.field_dicts = adapter.SON()
elif crud == 'select':
if self._parse_data['need_group']:
if not self.groupby:
# no groupby, aggregate all records
self.field_groups['_id'] = None
# id has no value after aggregations
self.field_dicts['_id'] = False
self.pipeline.append({'$group': self.field_groups})
if self.field_dicts:
self.pipeline.append({'$project': self.field_dicts})
self.field_dicts = adapter.SON()
self._add_having()
elif crud == 'count':
if self._parse_data['need_group']:
self.pipeline.append({'$group': self.field_groups})
self.pipeline.append(
{'$group': {"_id": None, 'count': {"$sum": 1}}})
#elif crud == 'delete':
# pass
@property
def dialect(self):
return self.adapter.dialect
def _build_pipeline_query(self):
# search for anything needing the $match stage.
# currently only '$regex' requires the match stage
def parse_need_match_stage(items, parent, parent_key):
need_match = False
non_matched_indices = []
if isinstance(items, list):
indices = range(len(items))
elif isinstance(items, dict):
indices = items.keys()
else:
return
for i in indices:
if parse_need_match_stage(items[i], items, i):
need_match = True
elif i not in [self.dialect.REGEXP_MARK1,
self.dialect.REGEXP_MARK2]:
non_matched_indices.append(i)
if i == self.dialect.REGEXP_MARK1:
need_match = True
self.query_dict['project'].update(items[i])
parent[parent_key] = items[self.dialect.REGEXP_MARK2]
if need_match:
for i in non_matched_indices:
name = str(items[i])
self.query_dict['project'][name] = items[i]
items[i] = {name: True}
if parent is None and self.query_dict['project']:
self.query_dict['match'] = items
return need_match
expanded = self.adapter.expand(self.query)
if self.dialect.REGEXP_MARK1 in expanded:
# the REGEXP_MARK is at the top of the tree, so can just split
# the regex over a '$project' and a '$match'
self.query_dict = None
match = expanded[self.dialect.REGEXP_MARK2]
project = expanded[self.dialect.REGEXP_MARK1]
else:
self.query_dict = {'project': {}, 'match': {}}
if parse_need_match_stage(expanded, None, None):
project = self.query_dict['project']
match = self.query_dict['match']
else:
project = {'__query__': expanded}
match = {'__query__': True}
if self.crud in ['select', 'update']:
self._add_all_fields_projection(project)
else:
self.pipeline.append({'$project': project})
self.pipeline.append({'$match': match})
self.query_dict = None
def _expand_fields(self, mid_loop):
if self.crud == 'update':
mid_loop = mid_loop or self._fields_loop_update_pipeline
for field, value in self.values:
self._expand_field(field, value, mid_loop)
elif self.crud in ['select', 'count']:
mid_loop = mid_loop or self._fields_loop_select_pipeline
for field in self.fields:
self._expand_field(field, field, mid_loop)
elif self.fields:
raise RuntimeError(self.crud + " not supported with fields")
def _expand_field(self, field, value, mid_loop):
expanded = {}
if isinstance(field, Field):
expanded = self.adapter.expand(value, field.type)
elif isinstance(field, (Expression, Query)):
expanded = self.adapter.expand(field)
field.name = str(expanded)
else:
raise RuntimeError("%s not supported with fields" % type(field))
if mid_loop:
expanded = mid_loop(expanded, field, value)
self.field_dicts[field.name] = expanded
def _fields_loop_abort(self, expanded, *args):
# if we need the aggregation engine, then start over
if self._parse_data['pipeline']:
raise StopIteration()
return expanded
def _fields_loop_update_pipeline(self, expanded, field, value):
if not isinstance(value, Expression):
if self.adapter.server_version_major >= 2.6:
expanded = {'$literal': expanded}
# '$literal' not present in server versions < 2.6
elif field.type in ['string', 'text', 'password']:
expanded = {'$concat': [expanded]}
elif field.type in ['integer', 'bigint', 'float', 'double']:
expanded = {'$add': [expanded]}
elif field.type == 'boolean':
expanded = {'$and': [expanded]}
elif field.type in ['date', 'time', 'datetime']:
expanded = {'$add': [expanded]}
else:
raise RuntimeError(
"updating with expressions not supported for field type " +
"'%s' in MongoDB version < 2.6" % field.type)
return expanded
def _fields_loop_select_pipeline(self, expanded, field, value):
# search for anything needing $group
def parse_groups(items, parent, parent_key):
for item in items:
if isinstance(items[item], list):
for list_item in items[item]:
if isinstance(list_item, dict):
parse_groups(list_item, items[item],
items[item].index(list_item))
elif isinstance(items[item], dict):
parse_groups(items[item], items, item)
if item == self.dialect.GROUP_MARK:
name = str(items)
self.field_groups[name] = items[item]
parent[parent_key] = '$' + name
return items
if self.dialect.AS_MARK in field.name:
# The AS_MARK in the field name is used by base to alias the
# result, we don't actually need the AS_MARK in the parse tree
# so we remove it here.
if isinstance(expanded, list):
# AS mark is first element in list, drop it
expanded = expanded[1]
elif self.dialect.AS_MARK in expanded:
# AS mark is element in dict, drop it
del expanded[self.dialect.AS_MARK]
else:
# ::TODO:: should be possible to do this...
raise SyntaxError("AS() not at top of parse tree")
if self.dialect.GROUP_MARK in expanded:
# the GROUP_MARK is at the top of the tree, so can just pass
# the group result straight through the '$project' stage
self.field_groups[field.name] = expanded[self.dialect.GROUP_MARK]
expanded = 1
elif self.dialect.GROUP_MARK in field.name:
# the GROUP_MARK is not at the top of the tree, so we need to
# pass the group results through to a '$project' stage.
expanded = parse_groups(expanded, None, None)
elif self._parse_data['need_group']:
if field in self.groupby:
# this is a 'groupby' field
self.field_groups['_id'][field.name] = expanded
expanded = '$_id.' + field.name
else:
raise SyntaxError("field '%s' not in groupby" % field)
return expanded
def _add_all_fields_projection(self, fields):
for fieldname in self.adapter.db[self.tablename].fields:
# add all fields to projection to pass them through
if fieldname not in fields and fieldname not in ("_id", "id"):
fields[fieldname] = 1
self.pipeline.append({'$project': fields})
def _add_having(self):
if not self.having:
return
self._expand_field(
self.having, None, self._fields_loop_select_pipeline)
fields = {'__having__': self.field_dicts[self.having.name]}
for fieldname in self.pipeline[-1]['$project']:
# add all fields to projection to pass them through
if fieldname not in fields and fieldname not in ("_id", "id"):
fields[fieldname] = 1
self.pipeline.append({'$project': copy.copy(fields)})
self.pipeline.append({'$match': {'__having__': True}})
del fields['__having__']
self.pipeline.append({'$project': fields})
def annotate_expression(self, expression):
def mark_has_field(expression):
if not isinstance(expression, (Expression, Query)):
return False
first_has_field = mark_has_field(expression.first)
second_has_field = mark_has_field(expression.second)
expression.has_field = (isinstance(expression, Field) or
first_has_field or second_has_field)
return expression.has_field
def add_parse_data(child, parent):
if isinstance(child, (Expression, Query)):
child.parse_root = parent.parse_root
child.parse_parent = parent
child.parse_depth = parent.parse_depth + 1
child._parse_data = parent._parse_data
add_parse_data(child.first, child)
add_parse_data(child.second, child)
elif isinstance(child, (list, tuple)):
for c in child:
add_parse_data(c, parent)
if isinstance(expression, (Expression, Query)):
expression.parse_root = expression
expression.parse_depth = -1
expression._parse_data = self._parse_data
add_parse_data(expression, expression)
mark_has_field(expression)
return expression
def get_collection(self, safe=None):
return self.adapter._get_collection(self.tablename, safe)
class MongoBlob(Binary):
MONGO_BLOB_BYTES = USER_DEFINED_SUBTYPE
MONGO_BLOB_NON_UTF8_STR = USER_DEFINED_SUBTYPE + 1
def __new__(cls, value):
# return None and Binary() unmolested
if value is None or isinstance(value, Binary):
return value
# bytearray is marked as MONGO_BLOB_BYTES
if isinstance(value, bytearray):
return Binary.__new__(
cls, bytes(value), MongoBlob.MONGO_BLOB_BYTES)
# return non-strings as Binary(), eg: PY3 bytes()
if not isinstance(value, basestring):
return Binary(value)
# if string is encodable as UTF-8, then return as string
try:
value.encode('utf-8')
return value
except UnicodeDecodeError:
# string which can not be UTF-8 encoded, eg: pickle strings
return Binary.__new__(
cls, value, MongoBlob.MONGO_BLOB_NON_UTF8_STR)
def __repr__(self):
return repr(MongoBlob.decode(self))
@staticmethod
def decode(value):
if isinstance(value, Binary):
if value.subtype == MongoBlob.MONGO_BLOB_BYTES:
return bytearray(value)
if value.subtype == MongoBlob.MONGO_BLOB_NON_UTF8_STR:
return str(value)
return value