Skip to content

Commit

Permalink
Improve TypedAstNode to fully describe container types (#1590)
Browse files Browse the repository at this point in the history
Add an attribute `class_type` to `TypedAstNode` (fixes #1530). This
attribute describes the type of the object as it would be reported in
Python. As a result it can be used to store container types (e.g.
ndarray, lists, sets, etc). Add docstrings.

Simplify type deductions using `__add__` to combine types. This function
is called repeatedly. As a simple function, it is therefore cached.
  • Loading branch information
EmilyBourne committed Nov 21, 2023
1 parent 40a8143 commit 79c0c06
Show file tree
Hide file tree
Showing 29 changed files with 1,250 additions and 490 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file.

- #1571 : Add support for the function `tuple`.
- #1493 : Add preliminary support for importing classes.
- \[INTERNALS\] Add `class_type` attribute to `TypedAstNode`.

### Fixed

Expand Down
36 changes: 31 additions & 5 deletions pyccel/ast/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,17 @@ def order(self):
"""
return self._order # pylint: disable=no-member

@property
def class_type(self):
"""
The type of the object.
The Python type of the object. In the case of scalars this is equivalent to
the datatype. For objects in (homogeneous) containers (e.g. list/ndarray/tuple),
this is the type of the container.
"""
return self._class_type # pylint: disable=no-member

@classmethod
def static_rank(cls):
"""
Expand Down Expand Up @@ -618,6 +629,20 @@ class does not have a predetermined order.
"""
return cls._order # pylint: disable=no-member

@classmethod
def static_class_type(cls):
"""
The type of the object.
The Python type of the object. In the case of scalars this is equivalent to
the datatype. For objects in (homogeneous) containers (e.g. list/ndarray/tuple),
this is the type of the container.
This function is static and will return an AttributeError if the
class does not have a predetermined order.
"""
return cls._class_type # pylint: disable=no-member

def copy_attributes(self, x):
"""
Copy the attributes describing a TypedAstNode into this node.
Expand All @@ -631,11 +656,12 @@ def copy_attributes(self, x):
x : TypedAstNode
The node from which the attributes should be copied.
"""
self._shape = x.shape
self._rank = x.rank
self._dtype = x.dtype
self._precision = x.precision
self._order = x.order
self._shape = x.shape
self._rank = x.rank
self._dtype = x.dtype
self._precision = x.precision
self._order = x.order
self._class_type = x.class_type


#------------------------------------------------------------------------------
Expand Down
25 changes: 14 additions & 11 deletions pyccel/ast/bitwise_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
from .builtins import PythonInt
from .datatypes import (NativeBool, NativeInteger, NativeFloat,
NativeComplex, NativeString)
NativeComplex, NativeString, NativeGeneric)
from .internals import max_precision
from .operators import PyccelUnaryOperator, PyccelOperator

Expand Down Expand Up @@ -70,7 +70,7 @@ def _calculate_dtype(self, arg):

self._args = (PythonInt(arg) if arg.dtype is NativeBool() else arg,)
precision = arg.precision
return dtype, precision
return dtype, precision, dtype

def __repr__(self):
return f'~{repr(self.args[0])}'
Expand All @@ -93,7 +93,7 @@ class PyccelBitOperator(PyccelOperator):
_shape = None
_rank = 0
_order = None
__slots__ = ('_dtype','_precision')
__slots__ = ('_dtype','_precision','_class_type')

def __init__(self, arg1, arg2):
super().__init__(arg1, arg2)
Expand Down Expand Up @@ -124,17 +124,20 @@ def _calculate_dtype(self, *args):
precision : integer
The precision of the result of the operation.
"""
integers = [a for a in args if a.dtype in (NativeInteger(),NativeBool())]
floats = [a for a in args if a.dtype is NativeFloat()]
complexes = [a for a in args if a.dtype is NativeComplex()]
strs = [a for a in args if a.dtype is NativeString()]
try:
dtype = sum((a.dtype for a in args), start=NativeGeneric())
class_type = sum((a.class_type for a in args), start=NativeGeneric())
except NotImplementedError:
raise TypeError(f'Cannot determine the type of {args}') #pylint: disable=raise-missing-from

if strs or complexes or floats:
if dtype in (NativeString(), NativeComplex(), NativeFloat()):
raise TypeError(f'unsupported operand type(s): {args}')
elif integers:
return self._handle_integer_type(integers)
elif (dtype in (NativeInteger(), NativeBool())):
if class_type is NativeBool():
class_type = NativeInteger()
return *self._handle_integer_type(args), class_type
else:
raise TypeError(f'cannot determine the type of {args}')
raise TypeError(f'Cannot determine the type of {args}')

def _set_shape_rank(self):
pass
Expand Down

0 comments on commit 79c0c06

Please sign in to comment.