diff --git a/docs/lang/articles/reference/language_reference.md b/docs/lang/articles/reference/language_reference.md index cd11c69ca237e..55509e63250fe 100644 --- a/docs/lang/articles/reference/language_reference.md +++ b/docs/lang/articles/reference/language_reference.md @@ -286,7 +286,12 @@ positional_item ::= assignment_expression | "*" expression The `primary` must be evaluated to one of: - A [Taichi function](../kernels/syntax.md#taichi-function). - A [Taichi builtin function](./operator.md#other-arithmetic-functions). -- A Taichi primitive type, which serves as a type annotation for a literal. In this case, the `positional_arguments` must be evaluated to a single Python value, and the Python value will be turned into a Taichi value with that annotated type. +- A Taichi primitive type. In this case, the `positional_arguments` must only + contain one item. If the item is evaluated to a Python value, then the + primitive type serves as a type annotation for a literal, and the Python value + will be turned into a Taichi value with that annotated type. Otherwise, the + primitive type serves as a syntax sugar for `ti.cast()`, but the item cannot + have a compound type. - A Python callable object. If not inside a [static expression](#static-expressions), a warning is produced. ### The power operator diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 497c132380d04..891cdf4f2347a 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -426,10 +426,17 @@ def build_call_if_is_builtin(ctx, node, args, keywords): def build_call_if_is_type(ctx, node, args, keywords): func = node.func.ptr if id(func) in primitive_types.type_ids: - if len(args) != 1 or keywords or isinstance(args[0], expr.Expr): + if len(args) != 1 or keywords: raise TaichiSyntaxError( - "Type annotation can only be given to a single literal.") - node.ptr = expr.Expr(args[0], dtype=func) + "A primitive type can only decorate a single expression.") + if is_taichi_class(args[0]): + raise TaichiSyntaxError( + "A primitive type cannot decorate an expression with a compound type." + ) + if isinstance(args[0], expr.Expr): + node.ptr = ti_ops.cast(args[0], func) + else: + node.ptr = expr.Expr(args[0], dtype=func) return True return False diff --git a/tests/python/test_cast.py b/tests/python/test_cast.py index c4475d0c26371..0cfa9c42da37a 100644 --- a/tests/python/test_cast.py +++ b/tests/python/test_cast.py @@ -11,7 +11,11 @@ def test_cast_uint_to_float(dtype): def func(a: dtype) -> ti.f32: return ti.cast(a, ti.f32) - assert func(255) == 255 + @ti.kernel + def func_sugar(a: dtype) -> ti.f32: + return ti.f32(a) + + assert func(255) == func_sugar(255) == 255 @pytest.mark.parametrize('dtype', [ti.u8, ti.u16, ti.u32]) @@ -21,7 +25,11 @@ def test_cast_float_to_uint(dtype): def func(a: ti.f32) -> dtype: return ti.cast(a, dtype) - assert func(255) == 255 + @ti.kernel + def func_sugar(a: ti.f32) -> dtype: + return dtype(a) + + assert func(255) == func_sugar(255) == 255 @test_utils.test() diff --git a/tests/python/test_literal.py b/tests/python/test_literal.py index 6aa17cdf7c0c7..2ff20e5e0709b 100644 --- a/tests/python/test_literal.py +++ b/tests/python/test_literal.py @@ -25,7 +25,7 @@ def multi_args_error(): with pytest.raises( ti.TaichiSyntaxError, - match="Type annotation can only be given to a single literal."): + match="A primitive type can only decorate a single expression."): multi_args_error() @@ -37,20 +37,22 @@ def keywords_error(): with pytest.raises( ti.TaichiSyntaxError, - match="Type annotation can only be given to a single literal."): + match="A primitive type can only decorate a single expression."): keywords_error() @test_utils.test() -def test_literal_expr_error(): +def test_literal_compound_error(): @ti.kernel def expr_error(): - a = 1 + a = ti.Vector([1]) b = ti.f16(a) with pytest.raises( ti.TaichiSyntaxError, - match="Type annotation can only be given to a single literal."): + match= + "A primitive type cannot decorate an expression with a compound type." + ): expr_error()