Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Fix assignment of linspace to a slice #1215

Draft
wants to merge 13 commits into
base: devel
Choose a base branch
from
2 changes: 1 addition & 1 deletion pyccel/ast/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def collect_loops(block, indices, new_index, language_has_vectors = False, resul

# TODO [NH]: get all indices when adding axis argument to linspace function
if isinstance(line.rhs, NumpyLinspace):
line.rhs.ind = indices[0]
line.rhs.ind = indices[len(line.lhs.shape) - len(line.rhs.shape)]

# Replace variable expressions with Indexed versions
line.substitute(variables, new_vars,
Expand Down
3 changes: 3 additions & 0 deletions pyccel/ast/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,9 @@ def __init__(self, base, *indices):
if not indices:
raise IndexError('Indexed needs at least one index.')

if isinstance(base, IndexedElement) and base._is_slice:
raise NotImplementedError("Can't extract a slice from a slice")

self._label = base
self._shape = None
if pyccel_stage == 'syntactic':
Expand Down
60 changes: 31 additions & 29 deletions pyccel/codegen/printing/fcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,9 @@ def _print_NumpyNorm(self, expr):
return code

def _print_NumpyLinspace(self, expr):
start = self._print(expr.start)
step = self._print(expr.step)
end = self._print(PyccelMinus(expr.num, LiteralInteger(1), simplify = True))

if expr.stop.dtype != expr.dtype:
cast_func = DtypePrecisionToCastFunction[expr.dtype]
Expand All @@ -1069,48 +1072,47 @@ def _print_NumpyLinspace(self, expr):
else:
v = self._print(expr.stop)

if expr.rank > 1:
index = self._print(expr.ind)
init_value = f'({start} + {index}*{step})'
else:
zero = self._print(LiteralInteger(0))
var = self.scope.get_temporary_variable(PythonNativeInt(), 'linspace_index')
index = self._print(var)
init_value = f'[(({start} + {index}*{step}), {index} = {zero},{end})]'

if not isinstance(expr.endpoint, LiteralFalse):
lhs = expr.get_user_nodes(Assign)[0].lhs


if expr.rank > 1:
#expr.rank > 1, we need to replace the last index of the loop with the last index of the array.
lhs_source = expr.get_user_nodes(Assign)[0].lhs
lhs_source.substitute(expr.ind, PyccelMinus(expr.num, LiteralInteger(1), simplify = True))
lhs = self._print(lhs_source)
lhs_code = self._print(lhs_source)
else:
#Since the expr.rank == 1, we modify the last element in the array.
lhs = self._print(IndexedElement(lhs,
PyccelMinus(expr.num, LiteralInteger(1),
simplify = True)))
if isinstance(lhs, IndexedElement):
arr = lhs.base
indices = lhs.indices
if indices[0].step:
indx = IndexedElement(arr, PyccelMinus(indices[0].stop, PyccelMod(indices[0].stop, indices[0].step),
simplify = True))
else:
indx = IndexedElement(arr, PyccelMinus(indices[0].stop, LiteralInteger(1),
simplify = True))
else:
indx = self._print(IndexedElement(lhs,
PyccelMinus(expr.num, LiteralInteger(1),
simplify = True)))
lhs_code = self._print(indx)

if isinstance(expr.endpoint, LiteralTrue):
cond_template = lhs + ' = {stop}'
init_value += f'\n{lhs_code} = {v}'
else:
cond_template = lhs + ' = merge({stop}, {lhs}, ({cond}))'
if expr.rank > 1:
template = '({start} + {index}*{step})'
var = expr.ind
else:
template = '[(({start} + {index}*{step}), {index} = {zero},{end})]'
var = self.scope.get_temporary_variable(PythonNativeInt(), 'linspace_index')
cond=self._print(expr.endpoint)
init_value += f'\n{lhs_code} = merge({v}, {lhs_code}, ({cond}))'

init_value = template.format(
start = self._print(expr.start),
step = self._print(expr.step),
index = self._print(var),
zero = self._print(LiteralInteger(0)),
end = self._print(PyccelMinus(expr.num, LiteralInteger(1), simplify = True)),
)

if isinstance(expr.endpoint, LiteralFalse):
code = init_value
elif isinstance(expr.endpoint, LiteralTrue):
code = init_value + '\n' + cond_template.format(stop=v)
else:
code = init_value + '\n' + cond_template.format(stop=v, lhs=lhs, cond=self._print(expr.endpoint))

return code
return init_value

def _print_NumpyNonZeroElement(self, expr):

Expand Down
24 changes: 23 additions & 1 deletion tests/epyccel/recognised_functions/test_numpy_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy import isclose, iinfo, finfo, complex64, complex128
import numpy as np

from pyccel.decorators import template, types
from pyccel.decorators import template, types, allow_negative_index
from pyccel import epyccel

min_int8 = iinfo('int8').min
Expand Down Expand Up @@ -5777,6 +5777,28 @@ def test_linspace4(start : 'complex128[:,:]', stop : 'complex128[:,:]', out : 'c
epyccel_func4(cmplx, cmplx2, out, True)
assert np.allclose(arr, out)

def test_linspace_slice_assign(language):
def linspace_assign(n : int):
from numpy import zeros, linspace
p = n//3
arr = zeros(n)
arr[p+1:n-p] = linspace(0, 1, p)
return arr

@allow_negative_index('arr')
def linspace_assign_neg_slice(n : int):
from numpy import zeros, linspace
p = n//3
arr = zeros(n)
arr[-p:] = linspace(0, 1, p)
return arr

epyccel_func = epyccel(linspace_assign, language=language)
assert np.allclose(linspace_assign(10), epyccel_func(10))

epyccel_func = epyccel(linspace_assign_neg_slice, language=language)
assert np.allclose(linspace_assign_neg_slice(10), epyccel_func(10))

@pytest.mark.parametrize( 'language', (
pytest.param("fortran", marks = pytest.mark.fortran),
pytest.param("c", marks = [
Expand Down