@@ -36,16 +36,34 @@ def end(self):
3636 self .unindent ()
3737 self .newLine (")" )
3838
39+ def emit (self ) -> str :
40+ lines = []
41+ for l in self .lines :
42+ if isinstance (l , str ):
43+ if " drop" in l and " i64.const 0" in lines [- 1 ]:
44+ lines [- 1 ] = None
45+ continue
46+ lines .append (l )
47+ else :
48+ lines .append (l .emit ())
49+ lines = [l for l in lines if l is not None ]
50+ return "\n " .join (lines )
51+
52+ def newBlock (self ):
53+ child = WasmBuilder (self .name )
54+ child .indentation = self .indentation
55+ self .lines .append (child )
56+ return child
57+
3958class WasmBackend (CommonVisitor ):
4059 defaultToGlobals = False # treat all vars as global if this is true
4160 localCounter = 0
61+ locals = None
4262
4363 def __init__ (self , main : str , ts : TypeSystem ):
4464 self .builder = WasmBuilder (main )
4565 self .main = main # name of main method
4666 self .ts = ts
47- self .enterScope ()
48-
4967
5068 def currentBuilder (self ):
5169 return self .classes [self .currentClass ]
@@ -68,11 +86,10 @@ def genLocalName(self) -> str:
6886 return f"local_{ self .localCounter } "
6987
7088 def newLocal (self , name : str = None , t : str = "i64" )-> str :
71- # store the top of stack as a new local
89+ # add a new local decl, does not store anything
7290 if name is None :
7391 name = self .genLocalName ()
74- self .instr (f"(local ${ name } { t } )" )
75- self .store (name )
92+ self .locals .newLine (f"(local ${ name } { t } )" )
7693 return name
7794
7895 def visitStmtList (self , stmts : List [Stmt ]):
@@ -86,25 +103,33 @@ def Program(self, node: Program):
86103 func_decls = [d for d in node .declarations if isinstance (d , FuncDef )]
87104 var_decls = [d for d in node .declarations if isinstance (d , VarDef )]
88105 self .builder .module ()
89- self .instr ('(import "console" "log" (func $log_int (param i64)))' )
90- self .instr ('(import "console" "log" (func $log_bool (param i64)))' )
91- self .instr ('(import "console" "assert" (func $assert (param i64)))' )
106+ self .instr ('(import "imports" "logInt" (func $log_int (param i64)))' )
107+ self .instr ('(import "imports" "logBool" (func $log_bool (param i32)))' )
108+ self .instr ('(import "imports" "logString" (func $log_str (param i64)))' )
109+
110+ self .instr ('(import "imports" "assert" (func $assert (param i32)))' )
92111 for v in var_decls :
93- self .instr (f"(global ${ v .var .identifier .name } { v .var .t .getWasmName ()} " )
94- self .builder .indent ()
112+ self .instr (f"(global ${ v .var .identifier .name } (mut { v .var .t .getWasmName ()} )" )
95113 self .visit (v .value )
96- self .builder . end ( )
114+ self .instr ( f")" )
97115 for d in func_decls :
98116 self .visit (d )
117+ module_builder = self .builder
118+ self .builder = module_builder .newBlock ()
119+
99120 self .builder .func ("main" )
100121 self .defaultToGlobals = True
122+ self .locals = self .builder .newBlock ()
101123 self .visitStmtList (node .statements )
102124 self .defaultToGlobals = False
103125 self .builder .end ()
126+
127+ self .builder = module_builder
104128 self .instr (f"(start $main)" )
105129 self .builder .end ()
106130
107131 def FuncDef (self , node : FuncDef ):
132+ self .locals = self .builder .newBlock ()
108133 params = [self .builder .param (p .identifier .name , p .t .getWasmName ()) for p in node .params ]
109134 self .returnType = node .type .returnType
110135 ret = None if self .returnType .isNone () else self .returnType .getWasmName ()
@@ -122,7 +147,8 @@ def VarDef(self, node: VarDef):
122147 raise Exception ("TODO" )
123148 else :
124149 self .visit (node .value )
125- self .newLocal (varName , node .value .inferredType .getWasmName ())
150+ n = self .newLocal (varName , node .value .inferredType .getWasmName ())
151+ self .store (n )
126152
127153 # # STATEMENTS
128154
@@ -147,6 +173,7 @@ def AssignStmt(self, node: AssignStmt):
147173 targets = node .targets [::- 1 ]
148174 if len (targets ) > 1 :
149175 name = self .newLocal (None , node .value .inferredType .getWasmName ())
176+ self .store (name )
150177 for t in targets :
151178 self .load (name )
152179 self .processAssignmentTarget (t )
@@ -213,16 +240,22 @@ def BinaryExpr(self, node: BinaryExpr):
213240 self .instr ("i64.eqz" )
214241 elif operator == "==" :
215242 # TODO: refs
216- self .instr ("i64.eq" )
243+ if leftType == BoolType ():
244+ self .instr ("i32.eq" )
245+ else :
246+ self .instr ("i64.eq" )
217247 elif operator == "!=" :
218- self .instr ("i64.ne" )
248+ if leftType == BoolType ():
249+ self .instr ("i32.ne" )
250+ else :
251+ self .instr ("i64.ne" )
219252 elif operator == "is" :
220253 raise Exception ("TODO" )
221254 # logical operators
222255 elif operator == "and" :
223- self .instr ("i64 .and" )
256+ self .instr ("i32 .and" )
224257 elif operator == "or" :
225- self .instr ("i64 .or" )
258+ self .instr ("i32 .or" )
226259 else :
227260 raise Exception (
228261 f"Internal compiler error: unexpected operator { operator } " )
@@ -234,7 +267,7 @@ def UnaryExpr(self, node: UnaryExpr):
234267 self .instr ("i64.sub" )
235268 elif node .operator == "not" :
236269 self .visit (node .operand )
237- self .instr ("i64 .eqz" )
270+ self .instr ("i32 .eqz" )
238271
239272 def CallExpr (self , node : CallExpr ):
240273 name = node .function .name
@@ -261,7 +294,7 @@ def WhileStmt(self, node: WhileStmt):
261294 self .builder .block (block )
262295 self .builder .loop (loop )
263296 self .visit (node .condition )
264- self .instr (f"i64 .eqz" )
297+ self .instr (f"i32 .eqz" )
265298 self .instr (f"br_if ${ block } " )
266299 for s in node .body :
267300 self .visit (s )
@@ -291,29 +324,34 @@ def Identifier(self, node: Identifier):
291324 self .instr (f"local.get ${ node .name } " )
292325
293326 def IfExpr (self , node : IfExpr ):
327+ n = self .newLocal (None , node .inferredType .getWasmName ())
294328 self .visit (node .condition )
295329 self .instr ("(if" )
296330 self .builder .indent ()
297331 self .instr ("(then" )
298332 self .builder .indent ()
299333 self .visit (node .thenExpr )
334+ self .store (n )
300335 self .builder .end ()
301336 self .instr ("(else" )
302337 self .builder .indent ()
303338 self .visit (node .elseExpr )
339+ self .store (n )
304340 self .builder .end ()
305341 self .builder .end ()
342+ self .load (n )
343+
306344
307345 # # LITERALS
308346
309347 def BooleanLiteral (self , node : BooleanLiteral ):
310348 if node .value :
311- self .instr (f"i64 .const 1" )
349+ self .instr (f"i32 .const 1" )
312350 else :
313- self .instr (f"i64 .const 0" )
351+ self .instr (f"i32 .const 0" )
314352
315353 def IntegerLiteral (self , node : IntegerLiteral ):
316- self .instr (f"i64.const f { node .value } " )
354+ self .instr (f"i64.const { node .value } " )
317355
318356 def NoneLiteral (self , node : NoneLiteral ):
319357 self .instr (f"i64.const 0" )
0 commit comments