forked from svenkreiss/pysparkling
/
parser.py
78 lines (58 loc) · 2.07 KB
/
parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import antlr4
from antlr4 import InputStream
from antlr4.error.ErrorListener import ErrorListener
from pysparkling.sql.ast.generated.SqlBaseLexer import SqlBaseLexer
from pysparkling.sql.ast.generated.SqlBaseParser import SqlBaseParser
class PostProcessor(antlr4.ParseTreeListener):
@staticmethod
def exitQuotedIdentifier(ctx):
def identity(token):
return token
return identity
@staticmethod
def enterNonReserved(ctx):
def add_backtick(token):
return "`{0}`".format(token)
return add_backtick
@staticmethod
def replace_token_by_identifier(ctx):
def do_replace_token_by_identifier(token):
parent = ctx.parent
parent.removeLastChild()
token = ctx.getChild(0).getPayload
parent.addChild()
return do_replace_token_by_identifier
class ParseErrorListener(ErrorListener):
def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise Exception("Parse error", msg)
class UpperCaseCharStream:
"""
Make SQL token detection case insensitive
"""
def __init__(self, wrapped):
self.wrapped = wrapped
def getText(self, interval, *args):
if args or (self.size() > 0 and (interval.b - interval.a >= 0)):
return self.wrapped.getText(interval, *args)
else:
return ""
def LA(self, i: int):
la = self.wrapped.LA(i)
if la == 0 or la == -1:
return la
else:
return ord(chr(la).upper())
def __getattr__(self, item):
return getattr(self.wrapped, item)
def build_ast(stream):
lexer = SqlBaseLexer(UpperCaseCharStream(stream))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener())
token_stream = antlr4.CommonTokenStream(lexer)
parser = SqlBaseParser(token_stream)
parser.addParseListener(PostProcessor())
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener())
return parser
def ast_parser(string):
return build_ast(InputStream(string))