From 3e650cb02ad162d6756ce283a030aea710107642 Mon Sep 17 00:00:00 2001 From: nhamidn Date: Tue, 11 Oct 2022 11:30:03 +0100 Subject: [PATCH 1/9] fix problem rank == 1 --- pyccel/codegen/printing/fcode.py | 43 ++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/pyccel/codegen/printing/fcode.py b/pyccel/codegen/printing/fcode.py index fac76a8490..04f48a24ee 100644 --- a/pyccel/codegen/printing/fcode.py +++ b/pyccel/codegen/printing/fcode.py @@ -914,18 +914,41 @@ def _print_NumpyLinspace(self, expr): if not isinstance(expr.endpoint, LiteralFalse): lhs = expr.get_user_nodes(Assign)[0].lhs + if isinstance(lhs, IndexedElement): + print(lhs) + arr = lhs.base + indices = lhs.indices + print(indices[0]) + print(indices[0].start) + print(indices[0].stop) + if expr.rank > 1: + #expr.rank > 1, we need to replace the last index of the loop with the last index of the array. + lhs_source = expr.get_user_nodes(Assign)[0].lhs + lhs_source.substitute(expr.ind, PyccelMinus(expr.num, LiteralInteger(1), simplify = True)) + lhs = self._print(lhs_source) + else: + #Since the expr.rank == 1, we modify the last element in the array. + if indices[0].step: + indx = IndexedElement(arr, PyccelMinus(indices[0].stop, PyccelMod(indices[0].stop, indices[0].step), + simplify = True)) + else: + indx = IndexedElement(arr, PyccelMinus(indices[0].stop, LiteralInteger(1), + simplify = True)) + lhs = self._print(indx) + print(arr) + print(indices) - - if expr.rank > 1: - #expr.rank > 1, we need to replace the last index of the loop with the last index of the array. - lhs_source = expr.get_user_nodes(Assign)[0].lhs - lhs_source.substitute(expr.ind, PyccelMinus(expr.num, LiteralInteger(1), simplify = True)) - lhs = self._print(lhs_source) else: - #Since the expr.rank == 1, we modify the last element in the array. - lhs = self._print(IndexedElement(lhs, - PyccelMinus(expr.num, LiteralInteger(1), - simplify = True))) + if expr.rank > 1: + #expr.rank > 1, we need to replace the last index of the loop with the last index of the array. + lhs_source = expr.get_user_nodes(Assign)[0].lhs + lhs_source.substitute(expr.ind, PyccelMinus(expr.num, LiteralInteger(1), simplify = True)) + lhs = self._print(lhs_source) + else: + #Since the expr.rank == 1, we modify the last element in the array. + lhs = self._print(IndexedElement(lhs, + PyccelMinus(expr.num, LiteralInteger(1), + simplify = True))) if isinstance(expr.endpoint, LiteralTrue): cond_template = lhs + ' = {stop}' From dbe628d41ea63cbbcd575cb9f38f7a916af4b561 Mon Sep 17 00:00:00 2001 From: nhamidn Date: Tue, 11 Oct 2022 11:40:53 +0100 Subject: [PATCH 2/9] comment prints --- pyccel/codegen/printing/fcode.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyccel/codegen/printing/fcode.py b/pyccel/codegen/printing/fcode.py index 04f48a24ee..348846784b 100644 --- a/pyccel/codegen/printing/fcode.py +++ b/pyccel/codegen/printing/fcode.py @@ -918,9 +918,9 @@ def _print_NumpyLinspace(self, expr): print(lhs) arr = lhs.base indices = lhs.indices - print(indices[0]) - print(indices[0].start) - print(indices[0].stop) + # print(indices[0]) + # print(indices[0].start) + # print(indices[0].stop) if expr.rank > 1: #expr.rank > 1, we need to replace the last index of the loop with the last index of the array. lhs_source = expr.get_user_nodes(Assign)[0].lhs From a066957e1653e536e8081ac3ce9065f14c2022bf Mon Sep 17 00:00:00 2001 From: nhamidn Date: Tue, 11 Oct 2022 15:37:59 +0100 Subject: [PATCH 3/9] simplify the code --- pyccel/codegen/printing/fcode.py | 48 ++++++++++++-------------------- 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/pyccel/codegen/printing/fcode.py b/pyccel/codegen/printing/fcode.py index 348846784b..f409989863 100644 --- a/pyccel/codegen/printing/fcode.py +++ b/pyccel/codegen/printing/fcode.py @@ -914,41 +914,29 @@ def _print_NumpyLinspace(self, expr): if not isinstance(expr.endpoint, LiteralFalse): lhs = expr.get_user_nodes(Assign)[0].lhs - if isinstance(lhs, IndexedElement): - print(lhs) - arr = lhs.base - indices = lhs.indices - # print(indices[0]) - # print(indices[0].start) - # print(indices[0].stop) - if expr.rank > 1: - #expr.rank > 1, we need to replace the last index of the loop with the last index of the array. - lhs_source = expr.get_user_nodes(Assign)[0].lhs - lhs_source.substitute(expr.ind, PyccelMinus(expr.num, LiteralInteger(1), simplify = True)) - lhs = self._print(lhs_source) - else: - #Since the expr.rank == 1, we modify the last element in the array. + + if expr.rank > 1: + #expr.rank > 1, we need to replace the last index of the loop with the last index of the array. + lhs_source = expr.get_user_nodes(Assign)[0].lhs + lhs_source.substitute(expr.ind, PyccelMinus(expr.num, LiteralInteger(1), simplify = True)) + lhs = self._print(lhs_source) + else: + #Since the expr.rank == 1, we modify the last element in the array. + if isinstance(lhs, IndexedElement): + print(lhs) + arr = lhs.base + indices = lhs.indices if indices[0].step: indx = IndexedElement(arr, PyccelMinus(indices[0].stop, PyccelMod(indices[0].stop, indices[0].step), - simplify = True)) + simplify = True)) else: indx = IndexedElement(arr, PyccelMinus(indices[0].stop, LiteralInteger(1), - simplify = True)) - lhs = self._print(indx) - print(arr) - print(indices) - - else: - if expr.rank > 1: - #expr.rank > 1, we need to replace the last index of the loop with the last index of the array. - lhs_source = expr.get_user_nodes(Assign)[0].lhs - lhs_source.substitute(expr.ind, PyccelMinus(expr.num, LiteralInteger(1), simplify = True)) - lhs = self._print(lhs_source) + simplify = True)) else: - #Since the expr.rank == 1, we modify the last element in the array. - lhs = self._print(IndexedElement(lhs, - PyccelMinus(expr.num, LiteralInteger(1), - simplify = True))) + indx = self._print(IndexedElement(lhs, + PyccelMinus(expr.num, LiteralInteger(1), + simplify = True))) + lhs = self._print(indx) if isinstance(expr.endpoint, LiteralTrue): cond_template = lhs + ' = {stop}' From 4450d9a6a2731bfe43dc002bf3c7485c33c85c7a Mon Sep 17 00:00:00 2001 From: nhamidn Date: Thu, 13 Oct 2022 11:23:40 +0100 Subject: [PATCH 4/9] get the correct indice in all the cases lhs is a slice or is the resulted linspace array --- pyccel/ast/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyccel/ast/utilities.py b/pyccel/ast/utilities.py index dca7d425c4..21b6a7104e 100644 --- a/pyccel/ast/utilities.py +++ b/pyccel/ast/utilities.py @@ -474,7 +474,7 @@ def collect_loops(block, indices, new_index, language_has_vectors = False, resul # TODO [NH]: get all indices when adding axis argument to linspace function if isinstance(line.rhs, NumpyLinspace): - line.rhs.ind = indices[0] + line.rhs.ind = indices[len(line.lhs.shape) - len(line.rhs.shape)] # Replace variable expressions with Indexed versions line.substitute(variables, new_vars, excluded_nodes = (FunctionCall, PyccelInternalFunction)) From a7bec511dadc928ed29a13740a3da0efe9bdfe72 Mon Sep 17 00:00:00 2001 From: EmilyBourne Date: Mon, 13 May 2024 10:49:45 +0200 Subject: [PATCH 5/9] Ensure a NotImplementedError is raised for unsupported code --- pyccel/ast/variable.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyccel/ast/variable.py b/pyccel/ast/variable.py index 051cf631b7..348ec4c216 100644 --- a/pyccel/ast/variable.py +++ b/pyccel/ast/variable.py @@ -802,6 +802,9 @@ def __init__(self, base, *indices): if not indices: raise IndexError('Indexed needs at least one index.') + if isinstance(base, IndexedElement) and base._is_slice: + raise NotImplementedError("Can't extract a slice from a slice") + self._label = base self._shape = None if pyccel_stage == 'syntactic': From e8b8a0c0eb492ba5c166332405b203f9dc1de493 Mon Sep 17 00:00:00 2001 From: EmilyBourne Date: Mon, 13 May 2024 10:53:02 +0200 Subject: [PATCH 6/9] Ensure lhs is printed --- pyccel/codegen/printing/fcode.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyccel/codegen/printing/fcode.py b/pyccel/codegen/printing/fcode.py index 9657d03ad1..60c3310687 100644 --- a/pyccel/codegen/printing/fcode.py +++ b/pyccel/codegen/printing/fcode.py @@ -1077,7 +1077,6 @@ def _print_NumpyLinspace(self, expr): else: #Since the expr.rank == 1, we modify the last element in the array. if isinstance(lhs, IndexedElement): - print(lhs) arr = lhs.base indices = lhs.indices if indices[0].step: @@ -1090,7 +1089,7 @@ def _print_NumpyLinspace(self, expr): indx = self._print(IndexedElement(lhs, PyccelMinus(expr.num, LiteralInteger(1), simplify = True))) - lhs = self._print(indx) + lhs = self._print(indx) if isinstance(expr.endpoint, LiteralTrue): cond_template = lhs + ' = {stop}' From 674b6400dcfe2d2b560828873381c899a8ce77b0 Mon Sep 17 00:00:00 2001 From: EmilyBourne Date: Mon, 13 May 2024 10:58:48 +0200 Subject: [PATCH 7/9] Clean code using f-strings --- pyccel/codegen/printing/fcode.py | 44 +++++++++++++------------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/pyccel/codegen/printing/fcode.py b/pyccel/codegen/printing/fcode.py index 60c3310687..ce2d1d7302 100644 --- a/pyccel/codegen/printing/fcode.py +++ b/pyccel/codegen/printing/fcode.py @@ -1058,6 +1058,9 @@ def _print_NumpyNorm(self, expr): return code def _print_NumpyLinspace(self, expr): + start = self._print(expr.start) + step = self._print(expr.step) + end = self._print(PyccelMinus(expr.num, LiteralInteger(1), simplify = True)) if expr.stop.dtype != expr.dtype: cast_func = DtypePrecisionToCastFunction[expr.dtype] @@ -1066,6 +1069,15 @@ def _print_NumpyLinspace(self, expr): else: v = self._print(expr.stop) + if expr.rank > 1: + index = self._print(expr.ind) + init_value = f'({start} + {index}*{step})' + else: + zero = self._print(LiteralInteger(0)) + var = self.scope.get_temporary_variable(PythonNativeInt(), 'linspace_index') + index = self._print(var) + init_value = f'[(({start} + {index}*{step}), {index} = {zero},{end})]' + if not isinstance(expr.endpoint, LiteralFalse): lhs = expr.get_user_nodes(Assign)[0].lhs @@ -1073,7 +1085,7 @@ def _print_NumpyLinspace(self, expr): #expr.rank > 1, we need to replace the last index of the loop with the last index of the array. lhs_source = expr.get_user_nodes(Assign)[0].lhs lhs_source.substitute(expr.ind, PyccelMinus(expr.num, LiteralInteger(1), simplify = True)) - lhs = self._print(lhs_source) + lhs_code = self._print(lhs_source) else: #Since the expr.rank == 1, we modify the last element in the array. if isinstance(lhs, IndexedElement): @@ -1089,35 +1101,15 @@ def _print_NumpyLinspace(self, expr): indx = self._print(IndexedElement(lhs, PyccelMinus(expr.num, LiteralInteger(1), simplify = True))) - lhs = self._print(indx) + lhs_code = self._print(indx) if isinstance(expr.endpoint, LiteralTrue): - cond_template = lhs + ' = {stop}' + init_value += f'\n{lhs_code} = {v}' else: - cond_template = lhs + ' = merge({stop}, {lhs}, ({cond}))' - if expr.rank > 1: - template = '({start} + {index}*{step})' - var = expr.ind - else: - template = '[(({start} + {index}*{step}), {index} = {zero},{end})]' - var = self.scope.get_temporary_variable(PythonNativeInt(), 'linspace_index') - - init_value = template.format( - start = self._print(expr.start), - step = self._print(expr.step), - index = self._print(var), - zero = self._print(LiteralInteger(0)), - end = self._print(PyccelMinus(expr.num, LiteralInteger(1), simplify = True)), - ) - - if isinstance(expr.endpoint, LiteralFalse): - code = init_value - elif isinstance(expr.endpoint, LiteralTrue): - code = init_value + '\n' + cond_template.format(stop=v) - else: - code = init_value + '\n' + cond_template.format(stop=v, lhs=lhs, cond=self._print(expr.endpoint)) + cond=self._print(expr.endpoint) + init_value += f'\n{lhs_code} = merge({v}, {lhs_code}, ({cond}))' - return code + return init_value def _print_NumpyNonZeroElement(self, expr): From 478056ba322ba7cd01b5ceae1bef8c8b91f17db7 Mon Sep 17 00:00:00 2001 From: EmilyBourne Date: Tue, 14 May 2024 23:11:55 +0200 Subject: [PATCH 8/9] Add a test --- .../epyccel/recognised_functions/test_numpy_funcs.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/epyccel/recognised_functions/test_numpy_funcs.py b/tests/epyccel/recognised_functions/test_numpy_funcs.py index 2642ff227b..dbf032539e 100644 --- a/tests/epyccel/recognised_functions/test_numpy_funcs.py +++ b/tests/epyccel/recognised_functions/test_numpy_funcs.py @@ -5777,6 +5777,17 @@ def test_linspace4(start : 'complex128[:,:]', stop : 'complex128[:,:]', out : 'c epyccel_func4(cmplx, cmplx2, out, True) assert np.allclose(arr, out) +def test_linspace_slice_assign(language): + def linspace_assign(n : int): + from numpy import zeros, linspace + p = n//3 + arr = zeros(n) + arr[p+1:n-p] = linspace(0, 1, p) + return arr + + epyccel_func = epyccel(linspace_assign, language=language) + assert np.allclose(linspace_assign(10), epyccel_func(10)) + @pytest.mark.parametrize( 'language', ( pytest.param("fortran", marks = pytest.mark.fortran), pytest.param("c", marks = [ From 2faeba57c947ee931503202c04357c16bf554ca3 Mon Sep 17 00:00:00 2001 From: EmilyBourne Date: Tue, 14 May 2024 23:43:17 +0200 Subject: [PATCH 9/9] Add a test --- .../recognised_functions/test_numpy_funcs.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/epyccel/recognised_functions/test_numpy_funcs.py b/tests/epyccel/recognised_functions/test_numpy_funcs.py index dbf032539e..399cac777d 100644 --- a/tests/epyccel/recognised_functions/test_numpy_funcs.py +++ b/tests/epyccel/recognised_functions/test_numpy_funcs.py @@ -6,7 +6,7 @@ from numpy import isclose, iinfo, finfo, complex64, complex128 import numpy as np -from pyccel.decorators import template, types +from pyccel.decorators import template, types, allow_negative_index from pyccel import epyccel min_int8 = iinfo('int8').min @@ -5785,9 +5785,20 @@ def linspace_assign(n : int): arr[p+1:n-p] = linspace(0, 1, p) return arr + @allow_negative_index('arr') + def linspace_assign_neg_slice(n : int): + from numpy import zeros, linspace + p = n//3 + arr = zeros(n) + arr[-p:] = linspace(0, 1, p) + return arr + epyccel_func = epyccel(linspace_assign, language=language) assert np.allclose(linspace_assign(10), epyccel_func(10)) + epyccel_func = epyccel(linspace_assign_neg_slice, language=language) + assert np.allclose(linspace_assign_neg_slice(10), epyccel_func(10)) + @pytest.mark.parametrize( 'language', ( pytest.param("fortran", marks = pytest.mark.fortran), pytest.param("c", marks = [