Skip to content

Commit

Permalink
Remove HomogeneousTupleVariable (#1623)
Browse files Browse the repository at this point in the history
Remove the class `HomogeneousTupleVariable`. This is no longer needed
now that we have `TypedAstNode.class_type`. Removing this class allows
us to store homogeneous tuples in classes. Fixes #1582.
  • Loading branch information
EmilyBourne committed Nov 21, 2023
1 parent 3a08f76 commit 3db9b8c
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 125 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ All notable changes to this project will be documented in this file.
- #1576 : Remove inline class method definition.
- Ensure an error is raised when if conditions are used in comprehension statements.
- #1553 : Fix `np.sign` when using the `ifort` compiler.
- #1582 : Allow homogeneous tuples in classes.

### Changed

- \[INTERNALS\] #1593 : Rename `PyccelAstNode.fst` to the `PyccelAstNode.ast`.
- \[INTERNALS\] #1593 : Use a setter instead of a method to update `PyccelAstNode.ast`.
- \[INTERNALS\] #1593 : Rename `BasicParser._current_fst_node` to the `BasicParser._current_ast_node`.
- \[INTERNALS\] #1390 : Remove dead code handling a `CodeBlock` in an assignment.
- \[INTERNALS\] #1582 : Remove the `HomogeneousTupleVariable` type.

### Deprecated

Expand Down
20 changes: 11 additions & 9 deletions pyccel/ast/numpyext.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .core import Module, Import, PyccelFunctionDef, FunctionCall

from .datatypes import (dtype_and_precision_registry as dtype_registry,
from .datatypes import (dtype_and_precision_registry as dtype_registry, NativeHomogeneousTuple,
default_precision, NativeInteger, DataType, NativeNumericTypes,
NativeFloat, NativeComplex, NativeBool, NativeNumeric)

Expand All @@ -35,7 +35,7 @@
from .literals import Nil
from .mathext import MathCeil
from .operators import broadcast, PyccelMinus, PyccelDiv, PyccelMul, PyccelAdd
from .variable import Variable, Constant, HomogeneousTupleVariable
from .variable import Variable, Constant

errors = Errors()
pyccel_stage = PyccelStage()
Expand Down Expand Up @@ -650,7 +650,7 @@ def __init__(self, arg, dtype=None, order='K', ndmin=None):
if not isinstance(arg, (PythonTuple, PythonList, Variable)):
raise TypeError('Unknown type of %s.' % type(arg))

is_homogeneous_tuple = isinstance(arg, (PythonTuple, HomogeneousTupleVariable)) and arg.is_homogeneous
is_homogeneous_tuple = isinstance(arg.class_type, NativeHomogeneousTuple)
is_array = isinstance(arg, Variable) and arg.is_ndarray

# TODO: treat inhomogenous lists and tuples when they have mixed ordering
Expand Down Expand Up @@ -1245,7 +1245,7 @@ class NumpyFull(NumpyNewArray):
shape : TypedAstNode
Shape of the new array, e.g., ``(2, 3)`` or ``2``.
For a 1D array this is either a `LiteralInteger` or an expression.
For a ND array this is a `PythonTuple` or a `HomogeneousTupleVariable`.
For a ND array this is a `TypedAstNode` with the class type NativeHomogeneousTuple.
fill_value : TypedAstNode
Fill value.
Expand Down Expand Up @@ -1405,7 +1405,7 @@ class constructor always returns an object of type `NumpyFull`.
shape : PythonTuple of TypedAstNode
Overrides the shape of the array.
For a 1D array this is either a `LiteralInteger` or an expression.
For a ND array this is a `PythonTuple` or a `HomogeneousTupleVariable`.
For a ND array this is a `TypedAstNode` with the class type NativeHomogeneousTuple.
See Also
--------
Expand Down Expand Up @@ -1449,7 +1449,7 @@ class constructor always returns an object of type `NumpyEmpty`.
shape : PythonTuple of TypedAstNode
Overrides the shape of the array.
For a 1D array this is either a `LiteralInteger` or an expression.
For a ND array this is a `PythonTuple` or a `HomogeneousTupleVariable`.
For a ND array this is a `TypedAstNode` with the class type NativeHomogeneousTuple.
See Also
--------
Expand Down Expand Up @@ -1495,7 +1495,7 @@ class constructor always returns an object of type `NumpyOnes`.
shape : PythonTuple of TypedAstNode
Overrides the shape of the array.
For a 1D array this is either a `LiteralInteger` or an expression.
For a ND array this is a `PythonTuple` or a `HomogeneousTupleVariable`.
For a ND array this is a `TypedAstNode` with the class type NativeHomogeneousTuple.
See Also
--------
Expand Down Expand Up @@ -1540,7 +1540,7 @@ class constructor always returns an object of type `NumpyZeros`.
shape : PythonTuple of TypedAstNode
Overrides the shape of the array.
For a 1D array this is either a `LiteralInteger` or an expression.
For a ND array this is a `PythonTuple` or a `HomogeneousTupleVariable`.
For a ND array this is a `TypedAstNode` with the class type NativeHomogeneousTuple.
See Also
--------
Expand Down Expand Up @@ -2092,7 +2092,7 @@ def dim(self):
"""
return self._dim

class NumpyNonZero(NumpyNewArray):
class NumpyNonZero(PyccelInternalFunction):
"""
Class representing a call to the function `numpy.nonzero`.
Expand All @@ -2117,6 +2117,8 @@ class NumpyNonZero(NumpyNewArray):
_precision = 8
_rank = 2
_order = 'C'
_class_type = NativeHomogeneousTuple()

def __init__(self, a):
if (a.rank > 1):
raise NotImplementedError("Non-Zero function is only implemented for 1D arrays")
Expand Down
27 changes: 17 additions & 10 deletions pyccel/ast/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .builtins import (builtin_functions_dict,
PythonRange, PythonList, PythonTuple)
from .cmathext import cmath_mod
from .datatypes import NativeHomogeneousTuple
from .internals import PyccelInternalFunction, Slice
from .itertoolsext import itertools_mod
from .literals import LiteralInteger, Nil
Expand All @@ -29,8 +30,7 @@
NumpyTranspose, NumpyLinspace)
from .operators import PyccelAdd, PyccelMul, PyccelIs, PyccelArithmeticOperator
from .scipyext import scipy_mod
from .variable import (Variable, IndexedElement, InhomogeneousTupleVariable,
HomogeneousTupleVariable )
from .variable import (Variable, IndexedElement, InhomogeneousTupleVariable )

from .c_concepts import ObjectAddress

Expand Down Expand Up @@ -654,12 +654,18 @@ def insert_fors(blocks, indices, scope, level = 0):
#==============================================================================
def expand_inhomog_tuple_assignments(block, language_has_vectors = False):
"""
Simplify expressions in a CodeBlock by unravelling tuple assignments into multiple lines
Simplify expressions in a CodeBlock by unravelling tuple assignments into multiple lines.
Simplify expressions in a CodeBlock by unravelling tuple assignments into multiple lines.
These changes are carried out in-place.
Parameters
==========
block : CodeBlock
The expression to be modified
----------
block : CodeBlock
The expression to be modified.
language_has_vectors : bool, default=False
Indicates whether the target language has built-in support for vector operations.
Examples
--------
Expand All @@ -677,13 +683,14 @@ def expand_inhomog_tuple_assignments(block, language_has_vectors = False):
"""
if not language_has_vectors:
allocs_to_unravel = [a for a in block.get_attribute_nodes(Assign) \
if isinstance(a.lhs, HomogeneousTupleVariable) \
and isinstance(a.rhs, (HomogeneousTupleVariable, Duplicate, Concatenate))]
if isinstance(a.lhs, Variable) \
and isinstance(a.lhs.class_type, NativeHomogeneousTuple) \
and isinstance(a.rhs.class_type, NativeHomogeneousTuple)]
new_allocs = [(Assign(a.lhs, NumpyEmpty(a.lhs.shape,
dtype=a.lhs.dtype,
order=a.lhs.order)
), a) if a.lhs.on_stack
else (a) if a.lhs.on_heap
), a) if getattr(a.lhs, 'on_stack', False)
else (a) if getattr(a.lhs, 'on_heap', False)
else (Allocate(a.lhs,
shape=a.lhs.shape,
order = a.lhs.order,
Expand Down
91 changes: 19 additions & 72 deletions pyccel/ast/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
'Constant',
'DottedName',
'DottedVariable',
'HomogeneousTupleVariable',
'IndexedElement',
'InhomogeneousTupleVariable',
'TupleVariable',
Expand Down Expand Up @@ -587,12 +586,18 @@ def __reduce_ex__(self, i):

def __getitem__(self, *args):

if len(args) == 1 and isinstance(args[0], (tuple, list)):
args = args[0]

if self.rank < len(args):
raise IndexError('Rank mismatch.')

if len(args) == 1:
arg0 = args[0]
if isinstance(arg0, (tuple, list)):
args = arg0
elif isinstance(arg0, int):
self_len = self.shape[0]
if isinstance(self_len, LiteralInteger) and arg0 >= int(self_len):
raise StopIteration

return IndexedElement(self, *args)

def invalidate_node(self):
Expand Down Expand Up @@ -664,74 +669,7 @@ def __ne__(self, other):
def __hash__(self):
return hash(str(self))

class TupleVariable(Variable):

"""Represents a tuple variable in the code.
Parameters
----------
arg_vars: Iterable
Multiple variables contained within the tuple
Examples
--------
>>> from pyccel.ast.core import TupleVariable, Variable
>>> v1 = Variable('int','v1')
>>> v2 = Variable('bool','v2')
>>> n = TupleVariable([v1, v2],'n')
>>> n
n
"""
__slots__ = ()

@property
def is_ndarray(self):
return False

class HomogeneousTupleVariable(TupleVariable):
"""
Represents a homogeneous tuple variable in the code.
Represents a homogeneous tuple variable in the code.
Parameters
----------
dtype : DataType
The data type of the elements of the tuple.
*args : tuple
See Variable.
**kwargs : dict
See Variable.
Examples
--------
>>> from pyccel.ast.core import TupleVariable, Variable
>>> v1 = Variable('int','v1')
>>> v2 = Variable('bool','v2')
>>> n = TupleVariable([v1, v2],'n')
>>> n
n
"""
__slots__ = ()
is_homogeneous = True

def __init__(self, dtype, *args, **kwargs):
super().__init__(dtype, *args, **kwargs)

def shape_can_change(self, i):
"""
Indicates if the shape can change in the i-th dimension
"""
return self.is_alias and i == (self.rank-1)

def __len__(self):
return self.shape[0]

def __iter__(self):
assert isinstance(self.shape[0], LiteralInteger)
return (self[i] for i in range(self.shape[0]))

class InhomogeneousTupleVariable(TupleVariable):
class InhomogeneousTupleVariable(Variable):
"""
Represents an inhomogeneous tuple variable in the code.
Expand Down Expand Up @@ -839,6 +777,15 @@ def is_target(self, is_target):
if var.rank > 0:
var.is_target = is_target

@property
def is_ndarray(self):
"""
Helper function to determine whether the variable is a NumPy array.
Helper function to determine whether the variable is a NumPy array.
"""
return False

class Constant(Variable):

"""
Expand Down
8 changes: 4 additions & 4 deletions pyccel/codegen/printing/ccode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pyccel.ast.datatypes import NativeInteger, NativeBool, NativeComplex, NativeVoid
from pyccel.ast.datatypes import NativeFloat, NativeTuple, datatype, default_precision
from pyccel.ast.datatypes import CustomDataType, NativeString
from pyccel.ast.datatypes import CustomDataType, NativeString, NativeHomogeneousTuple

from pyccel.ast.internals import Slice, PrecomputedCode, get_final_precision, PyccelArrayShapeElement

Expand All @@ -46,7 +46,7 @@
from pyccel.ast.variable import Variable
from pyccel.ast.variable import DottedName
from pyccel.ast.variable import DottedVariable
from pyccel.ast.variable import InhomogeneousTupleVariable, HomogeneousTupleVariable
from pyccel.ast.variable import InhomogeneousTupleVariable

from pyccel.ast.c_concepts import ObjectAddress, CMacro, CStringExpression

Expand Down Expand Up @@ -1209,7 +1209,7 @@ def get_declare_type(self, expr):
rank = expr.rank

if rank > 0:
if expr.is_ndarray or isinstance(expr, HomogeneousTupleVariable):
if expr.is_ndarray or isinstance(expr.class_type, NativeHomogeneousTuple):
if expr.rank > 15:
errors.report(UNSUPPORTED_ARRAY_RANK, symbol=expr, severity='fatal')
self.add_import(c_imports['ndarrays'])
Expand Down Expand Up @@ -1350,7 +1350,7 @@ def _print_IndexedElement(self, expr):
#set dtype to the C struct types
dtype = self.find_in_ndarray_type_registry(expr.dtype, expr.precision)
base_name = self._print(base)
if getattr(base, 'is_ndarray', False) or isinstance(base, HomogeneousTupleVariable):
if getattr(base, 'is_ndarray', False) or isinstance(base.class_type, NativeHomogeneousTuple):
if expr.rank > 0:
#managing the Slice input
for i , ind in enumerate(inds):
Expand Down
5 changes: 3 additions & 2 deletions pyccel/codegen/printing/pycode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from pyccel.ast.builtins import PythonMin, PythonMax, PythonType
from pyccel.ast.core import CodeBlock, Import, Assign, FunctionCall, For, AsName, FunctionAddress
from pyccel.ast.core import IfSection, FunctionDef, Module, DottedFunctionCall, PyccelFunctionDef
from pyccel.ast.datatypes import NativeHomogeneousTuple
from pyccel.ast.functionalexpr import FunctionalFor
from pyccel.ast.literals import LiteralTrue, LiteralString
from pyccel.ast.literals import LiteralInteger, LiteralFloat, LiteralComplex
from pyccel.ast.numpyext import numpy_target_swap
from pyccel.ast.numpyext import NumpyArray, NumpyNonZero, NumpyResultType
from pyccel.ast.numpyext import DtypePrecisionToCastFunction
from pyccel.ast.variable import DottedName, HomogeneousTupleVariable, Variable
from pyccel.ast.variable import DottedName, Variable
from pyccel.ast.utilities import builtin_import_registry as pyccel_builtin_import_registry
from pyccel.ast.utilities import decorators_mod

Expand Down Expand Up @@ -290,7 +291,7 @@ def _print_IndexedElement(self, expr):
indices = indices[0]

indices = [self._print(i) for i in indices]
if isinstance(expr.base, HomogeneousTupleVariable):
if isinstance(expr.base.class_type, NativeHomogeneousTuple):
indices = ']['.join(i for i in indices)
else:
indices = ','.join(i for i in indices)
Expand Down

0 comments on commit 3db9b8c

Please sign in to comment.