Skip to content

Commit

Permalink
[Lang] Support syntax sugar for ti.cast (#5515)
Browse files Browse the repository at this point in the history
* [Lang] Support syntax sugar for ti.cast

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] committed Jul 25, 2022
1 parent 0f44105 commit 111178b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
7 changes: 6 additions & 1 deletion docs/lang/articles/reference/language_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,12 @@ positional_item ::= assignment_expression | "*" expression
The `primary` must be evaluated to one of:
- A [Taichi function](../kernels/syntax.md#taichi-function).
- A [Taichi builtin function](./operator.md#other-arithmetic-functions).
- A Taichi primitive type, which serves as a type annotation for a literal. In this case, the `positional_arguments` must be evaluated to a single Python value, and the Python value will be turned into a Taichi value with that annotated type.
- A Taichi primitive type. In this case, the `positional_arguments` must only
contain one item. If the item is evaluated to a Python value, then the
primitive type serves as a type annotation for a literal, and the Python value
will be turned into a Taichi value with that annotated type. Otherwise, the
primitive type serves as a syntax sugar for `ti.cast()`, but the item cannot
have a compound type.
- A Python callable object. If not inside a [static expression](#static-expressions), a warning is produced.

### The power operator
Expand Down
13 changes: 10 additions & 3 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,17 @@ def build_call_if_is_builtin(ctx, node, args, keywords):
def build_call_if_is_type(ctx, node, args, keywords):
func = node.func.ptr
if id(func) in primitive_types.type_ids:
if len(args) != 1 or keywords or isinstance(args[0], expr.Expr):
if len(args) != 1 or keywords:
raise TaichiSyntaxError(
"Type annotation can only be given to a single literal.")
node.ptr = expr.Expr(args[0], dtype=func)
"A primitive type can only decorate a single expression.")
if is_taichi_class(args[0]):
raise TaichiSyntaxError(
"A primitive type cannot decorate an expression with a compound type."
)
if isinstance(args[0], expr.Expr):
node.ptr = ti_ops.cast(args[0], func)
else:
node.ptr = expr.Expr(args[0], dtype=func)
return True
return False

Expand Down
12 changes: 10 additions & 2 deletions tests/python/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ def test_cast_uint_to_float(dtype):
def func(a: dtype) -> ti.f32:
return ti.cast(a, ti.f32)

assert func(255) == 255
@ti.kernel
def func_sugar(a: dtype) -> ti.f32:
return ti.f32(a)

assert func(255) == func_sugar(255) == 255


@pytest.mark.parametrize('dtype', [ti.u8, ti.u16, ti.u32])
Expand All @@ -21,7 +25,11 @@ def test_cast_float_to_uint(dtype):
def func(a: ti.f32) -> dtype:
return ti.cast(a, dtype)

assert func(255) == 255
@ti.kernel
def func_sugar(a: ti.f32) -> dtype:
return dtype(a)

assert func(255) == func_sugar(255) == 255


@test_utils.test()
Expand Down
12 changes: 7 additions & 5 deletions tests/python/test_literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def multi_args_error():

with pytest.raises(
ti.TaichiSyntaxError,
match="Type annotation can only be given to a single literal."):
match="A primitive type can only decorate a single expression."):
multi_args_error()


Expand All @@ -37,20 +37,22 @@ def keywords_error():

with pytest.raises(
ti.TaichiSyntaxError,
match="Type annotation can only be given to a single literal."):
match="A primitive type can only decorate a single expression."):
keywords_error()


@test_utils.test()
def test_literal_expr_error():
def test_literal_compound_error():
@ti.kernel
def expr_error():
a = 1
a = ti.Vector([1])
b = ti.f16(a)

with pytest.raises(
ti.TaichiSyntaxError,
match="Type annotation can only be given to a single literal."):
match=
"A primitive type cannot decorate an expression with a compound type."
):
expr_error()


Expand Down

0 comments on commit 111178b

Please sign in to comment.