Permalink
Browse files

Fix recently introduced bugs in ast_tools using Scala nodes in C++ co…

…de by being explicit; clean up print output of Asp unit tests; add parse_method to ast_tools
  • Loading branch information...
1 parent f2b9fc1 commit 11892060b7fdaf141dc870f0a3c2f2679e2e68b9 Derrick Coetzee committed Dec 10, 2012
Showing with 100 additions and 98 deletions.
  1. +92 −90 asp/codegen/ast_tools.py
  2. +1 −1 specializers/run_tests.sh
  3. +3 −3 tests/asp_module_tests.py
  4. +4 −4 tests/ast_tools_test.py
View
@@ -1,17 +1,17 @@
-
-from cpp_ast import *
-import cpp_ast
+import cpp_ast as cpp
import python_ast as ast
-import python_ast
import scala_ast as scala
-from scala_ast import *
+
try:
from asp.util import *
except Exception,e:
pass
+def is_python_node(x):
+ return isinstance(x, ast.AST)
+
def is_cpp_node(x):
- return isinstance(x, Generable)
+ return isinstance(x, cpp.Generable)
def is_scala_node(x):
return isinstance(x, scala.Generable)
@@ -114,18 +114,18 @@ def is_node(self, x):
class ConvertAST(ast.NodeTransformer):
"""Class to convert from Python AST to C++ AST"""
def visit_Num(self, node):
- return CNumber(node.n)
+ return cpp.CNumber(node.n)
def visit_Str(self, node):
- return String(node.s)
+ return cpp.String(node.s)
def visit_Name(self, node):
- return CName(node.id)
+ return cpp.CName(node.id)
def visit_BinOp(self, node):
- return BinOp(self.visit(node.left),
- self.visit(node.op),
- self.visit(node.right))
+ return cpp.BinOp(self.visit(node.left),
+ self.visit(node.op),
+ self.visit(node.right))
def visit_Add(self, node):
return "+"
@@ -139,8 +139,8 @@ def visit_Mod(self, node):
return "%"
def visit_UnaryOp(self, node):
- return UnaryOp(self.visit(node.op),
- self.visit(node.operand))
+ return cpp.UnaryOp(self.visit(node.op),
+ self.visit(node.operand))
def visit_Invert(self, node):
return "-"
@@ -152,15 +152,14 @@ def visit_Not(self, node):
return "!"
def visit_Subscript(self, node):
- return Subscript(self.visit(node.value),
- self.visit(node.slice))
+ return cpp.Subscript(self.visit(node.value),
+ self.visit(node.slice))
def visit_Index(self, node):
return self.visit(node.value)
-
def visit_Pass(self, _):
- return Expression()
+ return cpp.Expression()
# by default, only do first statement in a module
def visit_Module(self, node):
@@ -171,35 +170,37 @@ def visit_Expr(self, node):
# only single targets supported
def visit_Assign(self, node):
- if isinstance(node, python_ast.Assign):
- return Assign(self.visit(node.targets[0]),
- self.visit(node.value))
- elif isinstance(node, cpp_ast.Assign):
- return Assign(self.visit(node.lvalue),
- self.visit(node.rvalue))
+ if is_python_node(node):
+ return cpp.Assign(self.visit(node.targets[0]),
+ self.visit(node.value))
+ elif is_cpp_node(node):
+ return cpp.Assign(self.visit(node.lvalue),
+ self.visit(node.rvalue))
+ else:
+ raise Exception ("Unknown Assign node type")
def visit_FunctionDef(self, node):
debug_print("In FunctionDef:")
debug_print(ast.dump(node))
debug_print("----")
- return FunctionBody(FunctionDeclaration(Value("void",
- node.name),
- self.visit(node.args)),
- Block([self.visit(x) for x in node.body]))
+ return cpp.FunctionBody(cpp.FunctionDeclaration(cpp.Value("void",
+ node.name),
+ self.visit(node.args)),
+ cpp.Block([self.visit(x) for x in node.body]))
def visit_arguments(self, node):
"""Only return the basic case: everything is void*, no named args, no default values"""
- return [Pointer(Value("void",self.visit(x))) for x in node.args]
+ return [cpp.Pointer(cpp.Value("void",self.visit(x))) for x in node.args]
def visit_Call(self, node):
"""We only handle calls that are casts; everything else (eventually) will be
translated into callbacks into Python."""
if isinstance(node.func, ast.Name):
if node.func.id == "int":
- return TypeCast(Value('int', ''), self.visit(node.args[0]))
+ return cpp.TypeCast(cpp.Value('int', ''), self.visit(node.args[0]))
if node.func.id == "abs":
- return Call(CName("abs"), [self.visit(x) for x in node.args])
+ return cpp.Call(cpp.CName("abs"), [self.visit(x) for x in node.args])
def visit_Print(self, node):
if len(node.values) > 0:
@@ -208,26 +209,26 @@ def visit_Print(self, node):
text = ''
for fragment in node.values[1:]:
text += ' << \" \" << ' + str(self.visit(fragment))
- return Print(text, node.nl)
+ return cpp.Print(text, node.nl)
def visit_Compare(self, node):
# only handles 1 thing on right side for now (1st op and comparator)
# also currently not handling: Is, IsNot, In, NotIn
ops = {'Eq':'==','NotEq':'!=','Lt':'<','LtE':'<=','Gt':'>','GtE':'>='}
op = ops[node.ops[0].__class__.__name__]
- return Compare(self.visit(node.left), op, self.visit(node.comparators[0]))
+ return cpp.Compare(self.visit(node.left), op, self.visit(node.comparators[0]))
def visit_If(self, node):
test = self.visit(node.test)
- body = Block([self.visit(x) for x in node.body])
+ body = cpp.Block([self.visit(x) for x in node.body])
if node.orelse == []:
orelse = None
else:
- orelse = Block([self.visit(x) for x in node.orelse])
- return IfConv(test, body, orelse)
+ orelse = cpp.Block([self.visit(x) for x in node.orelse])
+ return cpp.IfConv(test, body, orelse)
def visit_Return(self, node):
- return ReturnStatement(self.visit(node.value))
+ return cpp.ReturnStatement(self.visit(node.value))
class ConvertPyAST_ScalaAST(ast.NodeTransformer):
@@ -281,11 +282,11 @@ def visit_Return(self,node):
# only single targets supported
def visit_Assign(self, node):
- if isinstance(node, python_ast.Assign):
+ if is_python_node(node):
return scala.Assign(self.visit(node.targets[0]),
self.visit(node.value))
#below happen ever?
- elif isinstance(node, scala.Assign):
+ elif is_scala_node(node):
return scala.Assign(self.visit(node.lvalue),
self.visit(node.rvalue))
@@ -325,7 +326,8 @@ def visit_Subscript(self,node):
context ='store'
elif type(node.ctx) == ast.Load:
context = 'load'
- else: raise Exception ("Unknown Subscript Context")
+ else:
+ raise Exception ("Unknown Subscript Context")
return scala.Subscript(self.visit(node.value),self.visit(node.slice), context)
def visit_List(self,node):
@@ -402,7 +404,7 @@ def __init__(self, loopvar, increment):
def visit_CName(self, node):
#print "node.name is ", node.name
if node.name == self.loopvar:
- return BinOp(CName(self.loopvar), "+", CNumber(self.increment))
+ return cpp.BinOp(cpp.CName(self.loopvar), "+", cpp.CNumber(self.increment))
else:
return node
@@ -413,12 +415,12 @@ def visit_Block(self, node):
self.in_new_scope = True
#print "visiting block in ", node
contents = [self.visit(x) for x in node.contents]
- retnode = Block(contents=[x for x in contents if x != None])
+ retnode = cpp.Block(contents=[x for x in contents if x != None])
self.in_new_scope = old_scope
else:
self.inside_for = True
contents = [self.visit(x) for x in node.contents]
- retnode = Block(contents=[x for x in contents if x != None])
+ retnode = cpp.Block(contents=[x for x in contents if x != None])
return retnode
@@ -437,22 +439,22 @@ def visit_Pointer(self, node):
# ignore typecast declarators
def visit_TypeCast(self, node):
- return TypeCast(node.tp, self.visit(node.value))
+ return cpp.TypeCast(node.tp, self.visit(node.value))
# make lvalue not a declaration
def visit_Assign(self, node):
if not self.in_new_scope:
- if isinstance(node.lvalue, NestedDeclarator):
+ if isinstance(node.lvalue, cpp.NestedDeclarator):
tp, new_lvalue = node.lvalue.subdecl.get_decl_pair()
rvalue = self.visit(node.rvalue)
- return Assign(CName(new_lvalue), rvalue)
+ return cpp.Assign(cpp.CName(new_lvalue), rvalue)
- if isinstance(node.lvalue, Declarator):
+ if isinstance(node.lvalue, cpp.Declarator):
tp, new_lvalue = node.lvalue.get_decl_pair()
rvalue = self.visit(node.rvalue)
- return Assign(CName(new_lvalue), rvalue)
+ return cpp.Assign(cpp.CName(new_lvalue), rvalue)
- return Assign(self.visit(node.lvalue), self.visit(node.rvalue))
+ return cpp.Assign(self.visit(node.lvalue), self.visit(node.rvalue))
def unroll(self, node, factor):
"""Given a For node, unrolls the loop with a given factor.
@@ -468,49 +470,49 @@ def unroll(self, node, factor):
# we can't precalculate the number of leftover iterations in the case that
# the number of iterations are not known a priori, so we build an Expression
# and let the compiler deal with it
- #leftover_begin = BinOp(CNumber(factor),
- # "*",
- # BinOp(BinOp(node.end, "+", 1), "/", CNumber(factor)))
+ #leftover_begin = cpp.BinOp(cpp.CNumber(factor),
+ # "*",
+ # cpp.BinOp(cpp.BinOp(node.end, "+", 1), "/", cpp.CNumber(factor)))
# we begin leftover iterations at factor*( (end-initial+1) / factor ) + initial
# note that this works due to integer division
- leftover_begin = BinOp(BinOp(BinOp(BinOp(BinOp(node.end, "-", node.initial),
+ leftover_begin = cpp.BinOp(cpp.BinOp(cpp.BinOp(cpp.BinOp(cpp.BinOp(node.end, "-", node.initial),
"+",
- CNumber(1)),
+ cpp.CNumber(1)),
"/",
- CNumber(factor)),
+ cpp.CNumber(factor)),
"*",
- CNumber(factor)),
+ cpp.CNumber(factor)),
"+",
node.initial)
- new_limit = BinOp(node.end, "-", CNumber(factor-1))
+ new_limit = cpp.BinOp(node.end, "-", cpp.CNumber(factor-1))
# debug_print("Loop unroller called with ", node.loopvar)
# debug_print("Number of iterations: ", num_iterations)
# debug_print("Number of unrolls: ", num_unrolls)
# debug_print("Leftover iterations: ", leftover)
- new_increment = BinOp(node.increment, "*", CNumber(factor))
+ new_increment = cpp.BinOp(node.increment, "*", cpp.CNumber(factor))
- new_block = Block(contents=node.body.contents)
+ new_block = cpp.Block(contents=node.body.contents)
for x in xrange(1, factor):
new_extension = copy.deepcopy(node.body)
new_extension = LoopUnroller.UnrollReplacer(node.loopvar, x).visit(new_extension)
new_block.extend(new_extension.contents)
- return_block = UnbracedBlock()
+ return_block = cpp.UnbracedBlock()
- unrolled_for_node = For(
+ unrolled_for_node = cpp.For(
node.loopvar,
node.initial,
new_limit,
#node.end,
new_increment,
new_block)
- leftover_for_node = For(
+ leftover_for_node = cpp.For(
node.loopvar,
leftover_begin,
node.end,
@@ -522,7 +524,7 @@ def unroll(self, node, factor):
# if we *know* this loop has no leftover iterations, then
# we return without the leftover loop
- if not (isinstance(node.initial, CNumber) and isinstance(node.end, CNumber) and
+ if not (isinstance(node.initial, cpp.CNumber) and isinstance(node.end, cpp.CNumber) and
((node.end.num - node.initial.num + 1) % factor == 0)):
return_block.append(leftover_for_node)
@@ -531,24 +533,24 @@ def unroll(self, node, factor):
class LoopBlocker(object):
def loop_block(self, node, block_size):
- outer_incr_name = CName(node.loopvar + node.loopvar)
+ outer_incr_name = cpp.CName(node.loopvar + node.loopvar)
- new_inner_for = For(
+ new_inner_for = cpp.For(
node.loopvar,
outer_incr_name,
- FunctionCall("min", [BinOp(outer_incr_name,
- "+",
- CNumber(block_size-1)),
- node.end]),
- CNumber(1),
+ cpp.FunctionCall("min", [cpp.BinOp(outer_incr_name,
+ "+",
+ cpp.CNumber(block_size-1)),
+ node.end]),
+ cpp.CNumber(1),
node.body)
- new_outer_for = For(
+ new_outer_for = cpp.For(
node.loopvar + node.loopvar,
node.initial,
node.end,
- BinOp(node.increment, "*", CNumber(block_size)),
- Block(contents=[new_inner_for]))
+ cpp.BinOp(node.increment, "*", cpp.CNumber(block_size)),
+ cpp.Block(contents=[new_inner_for]))
debug_print(new_outer_for)
return new_outer_for
@@ -587,27 +589,27 @@ def visit_For(self, node):
new_body = self.visit(node.body)
assert self.second_target < self.current_loop + 1, 'Tried to switch loops %d and %d but only %d loops available' % (self.first_target, self.second_target, self.current_loop + 1)
# replace with the second loop (which has now been saved)
- return For(self.saved_second_loop.loopvar,
- self.saved_second_loop.initial,
- self.saved_second_loop.end,
- self.saved_second_loop.increment,
- new_body)
+ return cpp.For(self.saved_second_loop.loopvar,
+ self.saved_second_loop.initial,
+ self.saved_second_loop.end,
+ self.saved_second_loop.increment,
+ new_body)
if self.current_loop == self.second_target:
# save this
self.saved_second_loop = node
# replace this
debug_print("replacing loop")
- return For(self.saved_first_loop.loopvar,
- self.saved_first_loop.initial,
- self.saved_first_loop.end,
- self.saved_first_loop.increment,
- node.body)
-
-
- return For(node.loopvar,
- node.initial,
- node.end,
- node.increment,
- self.visit(node.body))
+ return cpp.For(self.saved_first_loop.loopvar,
+ self.saved_first_loop.initial,
+ self.saved_first_loop.end,
+ self.saved_first_loop.increment,
+ node.body)
+
+
+ return cpp.For(node.loopvar,
+ node.initial,
+ node.end,
+ node.increment,
+ self.visit(node.body))
@@ -1,3 +1,3 @@
#!/bin/bash
cd array_doubler; ./run_tests.sh; cd ..
-#cd stencil; ./run_tests.sh; cd ..
+cd array_map; ./run_tests.sh; cd ..
Oops, something went wrong.

0 comments on commit 1189206

Please sign in to comment.