6
6
#
7
7
8
8
9
+ from os import fdopen
9
10
from antlr4 import *
10
11
from solidity_parser .solidity_antlr4 .SolidityLexer import SolidityLexer
11
12
from solidity_parser .solidity_antlr4 .SolidityParser import SolidityParser
@@ -124,6 +125,27 @@ def visitEnumValue(self, ctx):
124
125
type = "EnumValue" ,
125
126
name = ctx .identifier ().getText ())
126
127
128
+ def visitTypeDefinition (self , ctx ):
129
+ return Node (ctx = ctx ,
130
+ type = "TypeDefinition" ,
131
+ typeKeyword = ctx .TypeKeyword ().getText (),
132
+ elementaryTypeName = self .visit (ctx .elementaryTypeName ()))
133
+
134
+
135
+ def visitCustomErrorDefinition (self , ctx ):
136
+ return Node (ctx = ctx ,
137
+ type = "CustomErrorDefinition" ,
138
+ name = self .visit (ctx .identifier ()),
139
+ parameterList = self .visit (ctx .parameterList ()))
140
+
141
+ def visitFileLevelConstant (self , ctx ):
142
+ return Node (ctx = ctx ,
143
+ type = "FileLevelConstant" ,
144
+ name = self .visit (ctx .identifier ()),
145
+ typeName = self .visit (ctx .typeName ()),
146
+ ConstantKeyword = self .visit (ctx .ConstantKeyword ()))
147
+
148
+
127
149
def visitUsingForDeclaration (self , ctx : SolidityParser .UsingForDeclarationContext ):
128
150
typename = None
129
151
if ctx .getChild (3 ) != '*' :
@@ -138,45 +160,29 @@ def visitInheritanceSpecifier(self, ctx: SolidityParser.InheritanceSpecifierCont
138
160
return Node (ctx = ctx ,
139
161
type = "InheritanceSpecifier" ,
140
162
baseName = self .visit (ctx .userDefinedTypeName ()),
141
- arguments = self .visit (ctx .expression ()))
163
+ arguments = self .visit (ctx .expressionList ()))
142
164
143
165
def visitContractPart (self , ctx : SolidityParser .ContractPartContext ):
144
166
return self .visit (ctx .children [0 ])
145
167
146
- def visitConstructorDefinition (self , ctx : SolidityParser .ConstructorDefinitionContext ):
147
- parameters = self .visit (ctx .parameterList ())
148
- block = self .visit (ctx .block ()) if ctx .block () else []
149
- modifiers = [self .visit (i ) for i in ctx .modifierList ().modifierInvocation ()]
150
168
151
- if ctx .modifierList ().ExternalKeyword (0 ):
152
- visibility = "external"
153
- elif ctx .modifierList ().InternalKeyword (0 ):
154
- visibility = "internal"
155
- elif ctx .modifierList ().PublicKeyword (0 ):
156
- visibility = "public"
157
- elif ctx .modifierList ().PrivateKeyword (0 ):
158
- visibility = "private"
159
- else :
160
- visibility = 'default'
161
-
162
- if ctx .modifierList ().stateMutability (0 ):
163
- stateMutability = ctx .modifierList ().stateMutability (0 ).getText ()
169
+ def visitFunctionDefinition (self , ctx : SolidityParser .FunctionDefinitionContext ):
170
+ isConstructor = isFallback = isReceive = False
171
+
172
+ fd = ctx .functionDescriptor ()
173
+ if fd .ConstructorKeyword ():
174
+ name = fd .ConstructorKeyword ().getText ()
175
+ isConstructor = True
176
+ elif fd .FallbackKeyword ():
177
+ name = fd .FallbackKeyword ().getText ()
178
+ isFallback = True
179
+ elif fd .ReceiveKeyword ():
180
+ name = fd .ReceiveKeyword ().getText ()
181
+ isReceive = True
182
+ elif fd .identifier ():
183
+ name = fd .identifier ().getText ()
164
184
else :
165
- stateMutability = None
166
-
167
- return Node (ctx = ctx ,
168
- type = "FunctionDefinition" ,
169
- name = None ,
170
- parameters = parameters ,
171
- returnParameters = None ,
172
- body = block ,
173
- visibility = visibility ,
174
- modifiers = modifiers ,
175
- isConstructor = True ,
176
- stateMutability = stateMutability )
177
-
178
- def visitFunctionDefinition (self , ctx : SolidityParser .ConstructorDefinitionContext ):
179
- name = ctx .identifier ().getText () if ctx .identifier () else ""
185
+ raise Exception ("unexpected function descriptor" )
180
186
181
187
parameters = self .visit (ctx .parameterList ())
182
188
returnParameters = self .visit (ctx .returnParameters ()) if ctx .returnParameters () else []
@@ -207,7 +213,9 @@ def visitFunctionDefinition(self, ctx: SolidityParser.ConstructorDefinitionConte
207
213
body = block ,
208
214
visibility = visibility ,
209
215
modifiers = modifiers ,
210
- isConstructor = name == self ._currentContract ,
216
+ isConstructor = isConstructor ,
217
+ isFallback = isFallback ,
218
+ isReceive = isReceive ,
211
219
stateMutability = stateMutability )
212
220
213
221
def visitReturnParameters (self , ctx : SolidityParser .ReturnParametersContext ):
@@ -393,6 +401,21 @@ def visitIfStatement(self, ctx):
393
401
TrueBody = TrueBody ,
394
402
FalseBody = FalseBody )
395
403
404
+ def visitTryStatement (self , ctx ):
405
+ return Node (ctx = ctx ,
406
+ type = 'TryStatement' ,
407
+ expression = self .visit (ctx .expression ()),
408
+ block = self .visit (ctx .block ()),
409
+ returnParameters = self .visit (ctx .returnParameters ()),
410
+ catchClause = self .visit (ctx .catchClause ()))
411
+
412
+ def visitCatchClause (self , ctx ):
413
+ return Node (ctx = ctx ,
414
+ type = 'CatchClause' ,
415
+ identifier = self .visit (ctx .identifier ()),
416
+ parameterList = self .visit (ctx .parameterList ()),
417
+ block = self .visit (ctx .block ()))
418
+
396
419
def visitUserDefinedTypeName (self , ctx ):
397
420
return Node (ctx = ctx ,
398
421
type = 'UserDefinedTypeName' ,
@@ -428,7 +451,7 @@ def visitNumberLiteral(self, ctx):
428
451
def visitMapping (self , ctx ):
429
452
return Node (ctx = ctx ,
430
453
type = 'Mapping' ,
431
- keyType = self .visit (ctx .elementaryTypeName ()),
454
+ keyType = self .visit (ctx .mappingKey ()),
432
455
valueType = self .visit (ctx .typeName ()))
433
456
434
457
def visitModifierDefinition (self , ctx ):
@@ -449,6 +472,16 @@ def visitStatement(self, ctx):
449
472
def visitSimpleStatement (self , ctx ):
450
473
return self .visit (ctx .getChild (0 ))
451
474
475
+ def visitUncheckedStatement (self , ctx ):
476
+ return Node (ctx = ctx ,
477
+ type = 'UncheckedStatement' ,
478
+ body = self .visit (ctx .block ()))
479
+
480
+ def visitRevertStatement (self , ctx ):
481
+ return Node (ctx = ctx ,
482
+ type = 'RevertStatement' ,
483
+ functionCall = self .visit (ctx .functionCall ()))
484
+
452
485
def visitExpression (self , ctx ):
453
486
454
487
children_length = len (ctx .children )
@@ -641,16 +674,15 @@ def visitPrimaryExpression(self, ctx):
641
674
type = 'BooleanLiteral' ,
642
675
value = ctx .BooleanLiteral ().getText () == 'true' )
643
676
644
- if ctx .HexLiteral ():
677
+ if ctx .hexLiteral ():
645
678
return Node (ctx = ctx ,
646
- type = 'HexLiteral ' ,
647
- value = ctx .HexLiteral ().getText ())
679
+ type = 'hexLiteral ' ,
680
+ value = ctx .hexLiteral ().getText ())
648
681
649
- if ctx .StringLiteral ():
682
+ if ctx .stringLiteral ():
650
683
text = ctx .getText ()
651
-
652
684
return Node (ctx = ctx ,
653
- type = 'StringLiteral ' ,
685
+ type = 'stringLiteral ' ,
654
686
value = text [1 : len (text ) - 1 ])
655
687
656
688
if len (ctx .children ) == 3 and ctx .getChild (1 ).getText () == '[' and ctx .getChild (2 ).getText () == ']' :
@@ -737,32 +769,6 @@ def visitVariableDeclarationStatement(self, ctx):
737
769
variables = variables ,
738
770
initialValue = initialValue )
739
771
740
- def visitImportDirective (self , ctx ):
741
- pathString = ctx .StringLiteral ().getText ()
742
- unitAlias = None
743
- symbolAliases = None
744
-
745
- impDecLen = len (ctx .importDeclaration ())
746
- if impDecLen > 0 :
747
- symbolAliases = []
748
- for decl in ctx .importDeclaration ():
749
- symbol = decl .identifier (0 ).getText ()
750
- alias = None
751
- if decl .identifier (1 ):
752
- alias = decl .identifier (1 ).getText ()
753
-
754
- symbolAliases .append ([symbol , alias ])
755
- elif impDecLen == 7 :
756
- unitAlias = ctx .getChild (3 ).getText ()
757
- elif impDecLen == 5 :
758
- unitAlias = ctx .getChild (3 ).getText ()
759
-
760
- return Node (ctx = ctx ,
761
- type = 'ImportDirective' ,
762
- path = pathString [1 : len (pathString ) - 1 ],
763
- unitAlias = unitAlias ,
764
- symbolAliases = symbolAliases )
765
-
766
772
def visitEventDefinition (self , ctx ):
767
773
return Node (ctx = ctx ,
768
774
type = 'EventDefinition' ,
@@ -792,8 +798,8 @@ def visitEventParameterList(self, ctx):
792
798
def visitInlineAssemblyStatement (self , ctx ):
793
799
language = None
794
800
795
- if ctx .StringLiteral ():
796
- language = ctx .StringLiteral ().getText ()
801
+ if ctx .StringLiteralFragment ():
802
+ language = ctx .StringLiteralFragment ().getText ()
797
803
language = language [1 : len (language ) - 1 ]
798
804
799
805
return Node (ctx = ctx ,
@@ -810,13 +816,13 @@ def visitAssemblyBlock(self, ctx):
810
816
811
817
def visitAssemblyItem (self , ctx ):
812
818
813
- if ctx .HexLiteral ():
819
+ if ctx .hexLiteral ():
814
820
return Node (ctx = ctx ,
815
821
type = 'HexLiteral' ,
816
- value = ctx .HexLiteral ().getText ())
822
+ value = ctx .hexLiteral ().getText ())
817
823
818
- if ctx .StringLiteral ():
819
- text = ctx .StringLiteral ().getText ()
824
+ if ctx .stringLiteral ():
825
+ text = ctx .stringLiteral ().getText ()
820
826
return Node (ctx = ctx ,
821
827
type = 'StringLiteral' ,
822
828
value = text [1 : len (text ) - 1 ])
@@ -834,6 +840,11 @@ def visitAssemblyItem(self, ctx):
834
840
def visitAssemblyExpression (self , ctx ):
835
841
return self .visit (ctx .getChild (0 ))
836
842
843
+ def visitAssemblyMember (self , ctx ):
844
+ return Node (ctx = ctx ,
845
+ type = 'AssemblyMember' ,
846
+ name = ctx .identifier ().getText ())
847
+
837
848
def visitAssemblyCall (self , ctx ):
838
849
functionName = ctx .getChild (0 ).getText ()
839
850
args = [self .visit (arg ) for arg in ctx .assemblyExpression ()]
@@ -845,7 +856,7 @@ def visitAssemblyCall(self, ctx):
845
856
846
857
def visitAssemblyLiteral (self , ctx ):
847
858
848
- if ctx .StringLiteral ():
859
+ if ctx .stringLiteral ():
849
860
text = ctx .getText ()
850
861
return Node (ctx = ctx ,
851
862
type = 'StringLiteral' ,
@@ -861,7 +872,7 @@ def visitAssemblyLiteral(self, ctx):
861
872
type = 'HexNumber' ,
862
873
value = ctx .getText ())
863
874
864
- if ctx .HexLiteral ():
875
+ if ctx .hexLiteral ():
865
876
return Node (ctx = ctx ,
866
877
type = 'HexLiteral' ,
867
878
value = ctx .getText ())
@@ -981,7 +992,7 @@ def visitImportDirective(self, ctx):
981
992
982
993
return Node (ctx = ctx ,
983
994
type = "ImportDirective" ,
984
- path = ctx .StringLiteral ().getText ().strip ('"' ),
995
+ path = ctx .importPath ().getText ().strip ('"' ),
985
996
symbolAliases = symbol_aliases ,
986
997
unitAlias = unit_alias
987
998
)
@@ -1106,10 +1117,6 @@ def visitStructDefinition(self, _node):
1106
1117
self .structs [_node .name ]= _node
1107
1118
self .names [_node .name ]= _node
1108
1119
1109
- def visitConstructorDefinition (self , _node ):
1110
- self .constructor = _node
1111
-
1112
-
1113
1120
def visitStateVariableDeclaration (self , _node ):
1114
1121
1115
1122
class VarDecVisitor (object ):
@@ -1150,10 +1157,15 @@ def __init__(self, node):
1150
1157
if (node .type == "FunctionDefinition" ):
1151
1158
self .visibility = node .visibility
1152
1159
self .stateMutability = node .stateMutability
1160
+ self .isConstructor = node .isConstructor
1161
+ self .isFallback = node .isFallback
1162
+ self .isReceive = node .isReceive
1153
1163
self .arguments = {}
1154
1164
self .returns = {}
1155
1165
self .declarations = {}
1156
1166
self .identifiers = []
1167
+
1168
+
1157
1169
1158
1170
class FunctionArgumentVisitor (object ):
1159
1171
@@ -1182,13 +1194,14 @@ def visitIdentifier(self, __node):
1182
1194
def visitAssemblyCall (self , __node ):
1183
1195
self .idents .append (__node )
1184
1196
1185
-
1186
1197
current_function = FunctionObject (_node )
1187
1198
self .names [_node .name ] = current_function
1188
1199
if _definition_type == "ModifierDefinition" :
1189
1200
self .modifiers [_node .name ] = current_function
1190
1201
else :
1191
1202
self .functions [_node .name ] = current_function
1203
+ if current_function .isConstructor :
1204
+ self .constructor = current_function
1192
1205
1193
1206
## get parameters
1194
1207
funcargvisitor = FunctionArgumentVisitor ()
0 commit comments