diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c65363471..7927e57e06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ All notable changes to this project will be documented in this file. - #1787 : Ensure `STC` is installed with Pyccel. - #1656 : Ensure `gFTL` is installed with Pyccel. - #1844 : Add line numbers and code to errors from built-in function calls. +- #1655 : Add the appropriate C language equivalent for declaring a Python `list` container using the `STC` library. +- #1659 : Add the appropriate C language equivalent for declaring a Python `set` container using the `STC` library. - \[INTERNALS\] Added `container_rank` property to `ast.datatypes.PyccelType` objects. - \[DEVELOPER\] Added an improved traceback to the developer-mode errors for errors in function calls. - #1893 : Add Python support for set initialisation with `set()`. diff --git a/pyccel/ast/datatypes.py b/pyccel/ast/datatypes.py index 8d3dfb79ee..56e6754708 100644 --- a/pyccel/ast/datatypes.py +++ b/pyccel/ast/datatypes.py @@ -149,6 +149,16 @@ class PyccelType: is expected when calling a bitwise comparison operator on objects of these types. """ __slots__ = () + _name = None + + @property + def name(self): + """ + Get the name of the pyccel type. + + Get the name of the pyccel type. + """ + return self._name def __init__(self): #pylint: disable=useless-parent-delegation # This __init__ function is required so the ArgumentSingleton can @@ -433,6 +443,12 @@ class GenericType(FixedSizeType): def __add__(self, other): return other + def __eq__(self, other): + return True + + def __hash__(self): + return hash(self.__class__) + class SymbolicType(FixedSizeType): """ Class representing the datatype of a placeholder symbol. @@ -745,6 +761,13 @@ def __init__(self, element_type): self._order = 'C' if (element_type.order == 'C' or element_type.rank == 1) else None super().__init__() + def __eq__(self, other): + return isinstance(other, self.__class__) and self._element_type == other._element_type \ + and self._order == other._order + + def __hash__(self): + return hash((self.__class__, self._element_type, self._order)) + class HomogeneousSetType(HomogeneousContainerType, metaclass = ArgumentSingleton): """ Class representing the homogeneous set type. @@ -767,6 +790,12 @@ def __init__(self, element_type): self._element_type = element_type super().__init__() + def __eq__(self, other): + return isinstance(other, self.__class__) and self._element_type == other._element_type + + def __hash__(self): + return hash((self.__class__, self._element_type)) + #============================================================================== class CustomDataType(ContainerType, metaclass=Singleton): diff --git a/pyccel/ast/utilities.py b/pyccel/ast/utilities.py index 3aeaa25395..1e6c0422ab 100644 --- a/pyccel/ast/utilities.py +++ b/pyccel/ast/utilities.py @@ -17,7 +17,7 @@ Concatenate, Module, PyccelFunctionDef) from .builtins import (builtin_functions_dict, - PythonRange, PythonList, PythonTuple) + PythonRange, PythonList, PythonTuple, PythonSet) from .cmathext import cmath_mod from .datatypes import HomogeneousTupleType, PythonNativeInt from .internals import PyccelFunction, Slice @@ -409,7 +409,7 @@ def collect_loops(block, indices, new_index, language_has_vectors = False, resul if result is None: result = [] current_level = 0 - array_creator_types = (Allocate, PythonList, PythonTuple, Concatenate, Duplicate) + array_creator_types = (Allocate, PythonList, PythonTuple, Concatenate, Duplicate, PythonSet) is_function_call = lambda f: ((isinstance(f, FunctionCall) and not f.funcdef.is_elemental) or (isinstance(f, PyccelFunction) and not f.is_elemental and not hasattr(f, '__getitem__') and not isinstance(f, (NumpyTranspose)))) diff --git a/pyccel/codegen/codegen.py b/pyccel/codegen/codegen.py index a19f07a4ed..daf4559df4 100644 --- a/pyccel/codegen/codegen.py +++ b/pyccel/codegen/codegen.py @@ -216,19 +216,19 @@ def export(self, filename): header_filename = f'{filename}.{header_ext}' filename = f'{filename}.{ext}' + # print module + code = self._printer.doprint(self.ast) + with open(filename, 'w', encoding="utf-8") as f: + for line in code: + f.write(line) + # print module header if header_ext is not None: code = self._printer.doprint(ModuleHeader(self.ast)) - with open(header_filename, 'w') as f: + with open(header_filename, 'w', encoding="utf-8") as f: for line in code: f.write(line) - # print module - code = self._printer.doprint(self.ast) - with open(filename, 'w') as f: - for line in code: - f.write(line) - # print program prog_filename = None if self.is_program and self.language != 'python': diff --git a/pyccel/codegen/pipeline.py b/pyccel/codegen/pipeline.py index efff153aad..c0f8634e03 100644 --- a/pyccel/codegen/pipeline.py +++ b/pyccel/codegen/pipeline.py @@ -21,6 +21,7 @@ from pyccel.codegen.utilities import recompile_object from pyccel.codegen.utilities import copy_internal_library from pyccel.codegen.utilities import internal_libs +from pyccel.codegen.utilities import external_libs from pyccel.codegen.python_wrapper import create_shared_library from pyccel.naming import name_clash_checkers from pyccel.utilities.stage import PyccelStage @@ -334,6 +335,14 @@ def handle_error(stage): mod_obj.add_dependencies(stdlib) + + # Iterate over the external_libs list and determine if the printer + # requires an external lib to be included. + for key in codegen.get_printer_imports(): + lib_name = key.split("/", 1)[0] + if lib_name in external_libs: + lib_dest_path = copy_internal_library(lib_name, pyccel_dirpath) + if convert_only: # Change working directory back to starting point os.chdir(base_dirpath) diff --git a/pyccel/codegen/printing/ccode.py b/pyccel/codegen/printing/ccode.py index 5d8eb665b8..a39a442a83 100644 --- a/pyccel/codegen/printing/ccode.py +++ b/pyccel/codegen/printing/ccode.py @@ -9,9 +9,9 @@ from pyccel.ast.basic import ScopedAstNode -from pyccel.ast.builtins import PythonRange, PythonComplex, DtypePrecisionToCastFunction +from pyccel.ast.builtins import PythonRange, PythonComplex from pyccel.ast.builtins import PythonPrint, PythonType -from pyccel.ast.builtins import PythonList, PythonTuple +from pyccel.ast.builtins import PythonList, PythonTuple, PythonSet from pyccel.ast.core import Declare, For, CodeBlock from pyccel.ast.core import FuncAddressDeclare, FunctionCall, FunctionCallArgument @@ -27,7 +27,7 @@ from pyccel.ast.datatypes import PythonNativeInt, PythonNativeBool, VoidType from pyccel.ast.datatypes import TupleType, FixedSizeNumericType -from pyccel.ast.datatypes import CustomDataType, StringType, HomogeneousTupleType +from pyccel.ast.datatypes import CustomDataType, StringType, HomogeneousTupleType, HomogeneousListType, HomogeneousSetType from pyccel.ast.datatypes import PrimitiveBooleanType, PrimitiveIntegerType, PrimitiveFloatingPointType, PrimitiveComplexType from pyccel.ast.datatypes import HomogeneousContainerType @@ -403,7 +403,7 @@ def copy_NumpyArray_Data(self, expr): order = lhs.order lhs_dtype = lhs.dtype - declare_dtype = self.find_in_dtype_registry(lhs_dtype) + declare_dtype = self.get_c_type(lhs_dtype) if isinstance(lhs.class_type, NumpyNDArrayType): #set dtype to the C struct types dtype = self.find_in_ndarray_type_registry(lhs_dtype) @@ -496,7 +496,7 @@ def arrayFill(self, expr): rhs = expr.rhs lhs = expr.lhs code_init = '' - declare_dtype = self.find_in_dtype_registry(rhs.dtype) + declare_dtype = self.get_c_type(rhs.dtype) if rhs.fill_value is not None: if isinstance(rhs.fill_value, Literal): @@ -524,7 +524,7 @@ def _init_stack_array(self, expr): String containing the rhs of the initialization of a stack array. """ var = expr - dtype = self.find_in_dtype_registry(var.dtype) + dtype = self.get_c_type(var.dtype) if isinstance(var.class_type, NumpyNDArrayType): np_dtype = self.find_in_ndarray_type_registry(var.dtype) elif isinstance(var.class_type, HomogeneousContainerType): @@ -534,7 +534,7 @@ def _init_stack_array(self, expr): shape = ", ".join(self._print(i) for i in var.alloc_shape) tot_shape = self._print(functools.reduce( lambda x,y: PyccelMul(x,y,simplify=True), var.alloc_shape)) - declare_dtype = self.find_in_dtype_registry(NumpyInt64Type()) + declare_dtype = self.get_c_type(NumpyInt64Type()) dummy_array_name = self.scope.get_new_name('array_dummy') buffer_array = "{dtype} {name}[{size}];\n".format( @@ -645,6 +645,56 @@ def _handle_inline_func_call(self, expr): return code + def init_stc_container(self, expr, assignment_var): + """ + Generate the initialization of an STC container in C. + + This method generates and prints the C code for initializing a container using the STC `c_init()` method. + + Parameters + ---------- + expr : TypedAstNode + The object representing the container being printed (e.g., PythonList, PythonSet). + + assignment_var : Assign + The assignment node where the Python container (rhs) is being initialized + and saved into a variable (lhs). + + Returns + ------- + str + The generated C code for the container initialization. + """ + + dtype = self.get_c_type(assignment_var.lhs.class_type) + keyraw = '{' + ', '.join(self._print(a) for a in expr.args) + '}' + container_name = self._print(assignment_var.lhs) + init = f'{container_name} = c_init({dtype}, {keyraw});\n' + return init + + def rename_imported_methods(self, expr): + """ + Rename class methods from user-defined imports. + + This function is responsible for renaming methods of classes from + the imported modules, ensuring that the names are correct + by prefixing them with their class names. + + Parameters + ---------- + expr : iterable[ClassDef] + The ClassDef objects found in the module being renamed. + """ + for classDef in expr: + class_scope = classDef.scope + for method in classDef.methods: + if not method.is_inline: + class_scope.rename_function(method, f"{classDef.name}__{method.name.lstrip('__')}") + for interface in classDef.interfaces: + for func in interface.functions: + if not func.is_inline: + class_scope.rename_function(func, f"{classDef.name}__{func.name.lstrip('__')}") + # ============ Elements ============ # def _print_PythonAbs(self, expr): @@ -707,13 +757,13 @@ def _print_SysExit(self, expr): def _print_PythonFloat(self, expr): value = self._print(expr.arg) - type_name = self.find_in_dtype_registry(expr.dtype) + type_name = self.get_c_type(expr.dtype) return '({0})({1})'.format(type_name, value) def _print_PythonInt(self, expr): self.add_import(c_imports['stdint']) value = self._print(expr.arg) - type_name = self.find_in_dtype_registry(expr.dtype) + type_name = self.get_c_type(expr.dtype) return '({0})({1})'.format(type_name, value) def _print_PythonBool(self, expr): @@ -747,7 +797,7 @@ def _print_PythonComplex(self, expr): else: value = self._print(PyccelAssociativeParenthesis(PyccelAdd(expr.real, PyccelMul(expr.imag, LiteralImaginaryUnit())))) - type_name = self.find_in_dtype_registry(expr.dtype) + type_name = self.get_c_type(expr.dtype) return '({0})({1})'.format(type_name, value) def _print_LiteralImaginaryUnit(self, expr): @@ -772,16 +822,11 @@ def _print_ModuleHeader(self, expr): classes += self._print(classDef.docstring) classes += f"struct {classDef.name} {{\n" classes += ''.join(self._print(Declare(var)) for var in classDef.attributes) - class_scope = classDef.scope for method in classDef.methods: - if not method.is_inline: - class_scope.rename_function(method, f"{classDef.name}__{method.name.lstrip('__')}") funcs += f"{self.function_signature(method)};\n" for interface in classDef.interfaces: for func in interface.functions: - if not func.is_inline: - class_scope.rename_function(func, f"{classDef.name}__{func.name.lstrip('__')}") - funcs += f"{self.function_signature(func)};\n" + funcs += f"{self.function_signature(func)};\n" classes += "};\n" funcs += '\n'.join(f"{self.function_signature(f)};" for f in expr.module.funcs if not f.is_inline) @@ -805,13 +850,17 @@ def _print_ModuleHeader(self, expr): def _print_Module(self, expr): self.set_scope(expr.scope) self._current_module = expr + for item in expr.imports: + if item.source_module and item.source_module is not self._current_module: + self.rename_imported_methods(item.source_module.classes) + self.rename_imported_methods(expr.classes) body = ''.join(self._print(i) for i in expr.body) global_variables = ''.join([self._print(d) for d in expr.declarations]) # Print imports last to be sure that all additional_imports have been collected - imports = [Import(self.scope.get_python_name(expr.name), Module(expr.name,(),())), *self._additional_imports.values()] - imports = ''.join(self._print(i) for i in imports) + imports = Import(self.scope.get_python_name(expr.name), Module(expr.name,(),())) + imports = self._print(imports) code = ('{imports}\n' '{variables}\n' @@ -947,6 +996,15 @@ def _print_Import(self, expr): source = source.name[-1] else: source = self._print(source) + if source.startswith('stc/'): + stc_name, container_type, container_key = source.split("/") + container = container_type.split("_") + return '\n'.join((f'#ifndef _{container_type.upper()}', + f'#define _{container_type.upper()}', + f'#define i_type {container_type}', + f'#define i_key {container_key}', + f'#include "{stc_name + "/" + container[0]}.h"', + '#endif\n')) # Get with a default value is not used here as it is # slower and on most occasions the import will not be in the @@ -954,16 +1012,6 @@ def _print_Import(self, expr): if source in import_dict: # pylint: disable=consider-using-get source = import_dict[source] - if expr.source_module and expr.source_module is not self._current_module: - for classDef in expr.source_module.classes: - class_scope = classDef.scope - for method in classDef.methods: - if not method.is_inline: - class_scope.rename_function(method, f"{classDef.name}__{method.name.lstrip('__')}") - for interface in classDef.interfaces: - for func in interface.functions: - if not func.is_inline: - class_scope.rename_function(func, f"{classDef.name}__{func.name.lstrip('__')}") if source is None: return '' @@ -1134,28 +1182,39 @@ def formatted_args_to_printf(args_format, args, end): code += formatted_args_to_printf(args_format, args, end) return code - def find_in_dtype_registry(self, dtype): + def get_c_type(self, dtype): """ - Find the corresponding C dtype in the dtype_registry. + Find the corresponding C type of the PyccelType. - Find the corresponding C dtype in the dtype_registry. - Raise PYCCEL_RESTRICTION_TODO if not found. + For scalar types, this function searches for the corresponding C data type + in the `dtype_registry`. If the provided type is a container (like + `HomogeneousSetType` or `HomogeneousListType`), it recursively identifies + the type of an element of the container and uses it to calculate the + appropriate type for the `STC` container. + A `PYCCEL_RESTRICTION_TODO` error is raised if the dtype is not found in the registry. Parameters ---------- - dtype : DataType - The data type of the expression. + dtype : PyccelType + The data type of the expression. This can be a fixed-size numeric type, + a primitive type, or a container type. Returns ------- str - The code which declares the datatype in C. + The code which declares the data type in C or the corresponding `STC` container + type. + + Raises + ------ + PyccelCodegenError + If the dtype is not found in the dtype_registry. """ if isinstance(dtype, FixedSizeNumericType): primitive_type = dtype.primitive_type if isinstance(primitive_type, PrimitiveComplexType): self.add_import(c_imports['complex']) - return f'{self.find_in_dtype_registry(dtype.element_type)} complex' + return f'{self.get_c_type(dtype.element_type)} complex' elif isinstance(primitive_type, PrimitiveIntegerType): self.add_import(c_imports['stdint']) elif isinstance(dtype, PythonNativeBool): @@ -1163,6 +1222,12 @@ def find_in_dtype_registry(self, dtype): return 'bool' key = (primitive_type, dtype.precision) + elif isinstance(dtype, (HomogeneousSetType, HomogeneousListType)): + container_type = 'hset_' if dtype.name == 'set' else 'vec_' + element_type = self.get_c_type(dtype.element_type) + i_type = container_type + element_type.replace(' ', '_') + self.add_import(Import(f'stc/{i_type}/{element_type}', Module(f'stc/{i_type}', (), ()))) + return i_type else: key = dtype @@ -1238,7 +1303,10 @@ def get_declare_type(self, expr): rank = expr.rank if rank > 0: - if expr.is_ndarray or isinstance(expr.class_type, HomogeneousContainerType): + if isinstance(expr.class_type, (HomogeneousSetType, HomogeneousListType)): + dtype = self.get_c_type(expr.class_type) + return dtype + if isinstance(expr.class_type,(HomogeneousTupleType, NumpyNDArrayType)): if expr.rank > 15: errors.report(UNSUPPORTED_ARRAY_RANK, symbol=expr, severity='fatal') self.add_import(c_imports['ndarrays']) @@ -1246,7 +1314,7 @@ def get_declare_type(self, expr): else: errors.report(PYCCEL_RESTRICTION_TODO+' (rank>0)', symbol=expr, severity='fatal') elif not isinstance(class_type, CustomDataType): - dtype = self.find_in_dtype_registry(expr.dtype) + dtype = self.get_c_type(expr.dtype) else: dtype = self._print(expr.class_type) @@ -1330,11 +1398,11 @@ def function_signature(self, expr, print_arg_names = True): if n_results == 1: ret_type = self.get_declare_type(result_vars[0]) elif n_results > 1: - ret_type = self.find_in_dtype_registry(PythonNativeInt()) + ret_type = self.get_c_type(PythonNativeInt()) arg_vars.extend(result_vars) self._additional_args.append(result_vars) # Ensure correct result for is_c_pointer else: - ret_type = self.find_in_dtype_registry(VoidType()) + ret_type = self.get_c_type(VoidType()) name = expr.name if not arg_vars: @@ -1432,7 +1500,7 @@ def _cast_to(self, expr, dtype): after using this function. """ if expr.dtype != dtype: - cast=self.find_in_dtype_registry(dtype) + cast=self.get_c_type(dtype) return '({}){{}}'.format(cast) return '{}' @@ -1525,6 +1593,8 @@ def _print_PyccelArrayShapeElement(self, expr): def _print_Allocate(self, expr): free_code = '' variable = expr.variable + if isinstance(variable.class_type, (HomogeneousListType, HomogeneousSetType)): + return '' if variable.rank > 0: #free the array if its already allocated and checking if its not null if the status is unknown if (expr.status == 'unknown'): @@ -1542,7 +1612,7 @@ def _print_Allocate(self, expr): dtype = self.find_in_ndarray_type_registry(numpy_precision_map[(variable.dtype.primitive_type, variable.dtype.precision)]) else: raise NotImplementedError(f"Don't know how to index {variable.class_type} type") - shape_dtype = self.find_in_dtype_registry(NumpyInt64Type()) + shape_dtype = self.get_c_type(NumpyInt64Type()) shape_Assign = "("+ shape_dtype +"[]){" + shape + "}" is_view = 'false' if variable.on_heap else 'true' order = "order_f" if expr.order == "F" else "order_c" @@ -1559,6 +1629,10 @@ def _print_Allocate(self, expr): raise NotImplementedError(f"Allocate not implemented for {variable}") def _print_Deallocate(self, expr): + if isinstance(expr.variable.class_type, (HomogeneousListType, HomogeneousSetType)): + variable_address = self._print(ObjectAddress(expr.variable)) + container_type = self.get_c_type(expr.variable.class_type) + return f'{container_type}_drop({variable_address});\n' if isinstance(expr.variable, InhomogeneousTupleVariable): return ''.join(self._print(Deallocate(v)) for v in expr.variable) variable_address = self._print(ObjectAddress(expr.variable)) @@ -1729,7 +1803,7 @@ def _print_MathFunctionBase(self, expr): args.append(self._print(arg)) code_args = ', '.join(args) if expr.dtype.primitive_type is PrimitiveIntegerType(): - cast_type = self.find_in_dtype_registry(expr.dtype) + cast_type = self.get_c_type(expr.dtype) return f'({cast_type}){func_name}({code_args})' return f'{func_name}({code_args})' @@ -2041,7 +2115,7 @@ def _print_PyccelFloorDiv(self, expr): code = ' / '.join(self._print(a if a.dtype.primitive_type is PrimitiveFloatingPointType() else NumpyFloat(a)) for a in expr.args) if (need_to_cast): - cast_type = self.find_in_dtype_registry(expr.dtype) + cast_type = self.get_c_type(expr.dtype) return "({})floor({})".format(cast_type, code) return "floor({})".format(code) @@ -2105,6 +2179,8 @@ def _print_Assign(self, expr): if isinstance(rhs, (NumpyFull)): return prefix_code+self.arrayFill(expr) lhs = self._print(expr.lhs) + if isinstance(rhs, (PythonList, PythonSet)): + return prefix_code+self.init_stc_container(rhs, expr) rhs = self._print(expr.rhs) return prefix_code+'{} = {};\n'.format(lhs, rhs) diff --git a/pyccel/codegen/utilities.py b/pyccel/codegen/utilities.py index 190044b5f0..ceffc483e3 100644 --- a/pyccel/codegen/utilities.py +++ b/pyccel/codegen/utilities.py @@ -12,17 +12,25 @@ import shutil from filelock import FileLock import pyccel.stdlib as stdlib_folder +import pyccel.extensions as ext_folder from .compiling.basic import CompileObj # get path to pyccel/stdlib/lib_name stdlib_path = os.path.dirname(stdlib_folder.__file__) +# get path to pyccel/extensions/lib_name +ext_path = os.path.dirname(ext_folder.__file__) + __all__ = ['copy_internal_library','recompile_object'] #============================================================================== language_extension = {'fortran':'f90', 'c':'c', 'python':'py'} +#============================================================================== +# map external libraries inside pyccel/extensions with their path +external_libs = {"stc" : "STC/include"} + #============================================================================== # map internal libraries to their folders inside pyccel/stdlib and their compile objects # The compile object folder will be in the pyccel dirpath @@ -101,8 +109,12 @@ def copy_internal_library(lib_folder, pyccel_dirpath, extra_files = None): str The location that the files were copied to. """ - # get lib path (stdlib_path/lib_name) - lib_path = os.path.join(stdlib_path, lib_folder) + # get lib path (stdlib_path/lib_name or ext_path/lib_name) + if lib_folder in external_libs: + lib_path = os.path.join(ext_path, external_libs[lib_folder], lib_folder) + else: + lib_path = os.path.join(stdlib_path, lib_folder) + # remove library folder to avoid missing files and copy # new one from pyccel stdlib lib_dest_path = os.path.join(pyccel_dirpath, lib_folder) diff --git a/pyccel/codegen/wrapper/fortran_to_c_wrapper.py b/pyccel/codegen/wrapper/fortran_to_c_wrapper.py index 6bea034128..e0e1ccd7df 100644 --- a/pyccel/codegen/wrapper/fortran_to_c_wrapper.py +++ b/pyccel/codegen/wrapper/fortran_to_c_wrapper.py @@ -22,6 +22,7 @@ from pyccel.ast.literals import LiteralInteger, Nil, LiteralTrue from pyccel.ast.operators import PyccelIsNot, PyccelMul from pyccel.ast.variable import Variable, IndexedElement, DottedVariable +from pyccel.ast.numpyext import NumpyNDArrayType from pyccel.parser.scope import Scope from .wrapper import Wrapper @@ -451,7 +452,7 @@ def _wrap_Variable(self, expr): """ if isinstance(expr.class_type, FixedSizeNumericType): return expr.clone(expr.name, new_class = BindCVariable) - else: + elif isinstance(expr.class_type, NumpyNDArrayType): scope = self.scope func_name = scope.get_new_name('bind_c_'+expr.name.lower()) func_scope = scope.new_child_scope(func_name) @@ -483,6 +484,8 @@ def _wrap_Variable(self, expr): original_function = expr) return expr.clone(expr.name, new_class = BindCArrayVariable, wrapper_function = func, original_variable = expr) + else: + raise NotImplementedError(f"Objects of type {expr.class_type} cannot be wrapped yet") def _wrap_DottedVariable(self, expr): """ diff --git a/pyccel/extensions/__init__.py b/pyccel/extensions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyccel/parser/semantic.py b/pyccel/parser/semantic.py index 4c6e210836..0bb1a7c0b1 100644 --- a/pyccel/parser/semantic.py +++ b/pyccel/parser/semantic.py @@ -61,10 +61,10 @@ from pyccel.ast.class_defs import NumpyArrayClass, TupleClass, get_cls_base from pyccel.ast.datatypes import CustomDataType, PyccelType, TupleType, VoidType, GenericType -from pyccel.ast.datatypes import PrimitiveIntegerType, HomogeneousListType, StringType, SymbolicType +from pyccel.ast.datatypes import PrimitiveIntegerType, StringType, SymbolicType from pyccel.ast.datatypes import PythonNativeBool, PythonNativeInt, PythonNativeFloat from pyccel.ast.datatypes import DataTypeFactory, PrimitiveFloatingPointType -from pyccel.ast.datatypes import InhomogeneousTupleType, HomogeneousTupleType, HomogeneousSetType +from pyccel.ast.datatypes import InhomogeneousTupleType, HomogeneousTupleType, HomogeneousSetType, HomogeneousListType from pyccel.ast.datatypes import PrimitiveComplexType, FixedSizeNumericType from pyccel.ast.functionalexpr import FunctionalSum, FunctionalMax, FunctionalMin, GeneratorComprehension, FunctionalFor @@ -1925,7 +1925,6 @@ def _get_indexed_type(self, base, args, expr): else: raise errors.report(f"Unknown annotation base {base}\n"+PYCCEL_RESTRICTION_TODO, severity='fatal', symbol=expr) - rank = 1 if len(args) == 2 and args[1] is LiteralEllipsis() or len(args) == 1: syntactic_annotation = args[0] if not isinstance(syntactic_annotation, SyntacticTypeAnnotation): diff --git a/tests/epyccel/modules/array_consts.py b/tests/epyccel/modules/array_consts.py index 157b9e6549..8e38031f57 100644 --- a/tests/epyccel/modules/array_consts.py +++ b/tests/epyccel/modules/array_consts.py @@ -6,7 +6,6 @@ c = np.zeros((2,3), dtype=np.int32) d = np.array([1+2j, 3+4j]) e = np.empty((2,3,4)) -F = [False for _ in range(5)] def update_a(): a[:] = a+1 diff --git a/tests/epyccel/modules/list_comprehension.py b/tests/epyccel/modules/list_comprehension.py new file mode 100644 index 0000000000..13585fb048 --- /dev/null +++ b/tests/epyccel/modules/list_comprehension.py @@ -0,0 +1,4 @@ +# pylint: disable=missing-function-docstring, missing-module-docstring + +A = [False for _ in range(5)] + diff --git a/tests/epyccel/test_arrays.py b/tests/epyccel/test_arrays.py index fdac39bbad..59ddb650ab 100644 --- a/tests/epyccel/test_arrays.py +++ b/tests/epyccel/test_arrays.py @@ -4027,7 +4027,15 @@ def test_array_float_nested_C_array_initialization_3(language): #============================================================================== # NUMPY SUM #============================================================================== - +@pytest.mark.parametrize( 'language', ( + pytest.param("fortran", marks = pytest.mark.fortran), + pytest.param("c", marks = [ + pytest.mark.xfail(reason="List indexing is not yet supported in C, related issue #1876"), + pytest.mark.c] + ), + pytest.param("python", marks = pytest.mark.python) + ) +) def test_arr_bool_sum(language): f1 = arrays.arr_bool_sum f2 = epyccel(f1, language = language) diff --git a/tests/epyccel/test_epyccel_modules.py b/tests/epyccel/test_epyccel_modules.py index 445e8dc457..ad8ae0bd75 100644 --- a/tests/epyccel/test_epyccel_modules.py +++ b/tests/epyccel/test_epyccel_modules.py @@ -1,4 +1,5 @@ # pylint: disable=missing-function-docstring, missing-module-docstring +import pytest import numpy as np from pyccel import epyccel @@ -141,8 +142,6 @@ def test_module_7(language): assert np.array_equal(mod_att, modnew_att) assert mod_att.dtype == modnew_att.dtype - assert np.array_equal(mod.F, modnew.F) - modnew.update_a() mod.update_a() @@ -171,6 +170,25 @@ def test_module_7(language): mod.reset_c() mod.reset_e() +@pytest.mark.parametrize( 'language', ( + pytest.param("fortran", marks = [ + pytest.mark.xfail(reason="List wrapper is not implemented yet, related issue #1911"), + pytest.mark.fortran] + ), + pytest.param("c", marks = [ + pytest.mark.xfail(reason="List indexing is not yet supported in C, related issue #1876"), + pytest.mark.c] + ), + pytest.param("python", marks = pytest.mark.python) + ) +) +def test_module_8(language): + import modules.list_comprehension as mod + + modnew = epyccel(mod, language=language) + + assert np.array_equal(mod.A, modnew.A) + def test_awkward_names(language): import modules.awkward_names as mod diff --git a/tests/epyccel/test_epyccel_variable_annotations.py b/tests/epyccel/test_epyccel_variable_annotations.py index 5927f58373..c8710a59da 100644 --- a/tests/epyccel/test_epyccel_variable_annotations.py +++ b/tests/epyccel/test_epyccel_variable_annotations.py @@ -10,6 +10,17 @@ from pyccel.errors.errors import PyccelSemanticError, Errors from pyccel.decorators import allow_negative_index, stack_array +@pytest.fixture( params=[ + pytest.param("c", marks = pytest.mark.c), + pytest.param("fortran", marks = [ + pytest.mark.xfail(reason="Variable declaration not implemented in fortran, related issues #1657 1658"), + pytest.mark.fortran]), + pytest.param("python", marks = pytest.mark.python) + ], + scope = "module" +) +def stc_language(request): + return request.param def test_local_type_annotation(language): def local_type_annotation(): @@ -166,3 +177,97 @@ def homogeneous_tuple_annotation(): assert epyc_homogeneous_tuple_annotation() == homogeneous_tuple_annotation() assert isinstance(epyc_homogeneous_tuple_annotation(), type(homogeneous_tuple_annotation())) + +def test_homogeneous_set_annotation_int(stc_language): + def homogeneous_set_annotation (): + a : 'set[int]' #pylint: disable=unused-variable + a = {1, 2, 3, 4} + epyc_homogeneous_set_annotation = epyccel(homogeneous_set_annotation, language=stc_language) + assert epyc_homogeneous_set_annotation() == homogeneous_set_annotation() + assert isinstance(epyc_homogeneous_set_annotation(), type(homogeneous_set_annotation())) + +def test_homogeneous_set_without_annotation(stc_language): + def homogeneous_set(): + a = {1, 2, 3, 4} #pylint: disable=unused-variable + epyc_homogeneous_set = epyccel(homogeneous_set, language=stc_language) + assert epyc_homogeneous_set() == homogeneous_set() + assert isinstance(epyc_homogeneous_set(), type(homogeneous_set())) + +def test_homogeneous_set_annotation_float(stc_language): + def homogeneous_set_annotation (): + a : 'set[float]' #pylint: disable=unused-variable + a = {1.5, 2.5, 3.3, 4.3} + epyc_homogeneous_set_annotation = epyccel(homogeneous_set_annotation, language=stc_language) + assert epyc_homogeneous_set_annotation() == homogeneous_set_annotation() + assert isinstance(epyc_homogeneous_set_annotation(), type(homogeneous_set_annotation())) + +def test_homogeneous_set_annotation_bool(stc_language): + def homogeneous_set_annotation (): + a : 'set[bool]' #pylint: disable=unused-variable + a = {False, True, False, True} #pylint: disable=duplicate-value + epyc_homogeneous_set_annotation = epyccel(homogeneous_set_annotation, language=stc_language) + assert epyc_homogeneous_set_annotation() == homogeneous_set_annotation() + assert isinstance(epyc_homogeneous_set_annotation(), type(homogeneous_set_annotation())) + +def test_homogeneous_set_annotation_complex(stc_language): + def homogeneous_set_annotation(): + a: 'set[complex]' # pylint: disable=unused-variable + a = {1+1j, 2+2j, 3+3j, 4+4j} + epyc_homogeneous_set_annotation = epyccel(homogeneous_set_annotation, language=stc_language) + assert epyc_homogeneous_set_annotation() == homogeneous_set_annotation() + assert isinstance(epyc_homogeneous_set_annotation(), type(homogeneous_set_annotation())) + +def test_homogeneous_empty_list_annotation_int(stc_language): + def homogeneous_list_annotation(): + a: 'list[int]' # pylint: disable=unused-variable + a = [] + epyc_homogeneous_list_annotation = epyccel(homogeneous_list_annotation, language=stc_language) + assert epyc_homogeneous_list_annotation() == homogeneous_list_annotation() + assert isinstance(epyc_homogeneous_list_annotation(), type(homogeneous_list_annotation())) + +def test_homogeneous_list_annotation_int(stc_language): + def homogeneous_list_annotation(): + a: 'list[int]' # pylint: disable=unused-variable + a = [1, 2, 3, 4] + epyc_homogeneous_list_annotation = epyccel(homogeneous_list_annotation, language=stc_language) + assert epyc_homogeneous_list_annotation() == homogeneous_list_annotation() + assert isinstance(epyc_homogeneous_list_annotation(), type(homogeneous_list_annotation())) + +def test_homogeneous_list_without_annotation(stc_language): + def homogeneous_list(): + a = [1, 2, 3, 4] # pylint: disable=unused-variable + epyc_homogeneous_list = epyccel(homogeneous_list, language=stc_language) + assert epyc_homogeneous_list() == homogeneous_list() + assert isinstance(epyc_homogeneous_list(), type(homogeneous_list())) + +def test_homogeneous_list_annotation_float(stc_language): + def homogeneous_list_annotation(): + a: 'list[float]' # pylint: disable=unused-variable + a = [1.1, 2.2, 3.3, 4.4] + epyc_homogeneous_list_annotation = epyccel(homogeneous_list_annotation, language=stc_language) + assert epyc_homogeneous_list_annotation() == homogeneous_list_annotation() + assert isinstance(epyc_homogeneous_list_annotation(), type(homogeneous_list_annotation())) + +def test_homogeneous_list_annotation_bool(stc_language): + def homogeneous_list_annotation(): + a: 'list[bool]' # pylint: disable=unused-variable + a = [False, True, True, False] + epyc_homogeneous_list_annotation = epyccel(homogeneous_list_annotation, language=stc_language) + assert epyc_homogeneous_list_annotation() == homogeneous_list_annotation() + assert isinstance(epyc_homogeneous_list_annotation(), type(homogeneous_list_annotation())) + +def test_homogeneous_list_annotation_complex(stc_language): + def homogeneous_list_annotation(): + a: 'list[complex]' # pylint: disable=unused-variable + a = [1+1j, 2+2j, 3+3j, 4+4j] + epyc_homogeneous_list_annotation = epyccel(homogeneous_list_annotation, language=stc_language) + assert epyc_homogeneous_list_annotation() == homogeneous_list_annotation() + assert isinstance(epyc_homogeneous_list_annotation(), type(homogeneous_list_annotation())) + +def test_homogeneous_list_annotation_embedded_complex(stc_language): + def homogeneous_list_annotation(): + a : 'list[complex]' = [1j, 2j] + b = [a] # pylint: disable=unused-variable + epyc_homogeneous_list_annotation = epyccel(homogeneous_list_annotation, language=stc_language) + assert epyc_homogeneous_list_annotation() == homogeneous_list_annotation() + assert isinstance(epyc_homogeneous_list_annotation(), type(homogeneous_list_annotation())) diff --git a/tests/epyccel/test_functionals.py b/tests/epyccel/test_functionals.py index f30c7d2279..01e54521e0 100644 --- a/tests/epyccel/test_functionals.py +++ b/tests/epyccel/test_functionals.py @@ -1,10 +1,24 @@ # pylint: disable=missing-function-docstring, missing-module-docstring from numpy.random import randint from numpy import equal +import pytest + from pyccel import epyccel from modules import functionals +@pytest.fixture( params=[ + pytest.param("fortran", marks = pytest.mark.fortran), + pytest.param("c", marks = [ + pytest.mark.xfail(reason="C does not support list indexing yet, related issue #1876"), + pytest.mark.c]), + pytest.param("python", marks = pytest.mark.python) + ], + scope = "module" +) +def language(request): + return request.param + def compare_epyccel(f, language, *args): f2 = epyccel(f, language=language) out1 = f(*args) diff --git a/tests/epyccel/test_loops.py b/tests/epyccel/test_loops.py index 3b24578f3f..a1176010e9 100644 --- a/tests/epyccel/test_loops.py +++ b/tests/epyccel/test_loops.py @@ -77,6 +77,15 @@ def test_double_loop_on_2d_array_F(language): f2( y ) assert np.array_equal( x, y ) +@pytest.mark.parametrize( 'language', ( + pytest.param("fortran", marks = pytest.mark.fortran), + pytest.param("c", marks = [ + pytest.mark.xfail(reason="C does not support list indexing yet, related issue #1876"), + pytest.mark.c] + ), + pytest.param("python", marks = pytest.mark.python) + ) +) def test_product_loop_on_2d_array_C(language): f1 = loops.product_loop_on_2d_array_C @@ -89,6 +98,15 @@ def test_product_loop_on_2d_array_C(language): f2( y ) assert np.array_equal( x, y ) +@pytest.mark.parametrize( 'language', ( + pytest.param("fortran", marks = pytest.mark.fortran), + pytest.param("c", marks = [ + pytest.mark.xfail(reason="C does not support list indexing yet, related issue #1876"), + pytest.mark.c] + ), + pytest.param("python", marks = pytest.mark.python) + ) +) def test_product_loop_on_2d_array_F(language): f1 = loops.product_loop_on_2d_array_F @@ -101,6 +119,15 @@ def test_product_loop_on_2d_array_F(language): f2( y ) assert np.array_equal( x, y ) +@pytest.mark.parametrize( 'language', ( + pytest.param("fortran", marks = pytest.mark.fortran), + pytest.param("c", marks = [ + pytest.mark.xfail(reason="C does not support list indexing yet, related issue #1876"), + pytest.mark.c] + ), + pytest.param("python", marks = pytest.mark.python) + ) +) def test_product_loop(language): f1 = loops.product_loop @@ -150,6 +177,15 @@ def test_enumerate_on_1d_array_with_start(language): assert np.array_equal( f1(z, 5), f2(z, 5) ) assert np.array_equal( f1(z,-2), f2(z,-2) ) +@pytest.mark.parametrize( 'language', ( + pytest.param("fortran", marks = pytest.mark.fortran), + pytest.param("c", marks = [ + pytest.mark.xfail(reason="C does not support list indexing yet, related issue #1876"), + pytest.mark.c] + ), + pytest.param("python", marks = pytest.mark.python) + ) +) def test_zip_prod(language): f1 = loops.zip_prod