Skip to content

Commit

Permalink
Fix translation for conflicting module name (#1854)
Browse files Browse the repository at this point in the history
Fix the translation of a file whose name conflicts with Fortran keywords
by ensuring that the original name is correctly extracted from the
scope. Fixes #1853.

**Commit Summary**

- Ensure Python name of module exists in the scope by adding it in the
syntactic stage.
- Don't use `AsName` for module name
- Save the Python name of a module as the name of the `PyModule`
- Ensure Python names are used for imports
- Remove hacky `set_name` function.
- Remove unused `assign_to` argument of `CodePrinter.doprint`.

---------

Co-authored-by: Yaman Güçlü <yaman.guclu@gmail.com>
  • Loading branch information
EmilyBourne and yguclu committed May 7, 2024
1 parent d422cac commit 8445dd7
Show file tree
Hide file tree
Showing 15 changed files with 97 additions and 40 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ All notable changes to this project will be documented in this file.
- #1830 : Fix missing allocation when returning an annotated array expression.
- #1821 : Ensure an error is raised when creating an ambiguous interface.
- #1842 : Fix homogeneous tuples incorrectly identified as inhomogeneous.
- #1853 : Fix translation of a file whose name conflicts with Fortran keywords.

### Changed

Expand Down Expand Up @@ -83,6 +84,8 @@ All notable changes to this project will be documented in this file.
- \[INTERNALS\] Remove `pyccel.ast.utilities.builtin_functions`.
- \[INTERNALS\] Remove unused/unnecessary functions in `pyccel.parser.utilities` : `read_file`, `header_statement`, `accelerator_statement`, `get_module_name`, `view_tree`.
- \[INTERNALS\] Remove unused functions `Errors.unset_target`, and `Errors.reset_target`.
- \[INTERNALS\] Remove function `Module.set_name`.
- \[INTERNALS\] Remove unused `assign_to` argument of `CodePrinter.doprint`.

## \[1.11.2\] - 2024-03-05

Expand Down
13 changes: 6 additions & 7 deletions pyccel/ast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,8 +1023,8 @@ def __init__(
imports=(),
scope = None
):
if not isinstance(name, (str, AsName)):
raise TypeError('name must be a string or an AsName')
if not isinstance(name, str):
raise TypeError('name must be a string')

if not iterable(variables):
raise TypeError('variables must be an iterable')
Expand Down Expand Up @@ -1179,11 +1179,6 @@ def body(self):
"""
return self.interfaces + self.funcs + self.classes

def set_name(self, new_name):
""" Function for changing the name of a module
"""
self._name = new_name

def __getitem__(self, arg):
assert isinstance(arg, str)
args = arg.split('.')
Expand Down Expand Up @@ -3721,6 +3716,10 @@ def __init__(self, source, target = None, ignore_at_print = False, mod = None):
self._target = set()
self._source_mod = mod
self._ignore_at_print = ignore_at_print

if mod is None and isinstance(target, Module):
self._source_mod = target

if target is None:
if pyccel_stage == "syntactic":
target = []
Expand Down
12 changes: 9 additions & 3 deletions pyccel/codegen/printing/ccode.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def _print_Header(self, expr):

def _print_ModuleHeader(self, expr):
self.set_scope(expr.module.scope)
self._current_module = expr.module
self._in_header = True
name = expr.module.name
if isinstance(name, AsName):
Expand Down Expand Up @@ -792,6 +793,7 @@ def _print_ModuleHeader(self, expr):

self._in_header = False
self.exit_scope()
self._current_module = None
return (f"#ifndef {name.upper()}_H\n \
#define {name.upper()}_H\n\n \
{imports}\n \
Expand All @@ -802,13 +804,13 @@ def _print_ModuleHeader(self, expr):

def _print_Module(self, expr):
self.set_scope(expr.scope)
self._current_module = expr.name
self._current_module = expr
body = ''.join(self._print(i) for i in expr.body)

global_variables = ''.join([self._print(d) for d in expr.declarations])

# Print imports last to be sure that all additional_imports have been collected
imports = [Import(expr.name, Module(expr.name,(),())), *self._additional_imports.values()]
imports = [Import(self.scope.get_python_name(expr.name), Module(expr.name,(),())), *self._additional_imports.values()]
imports = ''.join(self._print(i) for i in imports)

code = ('{imports}\n'
Expand All @@ -819,6 +821,7 @@ def _print_Module(self, expr):
body = body)

self.exit_scope()
self._current_module = None
return code

def _print_Break(self, expr):
Expand Down Expand Up @@ -951,7 +954,7 @@ def _print_Import(self, expr):
if source in import_dict: # pylint: disable=consider-using-get
source = import_dict[source]

if expr.source_module:
if expr.source_module and expr.source_module is not self._current_module:
for classDef in expr.source_module.classes:
class_scope = classDef.scope
for method in classDef.methods:
Expand Down Expand Up @@ -2426,6 +2429,8 @@ def _print_Omp_End_Clause(self, expr):
#=====================================

def _print_Program(self, expr):
mod = expr.get_direct_user_nodes(lambda x: isinstance(x, Module))[0]
self._current_module = mod
self.set_scope(expr.scope)
body = self._print(expr.body)
variables = self.scope.variables.values()
Expand All @@ -2435,6 +2440,7 @@ def _print_Program(self, expr):
imports = ''.join(self._print(i) for i in imports)

self.exit_scope()
self._current_module = None
return ('{imports}'
'int main()\n{{\n'
'{decs}'
Expand Down
28 changes: 14 additions & 14 deletions pyccel/codegen/printing/codeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pyccel.ast.basic import PyccelAstNode

from pyccel.ast.core import Assign
from pyccel.ast.core import Module, ModuleHeader, Program
from pyccel.ast.internals import PyccelSymbol

from pyccel.errors.errors import Errors
Expand All @@ -22,31 +22,31 @@
class CodePrinter:
"""
The base class for code-printing subclasses.
The base class from which code printers inherit. The sub-classes should define a language
and `_print_X` functions.
"""
language = None
def __init__(self):
self._scope = None

def doprint(self, expr, assign_to=None):
def doprint(self, expr):
"""
Print the expression as code.
Print the expression as code.
Parameters
----------
expr : Expression
The expression to be printed.
assign_to : PyccelSymbol, MatrixSymbol, or string (optional)
If provided, the printed code will set the expression to a
variable with name ``assign_to``.
Returns
-------
str
The generated code.
"""

if isinstance(assign_to, str):
assign_to = PyccelSymbol(assign_to)
elif not isinstance(assign_to, (PyccelAstNode, type(None))):
raise TypeError("{0} cannot assign to object of type {1}".format(
type(self).__name__, type(assign_to)))

if assign_to:
expr = Assign(assign_to, expr)
assert isinstance(expr, (Module, ModuleHeader, Program))

# Do the actual printing
lines = self._print(expr).splitlines(True)
Expand Down
6 changes: 5 additions & 1 deletion pyccel/codegen/printing/cwrappercode.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _print_PyModule_Create(self, expr):

def _print_ModuleHeader(self, expr):
mod = expr.module
self._current_module = expr.module
name = mod.name

# Print imports last to be sure that all additional_imports have been collected
Expand Down Expand Up @@ -293,6 +294,7 @@ def _print_ModuleHeader(self, expr):
static_import_decs = self._print(Declare(API_var, static=True))
import_func = self._print(mod.import_func)

self._current_module = None
header_id = f'{name.upper()}_WRAPPER'
header_guard = f'{header_id}_H'
return (f"#ifndef {header_guard}\n \
Expand All @@ -311,6 +313,7 @@ def _print_ModuleHeader(self, expr):
def _print_PyModule(self, expr):
scope = expr.scope
self.set_scope(scope)
self._current_module = expr

# Insert declared objects into scope
variables = expr.original_module.variables if isinstance(expr, BindCModule) else expr.variables
Expand All @@ -322,7 +325,7 @@ def _print_PyModule(self, expr):

funcs = []

self._module_name = self.get_python_name(scope, expr)
self._module_name = expr.name
sep = self._print(SeparatorComment(40))

interface_funcs = [f.name for i in expr.interfaces for f in i.functions]
Expand Down Expand Up @@ -373,6 +376,7 @@ def _print_PyModule(self, expr):
imports = ''.join(self._print(i) for i in imports)

self.exit_scope()
self._current_module = None

return '\n'.join(['#define PY_ARRAY_UNIQUE_SYMBOL CWRAPPER_ARRAY_API',
f'#define {pymod_name.upper()}\n',
Expand Down
2 changes: 1 addition & 1 deletion pyccel/codegen/printing/fcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def _print_Import(self, expr):
return ''

if expr.source_module:
source = expr.source_module.scope.get_expected_name(source)
source = expr.source_module.name

if 'mpi4py' == str(getattr(expr.source,'name',expr.source)):
return 'use mpi\n' + 'use mpiext\n'
Expand Down
5 changes: 4 additions & 1 deletion pyccel/codegen/printing/pycode.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ def _print_Return(self, expr):
def _print_Program(self, expr):
mod_scope = self.scope
self.set_scope(expr.scope)
imports = ''.join(self._print(i) for i in expr.imports)
modules = expr.get_direct_user_nodes(lambda m: isinstance(m, Module))
assert len(modules) == 1
module = modules[0]
imports = ''.join(self._print(i) for i in expr.imports if i.source_module is not module)
body = self._print(expr.body)
imports += ''.join(self._print(i) for i in self.get_additional_imports())

Expand Down
3 changes: 0 additions & 3 deletions pyccel/codegen/python_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def create_shared_library(codegen,
# Print code specific cwrapper
#---------------------------------------
module_old_name = codegen.ast.name
codegen.ast.set_name(sharedlib_modname)
wrapper_codegen = CWrapperCodePrinter(codegen.parser.filename, language)
Scope.name_clash_checker = name_clash_checkers['c']
wrapper = CToPythonWrapper(base_dirpath)
Expand All @@ -185,8 +184,6 @@ def create_shared_library(codegen,
if errors.has_errors():
return

codegen.ast.set_name(module_old_name)

with open(wrapper_filename, 'w', encoding="utf-8") as f:
f.writelines(wrapper_code)
timings['Wrapper printing'] = time.time() - start_print_cwrapper
Expand Down
11 changes: 6 additions & 5 deletions pyccel/codegen/wrapper/c_to_python_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def _build_module_init_function(self, expr, imports):
The initialisation function.
"""

mod_name = getattr(expr, 'original_module', expr).name
mod_name = self.scope.get_python_name(getattr(expr, 'original_module', expr).name)
# Initialise the scope
func_name = self.scope.get_new_name(f'PyInit_{mod_name}')
func_scope = self.scope.new_child_scope(func_name)
Expand Down Expand Up @@ -633,7 +633,7 @@ def _build_module_init_function(self, expr, imports):
ok_code = LiteralInteger(0)

# Save Capsule describing types (needed for dependent modules)
body.append(AliasAssign(capsule_obj, PyCapsule_New(API_var, self.scope.get_python_name(mod_name))))
body.append(AliasAssign(capsule_obj, PyCapsule_New(API_var, mod_name)))
body.extend(self._add_object_to_mod(module_var, capsule_obj, '_C_API', initialised))

body.append(FunctionCall(import_array, ()))
Expand Down Expand Up @@ -1052,9 +1052,10 @@ def _wrap_Module(self, expr):

imports += cwrapper_ndarray_imports if self._wrapping_arrays else []
if not isinstance(expr, BindCModule):
imports.append(Import(expr.name, expr))
imports.append(Import(mod_scope.get_python_name(expr.name), expr))
original_mod = getattr(expr, 'original_module', expr)
return PyModule(original_mod.name, [API_var], funcs, imports = imports,
original_mod_name = mod_scope.get_python_name(original_mod.name)
return PyModule(original_mod_name, [API_var], funcs, imports = imports,
interfaces = interfaces, classes = classes, scope = mod_scope,
init_func = init_func, import_func = import_func)

Expand Down Expand Up @@ -2121,7 +2122,7 @@ def _wrap_Import(self, expr):

if import_wrapper:
wrapper_name = f'{expr.source}_wrapper'
mod_spoof = PyModule(expr.source_module.name.name, (), (), scope = Scope())
mod_spoof = PyModule(expr.source_module.name, (), (), scope = Scope())
return Import(wrapper_name, AsName(mod_spoof, expr.source), mod = mod_spoof)
else:
return None
6 changes: 3 additions & 3 deletions pyccel/codegen/wrapper/fortran_to_c_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ def _wrap_Module(self, expr):
classes = [self._wrap(f) for f in expr.classes]
variables = [self._wrap(v) for v in expr.variables if not v.is_private]
variable_getters = [v for v in variables if isinstance(v, BindCArrayVariable)]
imports = [Import(expr.name, target = expr, mod=expr)]
imports = [Import(self.scope.get_python_name(expr.name), target = expr, mod=expr)]

name = mod_scope.get_new_name(f'bind_c_{expr.name.target}')
self._wrapper_names_dict[expr.name.target] = name
name = mod_scope.get_new_name(f'bind_c_{expr.name}')
self._wrapper_names_dict[expr.name] = name

self.exit_scope()

Expand Down
2 changes: 1 addition & 1 deletion pyccel/parser/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,7 +2195,7 @@ def _visit_Module(self, expr):

if expr.program:
container = prog_scope.imports
container['imports'][mod_name] = Import(mod_name, mod)
container['imports'][mod_name] = Import(self.scope.get_python_name(mod_name), mod)

if init_func:
import_init = FunctionCall(init_func, [], [])
Expand Down
3 changes: 2 additions & 1 deletion pyccel/parser/syntactic.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ def _visit_Module(self, stmt):
# Define the name of the module
# The module name allows it to be correctly referenced from an import command
mod_name = os.path.splitext(os.path.basename(self._filename))[0]
name = AsName(mod_name, self.scope.get_new_name(mod_name))
name = self.scope.get_new_name(mod_name)
self.scope.python_names[name] = mod_name

body = [b for i in body for b in (i.body if isinstance(i, CodeBlock) else [i])]
return Module(name, [], functions, init_func = CodeBlock(body), scope = self.scope,
Expand Down
21 changes: 21 additions & 0 deletions tests/pyccel/scripts/endif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# pylint: disable=missing-function-docstring, missing-module-docstring

def matmul(a: 'float[:,:](order=C)',
b: 'float[:,:](order=F)',
c: 'float[:,:](order=C)'):

m, p = a.shape
q, n = b.shape
r, s = c.shape

if p != q or m != r or n != s:
return -1

for i in range(m):
for j in range(n):
c[i, j] = 0.0
for k in range(p):
c[i, j] += a[i, k] * b[k, j]

return 0

10 changes: 10 additions & 0 deletions tests/pyccel/scripts/runtest_badly_named_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# pylint: disable=missing-function-docstring, missing-module-docstring
import numpy as np
from endif import matmul

if __name__ == '__main__':
a = np.ones((3,4))
b = np.ones((4,3), order='F')
c = np.empty((3,3))
matmul(a, b, c)
print(c)
12 changes: 12 additions & 0 deletions tests/pyccel/test_pyccel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,3 +1190,15 @@ def test_time_execution_flag():
assert 'Total' in result_lines[-2]
for l in result_lines[1:-1]:
assert ' : ' in l

#------------------------------------------------------------------------------
def test_module_name_containing_conflict(language):
base_dir = os.path.dirname(os.path.realpath(__file__))
path_dir = os.path.join(base_dir, "scripts")
compile_pyccel(path_dir, get_abs_path("scripts/endif.py"), options = f"--language={language}")

test_file = get_abs_path("scripts/runtest_badly_named_module.py")
out1 = get_python_output(test_file)
out2 = get_python_output(test_file)

assert out1 == out2

0 comments on commit 8445dd7

Please sign in to comment.