Skip to content

Commit

Permalink
Handle scalar optionals as per the standard (#1629)
Browse files Browse the repository at this point in the history
Make the handling of scalar optional arguments compatible with all
compilers (fixes #1628). To this end move the creation of local versions
of the optional variables from the C-codegen stage (where it was not
correctly implemented) to the semantic parser.

Further, add tolerances in unit tests where they are missing for floating
point functions.
  • Loading branch information
EmilyBourne committed Nov 22, 2023
1 parent 6a2348a commit 6b4e07b
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ All notable changes to this project will be documented in this file.
- #1619 : Give priority to imported functions over builtin functions.
- #1614 : Allow relative paths for custom compilation file.
- #1615 : Fixed infinite loop when passing slices while copying arrays.
- #1628 : Fixed segmentation fault when writing to optional scalars.

### Changed

Expand Down
22 changes: 0 additions & 22 deletions pyccel/codegen/printing/ccode.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,6 @@ def __init__(self, filename, prefix_module = None):
self._temporary_args = []
self._current_module = None
self._in_header = False
# Dictionary linking optional variables to their
# temporary counterparts which provide allocated
# memory
# Key is optional variable
self._optional_partners = {}

def get_additional_imports(self):
"""return the additional imports collected in printing stage"""
Expand Down Expand Up @@ -1777,9 +1772,6 @@ def _print_FunctionDef(self, expr):

self.set_scope(expr.scope)

# Reinitialise optional partners
self._optional_partners = {}

arguments = [a.var for a in expr.arguments]
results = [r.var for r in expr.results]
if len(expr.results) > 1:
Expand Down Expand Up @@ -2003,20 +1995,6 @@ def _print_Assign(self, expr):
prefix_code = ''
lhs = expr.lhs
rhs = expr.rhs
if isinstance(lhs, Variable) and lhs.is_optional:
if lhs in self._optional_partners:
# Collect temporary variable which provides
# allocated memory space for this optional variable
tmp_var = self._optional_partners[lhs]
else:
# Create temporary variable to provide allocated
# memory space before assigning to the pointer value
# (may be NULL)
tmp_var = self.scope.get_temporary_variable(lhs,
is_optional = False)
self._optional_partners[lhs] = tmp_var
# Point optional variable at an allocated memory space
prefix_code = self._print(AliasAssign(lhs, tmp_var))
if isinstance(rhs, FunctionCall) and isinstance(rhs.dtype, NativeTuple):
self._temporary_args = [ObjectAddress(a) for a in lhs]
return prefix_code+'{};\n'.format(self._print(rhs))
Expand Down
2 changes: 1 addition & 1 deletion pyccel/codegen/printing/fcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,7 @@ def _print_Declare(self, expr):

# Compute intent string
if intent:
if intent == 'in' and rank == 0 and not (is_static and is_optional) \
if intent == 'in' and rank == 0 and not is_optional \
and not isinstance(expr_dtype, CustomDataType):
intentstr = ', value'
if is_const:
Expand Down
11 changes: 7 additions & 4 deletions pyccel/codegen/wrapper/c_to_python_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,21 +863,24 @@ def _wrap_FunctionDefArgument(self, expr):
else:
cast_func = pyarray_to_ndarray

cast = Assign(arg_var, FunctionCall(cast_func, [collect_arg]))
cast = [Assign(arg_var, FunctionCall(cast_func, [collect_arg]))]
if arg_var.is_optional:
memory_var = self.scope.get_temporary_variable(arg_var, name = arg_var.name + '_memory', is_optional = False)
cast.insert(0, AliasAssign(arg_var, memory_var))

# Create any necessary type checks and errors
if expr.has_default:
check_func, err = self._get_check_function(collect_arg, arg_var, False)
body.append(If( IfSection(check_func, [cast]),
body.append(If( IfSection(check_func, cast),
IfSection(PyccelIsNot(collect_arg, Py_None), [*err, Return([Nil()])])
))
elif not in_interface:
check_func, err = self._get_check_function(collect_arg, arg_var, True)
body.append(If( IfSection(check_func, [cast]),
body.append(If( IfSection(check_func, cast),
IfSection(LiteralTrue(), [*err, Return([Nil()])])
))
else:
body.append(cast)
body.extend(cast)

return body

Expand Down
30 changes: 28 additions & 2 deletions pyccel/parser/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def __init__(self, inputs, *, parents = (), d_parsers = (), **kwargs):
# used to store code split into multiple lines to be reinserted in the CodeBlock
self._additional_exprs = []

# used to store variables if optional parameters are changed
self._optional_params = {}

#
self._code = parser._code
# ...
Expand Down Expand Up @@ -2142,7 +2145,8 @@ def _visit_Pass(self, expr):

def _visit_Variable(self, expr):
name = self.scope.get_python_name(expr.name)
return self.get_variable(name)
var = self.get_variable(name)
return self._optional_params.get(var, var)

def _visit_str(self, expr):
return repr(expr)
Expand Down Expand Up @@ -2198,7 +2202,7 @@ def _visit_PyccelSymbol(self, expr):
errors.report(UNDEFINED_VARIABLE, symbol=name,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
return var
return self._optional_params.get(var, var)

def _visit_AnnotatedPyccelSymbol(self, expr):
# Check if the variable already exists
Expand Down Expand Up @@ -3031,6 +3035,20 @@ def _visit_Assign(self, expr):
new_rhs.extend(r)
# Repeat step to handle tuples of tuples of etc.
unravelling = True
elif isinstance(l, Variable) and l.is_optional:
if l in self._optional_params:
# Collect temporary variable which provides
# allocated memory space for this optional variable
new_lhs.append(self._optional_params[l])
else:
# Create temporary variable to provide allocated
# memory space before assigning to the pointer value
# (may be NULL)
tmp_var = self.scope.get_temporary_variable(l,
name = l.name+'_loc', is_optional = False)
self._optional_params[l] = tmp_var
new_lhs.append(tmp_var)
new_rhs.append(r)
else:
new_lhs.append(l)
new_rhs.append(r)
Expand Down Expand Up @@ -3664,6 +3682,14 @@ def _visit_FunctionDef(self, expr):
symbol=r, severity='error',
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset))

optional_inits = []
for a in arguments:
var = self._optional_params.pop(a.var, None)
if var:
optional_inits.append(If(IfSection(PyccelIsNot(a.var, Nil()),
[Assign(var, a.var)])))
body.insert2body(*optional_inits, back=False)

func_kwargs = {
'global_vars':global_vars,
'cls_name':cls_name,
Expand Down
12 changes: 12 additions & 0 deletions tests/epyccel/modules/Module_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,15 @@ def call_optional_1():

def call_optional_2(b : 'int' = None):
return basic_optional(b)

def change_optional(a : int = None):
if a is None:
a = 4
else:
a += 3
return 5+a

def optional_func_call():
x = 3
y = change_optional(x)
return x,y
6 changes: 4 additions & 2 deletions tests/epyccel/test_epyccel_augassign.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pyccel.epyccel import epyccel

# += tests
RTOL = 1e-12
ATOL = 1e-16

def test_augassign_add_1d(language):
f_int = mod.augassign_add_1d_int
Expand Down Expand Up @@ -230,7 +232,7 @@ def test_augassign_func(language):
z = func(x,y)
z_epyc = func_epyc(x,y)

assert z == z_epyc
assert np.isclose(z, z_epyc, rtol=RTOL, atol=ATOL)
assert isinstance(z, type(z_epyc))

@pytest.mark.parametrize( 'language', (
Expand All @@ -253,4 +255,4 @@ def test_augassign_array_func(language):
func(x,y)
func_epyc(x_epyc,y)

assert np.array_equal(x, x_epyc)
assert np.allclose(x, x_epyc, rtol=RTOL, atol=ATOL)
4 changes: 4 additions & 0 deletions tests/epyccel/test_epyccel_mod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: disable=missing-function-docstring, missing-module-docstring
import os
import sys
from numpy.random import randint, uniform
from numpy import allclose
Expand All @@ -10,6 +11,9 @@
if sys.platform == 'win32':
RTOL = 1e-13
ATOL = 1e-14
elif os.environ.get('PYCCEL_DEFAULT_COMPILER', 'GNU') == 'intel':
RTOL = 1e-11
ATOL = 1e-14
else:
RTOL = 2e-14
ATOL = 1e-15
Expand Down
3 changes: 2 additions & 1 deletion tests/epyccel/test_epyccel_optional_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ def test_f5(language):
def test_f6(language):
import modules.Module_4 as mod

modnew = epyccel(mod, language = language)
modnew = epyccel(mod, language = language, verbose=True)

# ...
assert mod.call_optional_1() == modnew.call_optional_1()
assert mod.call_optional_2(None) == modnew.call_optional_2(None)
assert mod.call_optional_2(0) == modnew.call_optional_2(0)
assert mod.call_optional_2() == modnew.call_optional_2()
assert mod.optional_func_call() == modnew.optional_func_call()
#------------------------------------------------------------------------------
def test_f7(Module_5):
mod, modnew = Module_5
Expand Down

0 comments on commit 6b4e07b

Please sign in to comment.