Skip to content

Commit

Permalink
[Lang] Better struct initialization (#5481)
Browse files Browse the repository at this point in the history
* what

* Better StructType initialization

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Better StructType initialization

* Better StructType initialization

* Better StructType initialization

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* temp save

* better struct initialization

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
neozhaoliang and pre-commit-ci[bot] committed Jul 25, 2022
1 parent b6a322a commit 0f44105
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 23 deletions.
31 changes: 16 additions & 15 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numbers
from types import MethodType

from taichi.lang import expr, impl, ops
from taichi.lang import impl, ops
from taichi.lang.common_ops import TaichiOperations
from taichi.lang.enums import Layout
from taichi.lang.exception import TaichiSyntaxError
Expand Down Expand Up @@ -648,21 +648,22 @@ def __init__(self, **kwargs):
self.members[k] = cook_dtype(dtype)

def __call__(self, *args, **kwargs):
if len(args) == 0:
if kwargs == {}:
raise TaichiSyntaxError(
"Custom type instances need to be created with an initial value."
)
else:
# initialize struct members by keywords
entries = Struct(kwargs)
elif len(args) == 1:
# fill a single scalar
if isinstance(args[0], (numbers.Number, expr.Expr)):
entries = self.filled_with_scalar(args[0])
d = {}
items = self.members.items()
for index, pair in enumerate(items):
name, dtype = pair
if isinstance(dtype, CompoundType):
if index < len(args):
d[name] = dtype(args[index])
else:
d[name] = kwargs.get(name, dtype(0))
else:
# initialize struct members by dictionary
entries = Struct(args[0])
if index < len(args):
d[name] = args[index]
else:
d[name] = kwargs.get(name, 0)

entries = Struct(d)
struct = self.cast(entries)
return struct

Expand Down
36 changes: 28 additions & 8 deletions tests/python/test_custom_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,19 @@ def run_python_scope():

init_taichi_scope()
for i in range(n):
assert x[i].idx == 1
assert x[i].idx == 0
assert np.allclose(x[i].line.linedir.to_numpy(), 1.0)
assert x[i].line.length == 1.0
assert x[i].line.length == 0.0
run_taichi_scope()
for i in range(n):
assert x[i].idx == i
assert np.allclose(x[i].line.linedir.to_numpy(), 1.0)
assert x[i].line.length == i + 0.5
init_python_scope()
for i in range(n):
assert x[i].idx == 3
assert x[i].idx == 0
assert np.allclose(x[i].line.linedir.to_numpy(), 3.0)
assert x[i].line.length == 3.0
assert x[i].line.length == 0.0
run_python_scope()
for i in range(n):
assert x[i].idx == i
Expand Down Expand Up @@ -293,20 +293,20 @@ def test_compound_type_implicit_cast():

@ti.kernel
def f2i_taichi_scope() -> int:
s = structi(2.5)
s = structi(2.5, (2.5, 2.5))
return s.a + s.b[0] + s.b[1]

def f2i_python_scope():
s = structi(2.5)
s = structi(2.5, (2.5, 2.5))
return s.a + s.b[0] + s.b[1]

@ti.kernel
def i2f_taichi_scope() -> float:
s = structf(2)
s = structf(2, (2, 2))
return s.a + s.b[0] + s.b[1]

def i2f_python_scope():
s = structf(2)
s = structf(2, (2, 2))
return s.a + s.b[0] + s.b[1]

int_value = f2i_taichi_scope()
Expand Down Expand Up @@ -396,3 +396,23 @@ def test():
assert a.b == 3

test()


@test_utils.test(debug=True)
def test_dataclass():
vec3 = ti.types.vector(3, float)

@ti.dataclass
class Foo:
pos: vec3
vel: vec3
mass: float

@ti.kernel
def test():
A = Foo((1, 1, 1), mass=2)
assert all(A.pos == [1.0, 1.0, 1.0])
assert all(A.vel == [0.0, 0.0, 0.0])
assert A.mass == 2.0

test()

0 comments on commit 0f44105

Please sign in to comment.