Skip to content

Commit

Permalink
Add C support for a class containing only scalar data (#1472)
Browse files Browse the repository at this point in the history
- Visit the lhs of a `DottedVariable`.
- Add `is_argument = True` for the `self` variable when it's in a method
argument of a class.
- Add the declaration of the attributes of the class in the `struct def`.

---------

Co-authored-by: EmilyBourne <louise.bourne@gmail.com>
  • Loading branch information
sboof911 and EmilyBourne committed Sep 6, 2023
1 parent d44fbdd commit c6d7613
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ All notable changes to this project will be documented in this file.

### Added

- #1472 : Add C printing support for a class containing scalar data.

### Fixed

- #1484 : Use scope for classes to avoid name clashes.
Expand Down
3 changes: 2 additions & 1 deletion pyccel/codegen/printing/ccode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pyccel.ast.builtins import PythonList, PythonTuple

from pyccel.ast.core import Declare, For, CodeBlock
from pyccel.ast.core import FuncAddressDeclare, FunctionCall, FunctionCallArgument, FunctionDef
from pyccel.ast.core import FuncAddressDeclare, FunctionCall, FunctionCallArgument, ClassDef
from pyccel.ast.core import Allocate, Deallocate
from pyccel.ast.core import FunctionAddress, FunctionDefArgument
from pyccel.ast.core import Assign, Import, AugAssign, AliasAssign
Expand Down Expand Up @@ -736,6 +736,7 @@ def _print_ModuleHeader(self, expr):
funcs = ""
for classDef in expr.module.classes:
classes += f"struct {classDef.name} {{\n"
classes += ''.join(self._print(Declare(var.dtype,var)) for var in classDef.attributes)
for method in classDef.methods:
method.rename(classDef.name + ("__" + method.name if not method.name.startswith("__") else method.name))
funcs += f"{self.function_signature(method)};\n"
Expand Down
4 changes: 2 additions & 2 deletions pyccel/parser/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,7 +1267,7 @@ def _assign_lhs_variable(self, lhs, d_var, rhs, new_expressions, is_augassign,ar
# the following is a small fix, since lhs must be already
# declared
if isinstance(lhs, DottedName):
lhs = var.clone(var.name, new_class = DottedVariable, lhs = lhs.name[0])
lhs = var.clone(var.name, new_class = DottedVariable, lhs = self._visit(lhs.name[0]))
else:
lhs = var
else:
Expand Down Expand Up @@ -3281,7 +3281,7 @@ def _visit_FunctionDef(self, expr):
dt = self.get_class_construct(cls_name)()
cls_base = self.scope.find(cls_name, 'classes')
cls_base.scope.insert_symbols(expr.scope.local_used_symbols.copy())
var = Variable(dt, 'self', cls_base=cls_base)
var = Variable(dt, 'self', cls_base=cls_base, is_argument=True)
self.scope.insert_variable(var)

if arguments:
Expand Down
24 changes: 24 additions & 0 deletions tests/pyccel/scripts/classes/classes_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# pylint: disable=missing-class-docstring, missing-function-docstring, missing-module-docstring
# coding: utf-8

#$ header class Point1(public)
#$ header method __init__(Point1, double)
#$ header class Point2(public)
#$ header method __init__(Point2, double)
#$ header method test_func(Point2)

class Point1:
def __init__(self, x):
self.x = x

class Point2:
def __init__(self, y):
self.y = y

def test_func(self):
p = Point1(self.y)
print(p.x)

if __name__ == '__main__':
j = Point2(2.2)
j.test_func()
5 changes: 3 additions & 2 deletions tests/pyccel/test_pyccel.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,9 +863,8 @@ def test_basic_header():

#------------------------------------------------------------------------------
@pytest.mark.parametrize( "test_file", ["scripts/classes/classes.py",
"scripts/classes/classes_1.py",
"scripts/classes/classes_5.py",
"scripts/classes/generic_methods.py",
"scripts/classes/classes_1.py",
] )
@pytest.mark.parametrize( 'language', (
pytest.param("python", marks = pytest.mark.python),
Expand All @@ -882,6 +881,8 @@ def test_classes_f_only( test_file , language):
#------------------------------------------------------------------------------
@pytest.mark.xdist_incompatible
@pytest.mark.parametrize( "test_file", ["scripts/classes/classes_2_C.py",
"scripts/classes/classes_5.py",
"scripts/classes/classes_3.py",
] )
@pytest.mark.parametrize( 'language', (
pytest.param("python", marks = pytest.mark.python),
Expand Down

0 comments on commit c6d7613

Please sign in to comment.