Skip to content

Commit f81d384

Browse files
committedNov 26, 2021
update solidity grammar to support v0.8.0:
1 parent 70598d6 commit f81d384

14 files changed

+4966
-3033
lines changed
 

‎.gitmodules

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[submodule "solidity-antlr4"]
22
path = solidity-antlr4
3-
url = https://github.com/solidityj/solidity-antlr4.git
3+
url = https://github.com/solidity-parser/antlr.git

‎scripts/antlr4.sh

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
set -o errexit
44

55
antlr -Dlanguage=Python3 solidity-antlr4/Solidity.g4 -o src -visitor
6-
mv src/solidity-antlr4/* src/solidity_antlr4
6+
7+
mv src/solidity-antlr4/* solidity_parser/solidity_antlr4
78
rm -rf src/solidity-antlr4
89

9-
touch src/solidity_antlr4/__init__.py
10-
touch src/solidity_antlr4/__AUTOGENERATED__
10+
touch solidity_parser/solidity_antlr4/__init__.py
11+
touch solidity_parser/solidity_antlr4/__AUTOGENERATED__

‎solidity_parser/__main__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
print("=== contract: " + contract_name)
3636
level +=1
3737

38-
print(("\t" * level) + "=== Inherited Contrracts: " + ','.join([bc.baseName.namePath for bc in contract_object._node.baseContracts]))
38+
print(("\t" * level) + "=== Inherited Contracts: " + ','.join([bc.baseName.namePath for bc in contract_object._node.baseContracts]))
3939
## statevars
4040
print(("\t" * level) + "=== Enums")
4141
level += 2

‎solidity_parser/parser.py

+94-81
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77

88

9+
from os import fdopen
910
from antlr4 import *
1011
from solidity_parser.solidity_antlr4.SolidityLexer import SolidityLexer
1112
from solidity_parser.solidity_antlr4.SolidityParser import SolidityParser
@@ -124,6 +125,27 @@ def visitEnumValue(self, ctx):
124125
type="EnumValue",
125126
name=ctx.identifier().getText())
126127

128+
def visitTypeDefinition(self, ctx):
129+
return Node(ctx=ctx,
130+
type="TypeDefinition",
131+
typeKeyword=ctx.TypeKeyword().getText(),
132+
elementaryTypeName=self.visit(ctx.elementaryTypeName()))
133+
134+
135+
def visitCustomErrorDefinition(self, ctx):
136+
return Node(ctx=ctx,
137+
type="CustomErrorDefinition",
138+
name=self.visit(ctx.identifier()),
139+
parameterList=self.visit(ctx.parameterList()))
140+
141+
def visitFileLevelConstant(self, ctx):
142+
return Node(ctx=ctx,
143+
type="FileLevelConstant",
144+
name=self.visit(ctx.identifier()),
145+
typeName=self.visit(ctx.typeName()),
146+
ConstantKeyword=self.visit(ctx.ConstantKeyword()))
147+
148+
127149
def visitUsingForDeclaration(self, ctx: SolidityParser.UsingForDeclarationContext):
128150
typename = None
129151
if ctx.getChild(3) != '*':
@@ -138,45 +160,29 @@ def visitInheritanceSpecifier(self, ctx: SolidityParser.InheritanceSpecifierCont
138160
return Node(ctx=ctx,
139161
type="InheritanceSpecifier",
140162
baseName=self.visit(ctx.userDefinedTypeName()),
141-
arguments=self.visit(ctx.expression()))
163+
arguments=self.visit(ctx.expressionList()))
142164

143165
def visitContractPart(self, ctx: SolidityParser.ContractPartContext):
144166
return self.visit(ctx.children[0])
145167

146-
def visitConstructorDefinition(self, ctx: SolidityParser.ConstructorDefinitionContext):
147-
parameters = self.visit(ctx.parameterList())
148-
block = self.visit(ctx.block()) if ctx.block() else []
149-
modifiers = [self.visit(i) for i in ctx.modifierList().modifierInvocation()]
150168

151-
if ctx.modifierList().ExternalKeyword(0):
152-
visibility = "external"
153-
elif ctx.modifierList().InternalKeyword(0):
154-
visibility = "internal"
155-
elif ctx.modifierList().PublicKeyword(0):
156-
visibility = "public"
157-
elif ctx.modifierList().PrivateKeyword(0):
158-
visibility = "private"
159-
else:
160-
visibility = 'default'
161-
162-
if ctx.modifierList().stateMutability(0):
163-
stateMutability = ctx.modifierList().stateMutability(0).getText()
169+
def visitFunctionDefinition(self, ctx: SolidityParser.FunctionDefinitionContext):
170+
isConstructor = isFallback =isReceive = False
171+
172+
fd = ctx.functionDescriptor()
173+
if fd.ConstructorKeyword():
174+
name = fd.ConstructorKeyword().getText()
175+
isConstructor = True
176+
elif fd.FallbackKeyword():
177+
name = fd.FallbackKeyword().getText()
178+
isFallback = True
179+
elif fd.ReceiveKeyword():
180+
name = fd.ReceiveKeyword().getText()
181+
isReceive = True
182+
elif fd.identifier():
183+
name = fd.identifier().getText()
164184
else:
165-
stateMutability = None
166-
167-
return Node(ctx=ctx,
168-
type="FunctionDefinition",
169-
name=None,
170-
parameters=parameters,
171-
returnParameters=None,
172-
body=block,
173-
visibility=visibility,
174-
modifiers=modifiers,
175-
isConstructor=True,
176-
stateMutability=stateMutability)
177-
178-
def visitFunctionDefinition(self, ctx: SolidityParser.ConstructorDefinitionContext):
179-
name = ctx.identifier().getText() if ctx.identifier() else ""
185+
raise Exception("unexpected function descriptor")
180186

181187
parameters = self.visit(ctx.parameterList())
182188
returnParameters = self.visit(ctx.returnParameters()) if ctx.returnParameters() else []
@@ -207,7 +213,9 @@ def visitFunctionDefinition(self, ctx: SolidityParser.ConstructorDefinitionConte
207213
body=block,
208214
visibility=visibility,
209215
modifiers=modifiers,
210-
isConstructor=name == self._currentContract,
216+
isConstructor=isConstructor,
217+
isFallback=isFallback,
218+
isReceive=isReceive,
211219
stateMutability=stateMutability)
212220

213221
def visitReturnParameters(self, ctx: SolidityParser.ReturnParametersContext):
@@ -393,6 +401,21 @@ def visitIfStatement(self, ctx):
393401
TrueBody=TrueBody,
394402
FalseBody=FalseBody)
395403

404+
def visitTryStatement(self, ctx):
405+
return Node(ctx=ctx,
406+
type='TryStatement',
407+
expression=self.visit(ctx.expression()),
408+
block=self.visit(ctx.block()),
409+
returnParameters=self.visit(ctx.returnParameters()),
410+
catchClause=self.visit(ctx.catchClause()))
411+
412+
def visitCatchClause(self, ctx):
413+
return Node(ctx=ctx,
414+
type='CatchClause',
415+
identifier=self.visit(ctx.identifier()),
416+
parameterList=self.visit(ctx.parameterList()),
417+
block=self.visit(ctx.block()))
418+
396419
def visitUserDefinedTypeName(self, ctx):
397420
return Node(ctx=ctx,
398421
type='UserDefinedTypeName',
@@ -428,7 +451,7 @@ def visitNumberLiteral(self, ctx):
428451
def visitMapping(self, ctx):
429452
return Node(ctx=ctx,
430453
type='Mapping',
431-
keyType=self.visit(ctx.elementaryTypeName()),
454+
keyType=self.visit(ctx.mappingKey()),
432455
valueType=self.visit(ctx.typeName()))
433456

434457
def visitModifierDefinition(self, ctx):
@@ -449,6 +472,16 @@ def visitStatement(self, ctx):
449472
def visitSimpleStatement(self, ctx):
450473
return self.visit(ctx.getChild(0))
451474

475+
def visitUncheckedStatement(self, ctx):
476+
return Node(ctx=ctx,
477+
type='UncheckedStatement',
478+
body=self.visit(ctx.block()))
479+
480+
def visitRevertStatement(self, ctx):
481+
return Node(ctx=ctx,
482+
type='RevertStatement',
483+
functionCall=self.visit(ctx.functionCall()))
484+
452485
def visitExpression(self, ctx):
453486

454487
children_length = len(ctx.children)
@@ -641,16 +674,15 @@ def visitPrimaryExpression(self, ctx):
641674
type='BooleanLiteral',
642675
value=ctx.BooleanLiteral().getText() == 'true')
643676

644-
if ctx.HexLiteral():
677+
if ctx.hexLiteral():
645678
return Node(ctx=ctx,
646-
type='HexLiteral',
647-
value=ctx.HexLiteral().getText())
679+
type='hexLiteral',
680+
value=ctx.hexLiteral().getText())
648681

649-
if ctx.StringLiteral():
682+
if ctx.stringLiteral():
650683
text = ctx.getText()
651-
652684
return Node(ctx=ctx,
653-
type='StringLiteral',
685+
type='stringLiteral',
654686
value=text[1: len(text) - 1])
655687

656688
if len(ctx.children) == 3 and ctx.getChild(1).getText() == '[' and ctx.getChild(2).getText() == ']':
@@ -737,32 +769,6 @@ def visitVariableDeclarationStatement(self, ctx):
737769
variables=variables,
738770
initialValue=initialValue)
739771

740-
def visitImportDirective(self, ctx):
741-
pathString = ctx.StringLiteral().getText()
742-
unitAlias = None
743-
symbolAliases = None
744-
745-
impDecLen = len(ctx.importDeclaration())
746-
if impDecLen > 0:
747-
symbolAliases = []
748-
for decl in ctx.importDeclaration():
749-
symbol = decl.identifier(0).getText()
750-
alias = None
751-
if decl.identifier(1):
752-
alias = decl.identifier(1).getText()
753-
754-
symbolAliases.append([symbol, alias])
755-
elif impDecLen == 7:
756-
unitAlias = ctx.getChild(3).getText()
757-
elif impDecLen == 5:
758-
unitAlias = ctx.getChild(3).getText()
759-
760-
return Node(ctx=ctx,
761-
type='ImportDirective',
762-
path=pathString[1: len(pathString) - 1],
763-
unitAlias=unitAlias,
764-
symbolAliases=symbolAliases)
765-
766772
def visitEventDefinition(self, ctx):
767773
return Node(ctx=ctx,
768774
type='EventDefinition',
@@ -792,8 +798,8 @@ def visitEventParameterList(self, ctx):
792798
def visitInlineAssemblyStatement(self, ctx):
793799
language = None
794800

795-
if ctx.StringLiteral():
796-
language = ctx.StringLiteral().getText()
801+
if ctx.StringLiteralFragment():
802+
language = ctx.StringLiteralFragment().getText()
797803
language = language[1: len(language) - 1]
798804

799805
return Node(ctx=ctx,
@@ -810,13 +816,13 @@ def visitAssemblyBlock(self, ctx):
810816

811817
def visitAssemblyItem(self, ctx):
812818

813-
if ctx.HexLiteral():
819+
if ctx.hexLiteral():
814820
return Node(ctx=ctx,
815821
type='HexLiteral',
816-
value=ctx.HexLiteral().getText())
822+
value=ctx.hexLiteral().getText())
817823

818-
if ctx.StringLiteral():
819-
text = ctx.StringLiteral().getText()
824+
if ctx.stringLiteral():
825+
text = ctx.stringLiteral().getText()
820826
return Node(ctx=ctx,
821827
type='StringLiteral',
822828
value=text[1: len(text) - 1])
@@ -834,6 +840,11 @@ def visitAssemblyItem(self, ctx):
834840
def visitAssemblyExpression(self, ctx):
835841
return self.visit(ctx.getChild(0))
836842

843+
def visitAssemblyMember(self, ctx):
844+
return Node(ctx=ctx,
845+
type='AssemblyMember',
846+
name=ctx.identifier().getText())
847+
837848
def visitAssemblyCall(self, ctx):
838849
functionName = ctx.getChild(0).getText()
839850
args = [self.visit(arg) for arg in ctx.assemblyExpression()]
@@ -845,7 +856,7 @@ def visitAssemblyCall(self, ctx):
845856

846857
def visitAssemblyLiteral(self, ctx):
847858

848-
if ctx.StringLiteral():
859+
if ctx.stringLiteral():
849860
text = ctx.getText()
850861
return Node(ctx=ctx,
851862
type='StringLiteral',
@@ -861,7 +872,7 @@ def visitAssemblyLiteral(self, ctx):
861872
type='HexNumber',
862873
value=ctx.getText())
863874

864-
if ctx.HexLiteral():
875+
if ctx.hexLiteral():
865876
return Node(ctx=ctx,
866877
type='HexLiteral',
867878
value=ctx.getText())
@@ -981,7 +992,7 @@ def visitImportDirective(self, ctx):
981992

982993
return Node(ctx=ctx,
983994
type="ImportDirective",
984-
path=ctx.StringLiteral().getText().strip('"'),
995+
path=ctx.importPath().getText().strip('"'),
985996
symbolAliases=symbol_aliases,
986997
unitAlias=unit_alias
987998
)
@@ -1106,10 +1117,6 @@ def visitStructDefinition(self, _node):
11061117
self.structs[_node.name]=_node
11071118
self.names[_node.name]=_node
11081119

1109-
def visitConstructorDefinition(self, _node):
1110-
self.constructor = _node
1111-
1112-
11131120
def visitStateVariableDeclaration(self, _node):
11141121

11151122
class VarDecVisitor(object):
@@ -1150,10 +1157,15 @@ def __init__(self, node):
11501157
if(node.type=="FunctionDefinition"):
11511158
self.visibility = node.visibility
11521159
self.stateMutability = node.stateMutability
1160+
self.isConstructor = node.isConstructor
1161+
self.isFallback = node.isFallback
1162+
self.isReceive = node.isReceive
11531163
self.arguments = {}
11541164
self.returns = {}
11551165
self.declarations = {}
11561166
self.identifiers = []
1167+
1168+
11571169

11581170
class FunctionArgumentVisitor(object):
11591171

@@ -1182,13 +1194,14 @@ def visitIdentifier(self, __node):
11821194
def visitAssemblyCall(self, __node):
11831195
self.idents.append(__node)
11841196

1185-
11861197
current_function = FunctionObject(_node)
11871198
self.names[_node.name] = current_function
11881199
if _definition_type=="ModifierDefinition":
11891200
self.modifiers[_node.name] = current_function
11901201
else:
11911202
self.functions[_node.name] = current_function
1203+
if current_function.isConstructor:
1204+
self.constructor = current_function
11921205

11931206
## get parameters
11941207
funcargvisitor = FunctionArgumentVisitor()

0 commit comments

Comments
 (0)
Failed to load comments.