diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 6e1ba512b16f4..b0ebc50f0f277 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -1,7 +1,7 @@ import numbers from types import MethodType -from taichi.lang import expr, impl, ops +from taichi.lang import impl, ops from taichi.lang.common_ops import TaichiOperations from taichi.lang.enums import Layout from taichi.lang.exception import TaichiSyntaxError @@ -648,21 +648,22 @@ def __init__(self, **kwargs): self.members[k] = cook_dtype(dtype) def __call__(self, *args, **kwargs): - if len(args) == 0: - if kwargs == {}: - raise TaichiSyntaxError( - "Custom type instances need to be created with an initial value." - ) - else: - # initialize struct members by keywords - entries = Struct(kwargs) - elif len(args) == 1: - # fill a single scalar - if isinstance(args[0], (numbers.Number, expr.Expr)): - entries = self.filled_with_scalar(args[0]) + d = {} + items = self.members.items() + for index, pair in enumerate(items): + name, dtype = pair + if isinstance(dtype, CompoundType): + if index < len(args): + d[name] = dtype(args[index]) + else: + d[name] = kwargs.get(name, dtype(0)) else: - # initialize struct members by dictionary - entries = Struct(args[0]) + if index < len(args): + d[name] = args[index] + else: + d[name] = kwargs.get(name, 0) + + entries = Struct(d) struct = self.cast(entries) return struct diff --git a/tests/python/test_custom_struct.py b/tests/python/test_custom_struct.py index 363148ece5716..260ddf7b21a8b 100644 --- a/tests/python/test_custom_struct.py +++ b/tests/python/test_custom_struct.py @@ -179,9 +179,9 @@ def run_python_scope(): init_taichi_scope() for i in range(n): - assert x[i].idx == 1 + assert x[i].idx == 0 assert np.allclose(x[i].line.linedir.to_numpy(), 1.0) - assert x[i].line.length == 1.0 + assert x[i].line.length == 0.0 run_taichi_scope() for i in range(n): assert x[i].idx == i @@ -189,9 +189,9 @@ def run_python_scope(): assert x[i].line.length == i + 0.5 init_python_scope() for i in range(n): - assert x[i].idx == 3 + assert x[i].idx == 0 assert np.allclose(x[i].line.linedir.to_numpy(), 3.0) - assert x[i].line.length == 3.0 + assert x[i].line.length == 0.0 run_python_scope() for i in range(n): assert x[i].idx == i @@ -293,20 +293,20 @@ def test_compound_type_implicit_cast(): @ti.kernel def f2i_taichi_scope() -> int: - s = structi(2.5) + s = structi(2.5, (2.5, 2.5)) return s.a + s.b[0] + s.b[1] def f2i_python_scope(): - s = structi(2.5) + s = structi(2.5, (2.5, 2.5)) return s.a + s.b[0] + s.b[1] @ti.kernel def i2f_taichi_scope() -> float: - s = structf(2) + s = structf(2, (2, 2)) return s.a + s.b[0] + s.b[1] def i2f_python_scope(): - s = structf(2) + s = structf(2, (2, 2)) return s.a + s.b[0] + s.b[1] int_value = f2i_taichi_scope() @@ -396,3 +396,23 @@ def test(): assert a.b == 3 test() + + +@test_utils.test(debug=True) +def test_dataclass(): + vec3 = ti.types.vector(3, float) + + @ti.dataclass + class Foo: + pos: vec3 + vel: vec3 + mass: float + + @ti.kernel + def test(): + A = Foo((1, 1, 1), mass=2) + assert all(A.pos == [1.0, 1.0, 1.0]) + assert all(A.vel == [0.0, 0.0, 0.0]) + assert A.mass == 2.0 + + test()