Skip to content

Commit

Permalink
Use PyccelAstNode.pyccel_staging in SemanticParser (#1839)
Browse files Browse the repository at this point in the history
Use the `pyccel_staging` property of `PyccelAstNode` objects
to facilitate the creation of partially syntactic object (e.g. when
unravelling an assignment where the `rhs` has been visited and
the `lhs` is an inhomogeneous tuple).

This is done by exiting `SemanticParser._visit` immediately if the
object passed to the function is already a semantic object.

As the property is now always used it is obvious when it has not
been correctly set. These cases are therefore fixed.

---------

Co-authored-by: Yaman Güçlü <yaman.guclu@gmail.com>
  • Loading branch information
EmilyBourne and yguclu committed Apr 19, 2024
1 parent 046e93f commit 2728d40
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -56,6 +56,8 @@ All notable changes to this project will be documented in this file.
- \[INTERNALS\] Remove the `order` argument from the `pyccel.ast.core.Allocate` constructor.
- \[INTERNALS\] Remove `rank` and `order` arguments from `pyccel.ast.variable.Variable` constructor.
- \[INTERNALS\] Ensure `SemanticParser.infer_type` returns all documented information.
- \[INTERNALS\] Enforce correct value for `pyccel_staging` property of `PyccelAstNode`.
- \[INTERNALS\] Allow visiting objects containing both syntactic and semantic elements in `SemanticParser`.

### Deprecated

Expand Down
57 changes: 48 additions & 9 deletions pyccel/parser/semantic.py
Expand Up @@ -775,9 +775,11 @@ def is_int(a):
else:
return self._visit(var[indices[0]][indices[1:]])
else:
pyccel_stage.set_stage('syntactic')
tmp_var = PyccelSymbol(self.scope.get_new_name())
assign = Assign(tmp_var, var)
assign.set_current_ast(expr.python_ast)
pyccel_stage.set_stage('semantic')
self._additional_exprs[-1].append(self._visit(assign))
var = self._visit(tmp_var)

Expand Down Expand Up @@ -1074,9 +1076,11 @@ def _handle_function(self, expr, func, args, is_method = False):

func_results = func.results if isinstance(func, FunctionDef) else func.functions[0].results
if not parent_assign and len(func_results) == 1 and func_results[0].var.rank > 0:
pyccel_stage.set_stage('syntactic')
tmp_var = PyccelSymbol(self.scope.get_new_name())
assign = Assign(tmp_var, expr)
assign.set_current_ast(expr.python_ast)
pyccel_stage.set_stage('semantic')
self._additional_exprs[-1].append(self._visit(assign))
return self._visit(tmp_var)

Expand Down Expand Up @@ -1714,7 +1718,10 @@ def _assign_GeneratorComprehension(self, lhs_name, expr):
# Use _visit_Assign to create the requested iterator with the correct type
# The result of this operation is not stored, it is just used to declare
# iterator with the correct dtype to allow correct dtype deductions later
self._visit(Assign(iterator, iterator_rhs, python_ast=expr.python_ast))
pyccel_stage.set_stage('syntactic')
syntactic_assign = Assign(iterator, iterator_rhs, python_ast=expr.python_ast)
pyccel_stage.set_stage('semantic')
self._visit(syntactic_assign)

loop_elem = loop.body.body[0]

Expand All @@ -1724,9 +1731,13 @@ def _assign_GeneratorComprehension(self, lhs_name, expr):
gens = set(loop_elem.get_attribute_nodes(GeneratorComprehension))
if len(gens)==1:
gen = gens.pop()
pyccel_stage.set_stage('syntactic')
assert isinstance(gen.lhs, PyccelSymbol) and gen.lhs.is_temp
gen_lhs = self.scope.get_new_name() if gen.lhs.is_temp else gen.lhs
assign = self._visit(Assign(gen_lhs, gen, python_ast=gen.python_ast))
syntactic_assign = Assign(gen_lhs, gen, python_ast=gen.python_ast)
pyccel_stage.set_stage('semantic')
assign = self._visit(syntactic_assign)

new_expr.append(assign)
loop.substitute(gen, assign.lhs)
loop_elem = loop.body.body[0]
Expand Down Expand Up @@ -1857,7 +1868,9 @@ def _get_indexed_type(self, base, args, expr):
if isinstance(base, PyccelFunctionDef) and base.cls_name is TypingFinal:
syntactic_annotation = args[0]
if not isinstance(syntactic_annotation, SyntacticTypeAnnotation):
pyccel_stage.set_stage('syntactic')
syntactic_annotation = SyntacticTypeAnnotation(dtype=syntactic_annotation)
pyccel_stage.set_stage('semantic')
annotation = self._visit(syntactic_annotation)
for t in annotation.type_list:
t.is_const = True
Expand Down Expand Up @@ -1887,7 +1900,9 @@ def _get_indexed_type(self, base, args, expr):
if len(args) == 2 and args[1] is LiteralEllipsis():
syntactic_annotation = args[0]
if not isinstance(syntactic_annotation, SyntacticTypeAnnotation):
pyccel_stage.set_stage('syntactic')
syntactic_annotation = SyntacticTypeAnnotation(dtype=syntactic_annotation)
pyccel_stage.set_stage('semantic')
internal_datatypes = self._visit(syntactic_annotation)
type_annotations = []
if dtype_cls is PythonTupleFunction:
Expand Down Expand Up @@ -1925,14 +1940,16 @@ def _visit(self, expr):
Parameters
----------
expr : pyccel.ast.basic.PyccelAstNode
expr : pyccel.ast.basic.PyccelAstNode | PyccelSymbol
Object to visit of type X.
Returns
-------
pyccel.ast.basic.PyccelAstNode
AST object which is the semantic equivalent of expr.
"""
if getattr(expr, 'pyccel_staging', 'syntactic') == 'semantic':
return expr

# TODO - add settings to Errors
# - line and column
Expand Down Expand Up @@ -2203,8 +2220,12 @@ def _visit_FunctionCallArgument(self, expr):
value = self._visit(expr.value)
a = FunctionCallArgument(value, expr.keyword)
def generate_and_assign_temp_var():
pyccel_stage.set_stage('syntactic')
tmp_var = self.scope.get_new_name()
assign = self._visit(Assign(tmp_var, expr.value, python_ast = expr.value.python_ast))
syntactic_assign = Assign(tmp_var, expr.value, python_ast = expr.value.python_ast)
pyccel_stage.set_stage('semantic')

assign = self._visit(syntactic_assign)
self._additional_exprs[-1].append(assign)
return FunctionCallArgument(self._visit(tmp_var))
if isinstance(value, (PyccelArithmeticOperator, PyccelInternalFunction)) and value.rank:
Expand Down Expand Up @@ -3663,12 +3684,15 @@ def _visit_FunctionalFor(self, expr):
def _visit_GeneratorComprehension(self, expr):
lhs = self.check_for_variable(expr.lhs)
if lhs is None:
pyccel_stage.set_stage('syntactic')
if expr.lhs.is_temp:
lhs = PyccelSymbol(self.scope.get_new_name(), is_temp=True)
else:
lhs = expr.lhs
syntactic_assign = Assign(lhs, expr, python_ast=expr.python_ast)
pyccel_stage.set_stage('semantic')

creation = self._visit(Assign(lhs, expr, python_ast=expr.python_ast))
creation = self._visit(syntactic_assign)
self._additional_exprs[-1].append(creation)
return self.get_variable(lhs)
else:
Expand Down Expand Up @@ -3774,9 +3798,13 @@ def _visit_Return(self, expr):
v = o.var
if not (isinstance(r, PyccelSymbol) and r == (v.name if isinstance(v, Variable) else v)):
# Create a syntactic object to visit
pyccel_stage.set_stage('syntactic')
if isinstance(v, Variable):
v = PyccelSymbol(v.name)
a = self._visit(Assign(v, r, python_ast=expr.python_ast))
syntactic_assign = Assign(v, r, python_ast=expr.python_ast)
pyccel_stage.set_stage('semantic')

a = self._visit(syntactic_assign)
assigns.append(a)
if isinstance(a, ConstructorCall):
a.cls_variable.is_temp = False
Expand Down Expand Up @@ -3871,6 +3899,7 @@ def unpack(ann):
templates = {t: v for t,v in templates.items() if t in used_type_names}

# Create new temparary templates for the arguments with a Union data type.
pyccel_stage.set_stage('syntactic')
tmp_templates = {}
new_expr_args = []
for a in expr.arguments:
Expand All @@ -3889,6 +3918,7 @@ def unpack(ann):
value=a.value, kwonly=a.is_kwonly, annotation=dtype_symb))
else:
new_expr_args.append(a)
pyccel_stage.set_stage('semantic')

templates.update(tmp_templates)
template_combinations = list(product(*[v.type_list for v in templates.values()]))
Expand Down Expand Up @@ -4513,10 +4543,15 @@ def _visit_MacroFunction(self, expr):
for hd in header:
for i,_ in enumerate(hd.dtypes):
self.scope.insert_symbol(f'arg_{i}')
arguments = [FunctionDefArgument(self._visit(AnnotatedPyccelSymbol(f'arg_{i}', annotation = arg))[0]) \
pyccel_stage.set_stage('syntactic')
syntactic_args = [AnnotatedPyccelSymbol(f'arg_{i}', annotation = arg) \
for i, arg in enumerate(hd.dtypes)]
results = [FunctionDefResult(self._visit(AnnotatedPyccelSymbol(f'out_{i}', annotation = arg))[0]) \
syntactic_results = [AnnotatedPyccelSymbol(f'out_{i}', annotation = arg) \
for i, arg in enumerate(hd.results)]
pyccel_stage.set_stage('semantic')

arguments = [FunctionDefArgument(self._visit(a)[0]) for a in syntactic_args]
results = [FunctionDefResult(self._visit(r)[0]) for r in syntactic_results]
interfaces.append(FunctionDef(f_name, arguments, results, []))

# TODO -> Said: must handle interface
Expand Down Expand Up @@ -4600,8 +4635,12 @@ def _visit_NumpyNonZero(self, func_call):
# expr is a FunctionCall
arg = func_call_args[0].value
if not isinstance(arg, Variable):
pyccel_stage.set_stage('syntactic')
new_symbol = PyccelSymbol(self.scope.get_new_name())
creation = self._visit(Assign(new_symbol, arg, python_ast=func_call.python_ast))
syntactic_assign = Assign(new_symbol, arg, python_ast=func_call.python_ast)
pyccel_stage.set_stage('semantic')

creation = self._visit(syntactic_assign)
self._additional_exprs[-1].append(creation)
arg = self._visit(new_symbol)
return NumpyWhere(arg)
Expand Down

0 comments on commit 2728d40

Please sign in to comment.