/
generator.py
3666 lines (2968 loc) · 151 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import logging
import re
import typing as t
from collections import defaultdict
from functools import reduce
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSON_PATH_PART_TRANSFORMS
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
logger = logging.getLogger("sqlglot")
ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)")
class _Generator(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
# Remove transforms that correspond to unsupported JSONPathPart expressions
for part in ALL_JSON_PATH_PARTS - klass.SUPPORTED_JSON_PATH_PARTS:
klass.TRANSFORMS.pop(part, None)
return klass
class Generator(metaclass=_Generator):
"""
Generator converts a given syntax tree to the corresponding SQL string.
Args:
pretty: Whether to format the produced SQL string.
Default: False.
identify: Determines when an identifier should be quoted. Possible values are:
False (default): Never quote, except in cases where it's mandatory by the dialect.
True or 'always': Always quote.
'safe': Only quote identifiers that are case insensitive.
normalize: Whether to normalize identifiers to lowercase.
Default: False.
pad: The pad size in a formatted string. For example, this affects the indentation of
a projection in a query, relative to its nesting level.
Default: 2.
indent: The indentation size in a formatted string. For example, this affects the
indentation of subqueries and filters under a `WHERE` clause.
Default: 2.
normalize_functions: How to normalize function names. Possible values are:
"upper" or True (default): Convert names to uppercase.
"lower": Convert names to lowercase.
False: Disables function name normalization.
unsupported_level: Determines the generator's behavior when it encounters unsupported expressions.
Default ErrorLevel.WARN.
max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
leading_comma: Whether the comma is leading or trailing in select expressions.
This is only relevant when generating in pretty mode.
Default: False
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
comments: Whether to preserve comments in the output SQL code.
Default: True
"""
TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
**JSON_PATH_PART_TRANSFORMS,
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
exp.BackupProperty: lambda self, e: f"BACKUP {self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda _,
e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self,
e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.ClusteredColumnConstraint: lambda self,
e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda *_: "COPY GRANTS",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda *_: "EXTERNAL",
exp.GlobalProperty: lambda *_: "GLOBAL",
exp.HeapProperty: lambda *_: "HEAP",
exp.IcebergProperty: lambda *_: "ICEBERG",
exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
exp.JSONExtract: lambda self, e: self.func(
"JSON_EXTRACT", e.this, e.expression, *e.expressions
),
exp.JSONExtractScalar: lambda self, e: self.func(
"JSON_EXTRACT_SCALAR", e.this, e.expression, *e.expressions
),
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda _, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda *_: "MATERIALIZED",
exp.NonClusteredColumnConstraint: lambda self,
e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.NoPrimaryIndexProperty: lambda *_: "NO PRIMARY INDEX",
exp.NotForReplicationColumnConstraint: lambda *_: "NOT FOR REPLICATION",
exp.OnCommitProperty: lambda _,
e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.RemoteWithConnectionModelProperty: lambda self,
e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SharingProperty: lambda self, e: f"SHARING={self.sql(e, 'this')}",
exp.SqlReadWriteProperty: lambda _, e: e.name,
exp.SqlSecurityProperty: lambda _,
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.StabilityProperty: lambda _, e: e.name,
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TransientProperty: lambda *_: "TRANSIENT",
exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE",
exp.UnloggedProperty: lambda *_: "UNLOGGED",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
exp.VolatileProperty: lambda *_: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
}
# Whether null ordering is supported in order by
# True: Full Support, None: No support, False: No support in window specifications
NULL_ORDERING_SUPPORTED: t.Optional[bool] = True
# Whether ignore nulls is inside the agg or outside.
# FIRST(x IGNORE NULLS) OVER vs FIRST (x) IGNORE NULLS OVER
IGNORE_NULLS_IN_FUNC = False
# Whether locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
# Always do union distinct or union all
EXPLICIT_UNION = False
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
# Whether create function uses an AS before the RETURN
CREATE_FUNCTION_RETURN_AS = True
# Whether MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
# Whether the INTERVAL expression works only with values like '1 day'
SINGLE_STRING_INTERVAL = False
# Whether the plural form of date parts like day (i.e. "days") is supported in INTERVALs
INTERVAL_ALLOWS_PLURAL_FORM = True
# Whether limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
# Whether limit and fetch allows expresions or just limits
LIMIT_ONLY_LITERALS = False
# Whether a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
# The separator for grouping sets and rollups
GROUPINGS_SEP = ","
# The string used for creating an index on a table
INDEX_ON = "ON"
# Whether join hints should be generated
JOIN_HINTS = True
# Whether table hints should be generated
TABLE_HINTS = True
# Whether query hints should be generated
QUERY_HINTS = True
# What kind of separator to use for query hints
QUERY_HINT_SEP = ", "
# Whether comparing against booleans (e.g. x IS TRUE) is supported
IS_BOOL_ALLOWED = True
# Whether to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement
DUPLICATE_KEY_UPDATE_WITH_SET = True
# Whether to generate the limit as TOP <value> instead of LIMIT <value>
LIMIT_IS_TOP = False
# Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
RETURNING_END = True
# Whether to generate the (+) suffix for columns used in old-style join conditions
COLUMN_JOIN_MARKS_SUPPORTED = False
# Whether to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True
# Whether TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
TZ_TO_WITH_TIME_ZONE = False
# Whether the NVL2 function is supported
NVL2_SUPPORTED = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
# Whether VALUES statements can be used as derived tables.
# MySQL 5 and Redshift do not allow this, so when False, it will convert
# SELECT * VALUES into SELECT UNION
VALUES_AS_TABLE = True
# Whether the word COLUMN is included when adding a column with ALTER TABLE
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = True
# UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery)
UNNEST_WITH_ORDINALITY = True
# Whether FILTER (WHERE cond) can be used for conditional aggregation
AGGREGATE_FILTER_SUPPORTED = True
# Whether JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds
SEMI_ANTI_JOIN_WITH_SIDE = True
# Whether to include the type of a computed column in the CREATE DDL
COMPUTED_COLUMN_WITH_TYPE = True
# Whether CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY
SUPPORTS_TABLE_COPY = True
# Whether parentheses are required around the table sample's expression
TABLESAMPLE_REQUIRES_PARENS = True
# Whether a table sample clause's size needs to be followed by the ROWS keyword
TABLESAMPLE_SIZE_IS_ROWS = True
# The keyword(s) to use when generating a sample clause
TABLESAMPLE_KEYWORDS = "TABLESAMPLE"
# Whether the TABLESAMPLE clause supports a method name, like BERNOULLI
TABLESAMPLE_WITH_METHOD = True
# The keyword to use when specifying the seed of a sample clause
TABLESAMPLE_SEED_KEYWORD = "SEED"
# Whether COLLATE is a function instead of a binary operator
COLLATE_IS_FUNC = False
# Whether data types support additional specifiers like e.g. CHAR or BYTE (oracle)
DATA_TYPE_SPECIFIERS_ALLOWED = False
# Whether conditions require booleans WHERE x = 0 vs WHERE x
ENSURE_BOOLS = False
# Whether the "RECURSIVE" keyword is required when defining recursive CTEs
CTE_RECURSIVE_KEYWORD_REQUIRED = True
# Whether CONCAT requires >1 arguments
SUPPORTS_SINGLE_ARG_CONCAT = True
# Whether LAST_DAY function supports a date part argument
LAST_DAY_SUPPORTS_DATE_PART = True
# Whether named columns are allowed in table aliases
SUPPORTS_TABLE_ALIAS_COLUMNS = True
# Whether UNPIVOT aliases are Identifiers (False means they're Literals)
UNPIVOT_ALIASES_ARE_IDENTIFIERS = True
# What delimiter to use for separating JSON key/value pairs
JSON_KEY_VALUE_PAIR_SEP = ":"
# INSERT OVERWRITE TABLE x override
INSERT_OVERWRITE = " OVERWRITE TABLE"
# Whether the SELECT .. INTO syntax is used instead of CTAS
SUPPORTS_SELECT_INTO = False
# Whether UNLOGGED tables can be created
SUPPORTS_UNLOGGED_TABLES = False
# Whether the CREATE TABLE LIKE statement is supported
SUPPORTS_CREATE_TABLE_LIKE = True
# Whether the LikeProperty needs to be specified inside of the schema clause
LIKE_PROPERTY_INSIDE_SCHEMA = False
# Whether DISTINCT can be followed by multiple args in an AggFunc. If not, it will be
# transpiled into a series of CASE-WHEN-ELSE, ultimately using a tuple conseisting of the args
MULTI_ARG_DISTINCT = True
# Whether the JSON extraction operators expect a value of type JSON
JSON_TYPE_REQUIRED_FOR_EXTRACTION = False
# Whether bracketed keys like ["foo"] are supported in JSON paths
JSON_PATH_BRACKETED_KEY_SUPPORTED = True
# Whether to escape keys using single quotes in JSON paths
JSON_PATH_SINGLE_QUOTE_ESCAPE = False
# The JSONPathPart expressions supported by this dialect
SUPPORTED_JSON_PATH_PARTS = ALL_JSON_PATH_PARTS.copy()
# Whether any(f(x) for x in array) can be implemented by this dialect
CAN_IMPLEMENT_ARRAY_ANY = False
# Whether the function TO_NUMBER is supported
SUPPORTS_TO_NUMBER = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
exp.DataType.Type.MEDIUMTEXT: "TEXT",
exp.DataType.Type.LONGTEXT: "TEXT",
exp.DataType.Type.TINYTEXT: "TEXT",
exp.DataType.Type.MEDIUMBLOB: "BLOB",
exp.DataType.Type.LONGBLOB: "BLOB",
exp.DataType.Type.TINYBLOB: "BLOB",
exp.DataType.Type.INET: "INET",
}
STAR_MAPPING = {
"except": "EXCEPT",
"replace": "REPLACE",
}
TIME_PART_SINGULARS = {
"MICROSECONDS": "MICROSECOND",
"SECONDS": "SECOND",
"MINUTES": "MINUTE",
"HOURS": "HOUR",
"DAYS": "DAY",
"WEEKS": "WEEK",
"MONTHS": "MONTH",
"QUARTERS": "QUARTER",
"YEARS": "YEAR",
}
AFTER_HAVING_MODIFIER_TRANSFORMS = {
"cluster": lambda self, e: self.sql(e, "cluster"),
"distribute": lambda self, e: self.sql(e, "distribute"),
"qualify": lambda self, e: self.sql(e, "qualify"),
"sort": lambda self, e: self.sql(e, "sort"),
"windows": lambda self, e: (
self.seg("WINDOW ") + self.expressions(e, key="windows", flat=True)
if e.args.get("windows")
else ""
),
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
PARAMETER_TOKEN = "@"
NAMED_PLACEHOLDER_TOKEN = ":"
PROPERTIES_LOCATION = {
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA,
exp.BackupProperty: exp.Properties.Location.POST_SCHEMA,
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA,
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA,
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
exp.DictProperty: exp.Properties.Location.POST_SCHEMA,
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA,
exp.ExternalProperty: exp.Properties.Location.POST_CREATE,
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
exp.GlobalProperty: exp.Properties.Location.POST_CREATE,
exp.HeapProperty: exp.Properties.Location.POST_WITH,
exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA,
exp.IcebergProperty: exp.Properties.Location.POST_CREATE,
exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA,
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
exp.LockProperty: exp.Properties.Location.POST_SCHEMA,
exp.LockingProperty: exp.Properties.Location.POST_ALIAS,
exp.LogProperty: exp.Properties.Location.POST_NAME,
exp.MaterializedProperty: exp.Properties.Location.POST_CREATE,
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
exp.OnProperty: exp.Properties.Location.POST_SCHEMA,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SampleProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
exp.Set: exp.Properties.Location.POST_SCHEMA,
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA,
exp.SharingProperty: exp.Properties.Location.POST_EXPRESSION,
exp.SequenceProperties: exp.Properties.Location.POST_EXPRESSION,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.UnloggedProperty: exp.Properties.Location.POST_CREATE,
exp.ViewAttributeProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
}
# Keywords that can't be used as unquoted identifier names
RESERVED_KEYWORDS: t.Set[str] = set()
# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Create,
exp.Delete,
exp.Drop,
exp.From,
exp.Insert,
exp.Join,
exp.Select,
exp.Union,
exp.Update,
exp.Where,
exp.With,
)
# Expressions that should not have their comments generated in maybe_comment
EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Binary,
exp.Union,
)
# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Column,
exp.Literal,
exp.Neg,
exp.Paren,
)
PARAMETERIZABLE_TEXT_TYPES = {
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.VARCHAR,
exp.DataType.Type.CHAR,
exp.DataType.Type.NCHAR,
}
# Expressions that need to have all CTEs under them bubbled up to them
EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
"pretty",
"identify",
"normalize",
"pad",
"_indent",
"normalize_functions",
"unsupported_level",
"max_unsupported",
"leading_comma",
"max_text_width",
"comments",
"dialect",
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
)
def __init__(
self,
pretty: t.Optional[bool] = None,
identify: str | bool = False,
normalize: bool = False,
pad: int = 2,
indent: int = 2,
normalize_functions: t.Optional[str | bool] = None,
unsupported_level: ErrorLevel = ErrorLevel.WARN,
max_unsupported: int = 3,
leading_comma: bool = False,
max_text_width: int = 80,
comments: bool = True,
dialect: DialectType = None,
):
import sqlglot
from sqlglot.dialects import Dialect
self.pretty = pretty if pretty is not None else sqlglot.pretty
self.identify = identify
self.normalize = normalize
self.pad = pad
self._indent = indent
self.unsupported_level = unsupported_level
self.max_unsupported = max_unsupported
self.leading_comma = leading_comma
self.max_text_width = max_text_width
self.comments = comments
self.dialect = Dialect.get_or_raise(dialect)
# This is both a Dialect property and a Generator argument, so we prioritize the latter
self.normalize_functions = (
self.dialect.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
)
self.unsupported_messages: t.List[str] = []
self._escaped_quote_end: str = (
self.dialect.tokenizer_class.STRING_ESCAPES[0] + self.dialect.QUOTE_END
)
self._escaped_identifier_end: str = (
self.dialect.tokenizer_class.IDENTIFIER_ESCAPES[0] + self.dialect.IDENTIFIER_END
)
def generate(self, expression: exp.Expression, copy: bool = True) -> str:
"""
Generates the SQL string corresponding to the given syntax tree.
Args:
expression: The syntax tree.
copy: Whether to copy the expression. The generator performs mutations so
it is safer to copy.
Returns:
The SQL string corresponding to `expression`.
"""
if copy:
expression = expression.copy()
expression = self.preprocess(expression)
self.unsupported_messages = []
sql = self.sql(expression).strip()
if self.pretty:
sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
if self.unsupported_level == ErrorLevel.WARN:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
return sql
def preprocess(self, expression: exp.Expression) -> exp.Expression:
"""Apply generic preprocessing transformations to a given expression."""
if (
not expression.parent
and type(expression) in self.EXPRESSIONS_WITHOUT_NESTED_CTES
and any(node.parent is not expression for node in expression.find_all(exp.With))
):
from sqlglot.transforms import move_ctes_to_top_level
expression = move_ctes_to_top_level(expression)
if self.ENSURE_BOOLS:
from sqlglot.transforms import ensure_bools
expression = ensure_bools(expression)
return expression
def unsupported(self, message: str) -> None:
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
def sep(self, sep: str = " ") -> str:
return f"{sep.strip()}\n" if self.pretty else sep
def seg(self, sql: str, sep: str = " ") -> str:
return f"{self.sep(sep)}{sql}"
def pad_comment(self, comment: str) -> str:
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
return comment
def maybe_comment(
self,
sql: str,
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
) -> str:
comments = (
((expression and expression.comments) if comments is None else comments) # type: ignore
if self.comments
else None
)
if not comments or isinstance(expression, self.EXCLUDE_COMMENTS):
return sql
comments_sql = " ".join(
f"/*{self.pad_comment(comment)}*/" for comment in comments if comment
)
if not comments_sql:
return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return (
f"{self.sep()}{comments_sql}{sql}"
if sql[0].isspace()
else f"{comments_sql}{self.sep()}{sql}"
)
return f"{sql} {comments_sql}"
def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
(
self.sql(expression)
if isinstance(expression, exp.UNWRAPPED_QUERIES)
else self.sql(expression, "this")
),
level=1,
pad=0,
)
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str:
original = self.identify
self.identify = False
result = func(*args, **kwargs)
self.identify = original
return result
def normalize_func(self, name: str) -> str:
if self.normalize_functions == "upper" or self.normalize_functions is True:
return name.upper()
if self.normalize_functions == "lower":
return name.lower()
return name
def indent(
self,
sql: str,
level: int = 0,
pad: t.Optional[int] = None,
skip_first: bool = False,
skip_last: bool = False,
) -> str:
if not self.pretty:
return sql
pad = self.pad if pad is None else pad
lines = sql.split("\n")
return "\n".join(
(
line
if (skip_first and i == 0) or (skip_last and i == len(lines) - 1)
else f"{' ' * (level * self._indent + pad)}{line}"
)
for i, line in enumerate(lines)
)
def sql(
self,
expression: t.Optional[str | exp.Expression],
key: t.Optional[str] = None,
comment: bool = True,
) -> str:
if not expression:
return ""
if isinstance(expression, str):
return expression
if key:
value = expression.args.get(key)
if value:
return self.sql(value)
return ""
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
sql = transform(self, expression)
elif isinstance(expression, exp.Expression):
exp_handler_name = f"{expression.key}_sql"
if hasattr(self, exp_handler_name):
sql = getattr(self, exp_handler_name)(expression)
elif isinstance(expression, exp.Func):
sql = self.function_fallback_sql(expression)
elif isinstance(expression, exp.Property):
sql = self.property_sql(expression)
else:
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
return self.maybe_comment(sql, expression) if self.comments and comment else sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
exists_sql = " IF EXISTS" if expression.args.get("exists") else ""
return f"UNCACHE TABLE{exists_sql} {table}"
def cache_sql(self, expression: exp.Cache) -> str:
lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this")
options = expression.args.get("options")
options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else ""
sql = self.sql(expression, "expression")
sql = f" AS{self.sep()}{sql}" if sql else ""
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
return self.prepend_ctes(expression, sql)
def characterset_sql(self, expression: exp.CharacterSet) -> str:
if isinstance(expression.parent, exp.Cast):
return f"CHAR CHARACTER SET {self.sql(expression, 'this')}"
default = "DEFAULT " if expression.args.get("default") else ""
return f"{default}CHARACTER SET={self.sql(expression, 'this')}"
def column_sql(self, expression: exp.Column) -> str:
join_mark = " (+)" if expression.args.get("join_mark") else ""
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
join_mark = ""
self.unsupported("Outer join syntax using the (+) operator is not supported.")
column = ".".join(
self.sql(part)
for part in (
expression.args.get("catalog"),
expression.args.get("db"),
expression.args.get("table"),
expression.args.get("this"),
)
if part
)
return f"{column}{join_mark}"
def columnposition_sql(self, expression: exp.ColumnPosition) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
position = self.sql(expression, "position")
return f"{position}{this}"
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
kind = f"{sep}{kind}" if kind else ""
constraints = f" {constraints}" if constraints else ""
position = self.sql(expression, "position")
position = f" {position}" if position else ""
if expression.find(exp.ComputedColumnConstraint) and not self.COMPUTED_COLUMN_WITH_TYPE:
kind = ""
return f"{exists}{column}{kind}{constraints}{position}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
kind_sql = self.sql(expression, "kind").strip()
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
this = self.sql(expression, "this")
if expression.args.get("not_null"):
persisted = " PERSISTED NOT NULL"
elif expression.args.get("persisted"):
persisted = " PERSISTED"
else:
persisted = ""
return f"AS {this}{persisted}"
def autoincrementcolumnconstraint_sql(self, _) -> str:
return self.token_sql(TokenType.AUTO_INCREMENT)
def compresscolumnconstraint_sql(self, expression: exp.CompressColumnConstraint) -> str:
if isinstance(expression.this, list):
this = self.wrap(self.expressions(expression, key="this", flat=True))
else:
this = self.sql(expression, "this")
return f"COMPRESS {this}"
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
this = ""
if expression.this is not None:
on_null = " ON NULL" if expression.args.get("on_null") else ""
this = " ALWAYS" if expression.this else f" BY DEFAULT{on_null}"
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
increment = f" INCREMENT BY {increment}" if increment else ""
minvalue = expression.args.get("minvalue")
minvalue = f" MINVALUE {minvalue}" if minvalue else ""
maxvalue = expression.args.get("maxvalue")
maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
cycle = expression.args.get("cycle")
cycle_sql = ""
if cycle is not None:
cycle_sql = f"{' NO' if not cycle else ''} CYCLE"
cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql
sequence_opts = ""
if start or increment or cycle_sql:
sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}"
sequence_opts = f" ({sequence_opts.strip()})"
expr = self.sql(expression, "expression")
expr = f"({expr})" if expr else "IDENTITY"
return f"GENERATED{this} AS {expr}{sequence_opts}"
def generatedasrowcolumnconstraint_sql(
self, expression: exp.GeneratedAsRowColumnConstraint
) -> str:
start = "START" if expression.args.get("start") else "END"
hidden = " HIDDEN" if expression.args.get("hidden") else ""
return f"GENERATED ALWAYS AS ROW {start}{hidden}"
def periodforsystemtimeconstraint_sql(
self, expression: exp.PeriodForSystemTimeConstraint
) -> str:
return f"PERIOD FOR SYSTEM_TIME ({self.sql(expression, 'this')}, {self.sql(expression, 'expression')})"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
def transformcolumnconstraint_sql(self, expression: exp.TransformColumnConstraint) -> str:
return f"AS {self.sql(expression, 'this')}"
def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str:
desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return "PRIMARY KEY"
def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
index_type = expression.args.get("index_type")
index_type = f" USING {index_type}" if index_type else ""
on_conflict = self.sql(expression, "on_conflict")
on_conflict = f" {on_conflict}" if on_conflict else ""
return f"UNIQUE{this}{index_type}{on_conflict}"
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
return self.sql(expression, "this")
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind")
properties = expression.args.get("properties")
properties_locs = self.locate_properties(properties) if properties else defaultdict()
this = self.createable_sql(expression, properties_locs)
properties_sql = ""
if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get(
exp.Properties.Location.POST_WITH
):
properties_sql = self.sql(
exp.Properties(
expressions=[
*properties_locs[exp.Properties.Location.POST_SCHEMA],
*properties_locs[exp.Properties.Location.POST_WITH],
]
)
)
begin = " BEGIN" if expression.args.get("begin") else ""
end = " END" if expression.args.get("end") else ""
expression_sql = self.sql(expression, "expression")
if expression_sql:
expression_sql = f"{begin}{self.sep()}{expression_sql}{end}"
if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return):
if properties_locs.get(exp.Properties.Location.POST_ALIAS):
postalias_props_sql = self.properties(
exp.Properties(
expressions=properties_locs[exp.Properties.Location.POST_ALIAS]
),
wrapped=False,
)
expression_sql = f" AS {postalias_props_sql}{expression_sql}"
else:
expression_sql = f" AS{expression_sql}"
postindex_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_INDEX):
postindex_props_sql = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_INDEX]),
wrapped=False,
prefix=" ",
)
indexes = self.expressions(expression, key="indexes", indent=False, sep=" ")
indexes = f" {indexes}" if indexes else ""
index_sql = indexes + postindex_props_sql
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
postcreate_props_sql = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]),
sep=" ",
prefix=" ",
wrapped=False,
)
modifiers = "".join((replace, unique, postcreate_props_sql))
postexpression_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
postexpression_props_sql = self.properties(
exp.Properties(
expressions=properties_locs[exp.Properties.Location.POST_EXPRESSION]
),
sep=" ",
prefix=" ",
wrapped=False,
)
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
clone = self.sql(expression, "clone")
clone = f" {clone}" if clone else ""
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
return self.prepend_ctes(expression, expression_sql)
def sequenceproperties_sql(self, expression: exp.SequenceProperties) -> str:
start = self.sql(expression, "start")