Skip to content

Commit

Permalink
[Lang] Add ti.round op (#3541)
Browse files Browse the repository at this point in the history
* add round op

* add round for other bankend

* include vulkan bankend for unary op test

* format cpp

* pylint skip redefined-builtin check at round op
  • Loading branch information
gaoxinge committed Nov 23, 2021
1 parent 26b833e commit cbdf5dd
Show file tree
Hide file tree
Showing 14 changed files with 37 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/lang/articles/basic/type.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ ti.sqrt(x)
ti.rsqrt(x) # A fast version for `1 / ti.sqrt(x)`.
ti.exp(x)
ti.log(x)
ti.round(x)
ti.floor(x)
ti.ceil(x)
```
Expand Down
13 changes: 13 additions & 0 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,19 @@ def _rsqrt(a):
return _unary_operation(_ti_core.expr_rsqrt, _rsqrt, a)


@unary
def round(a): # pylint: disable=redefined-builtin
"""The round function.
Args:
a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix.
Returns:
The nearest integer of `a`.
"""
return _unary_operation(_ti_core.expr_round, builtins.round, a)


@unary
def floor(a):
"""The floor function.
Expand Down
2 changes: 2 additions & 0 deletions taichi/backends/metal/data_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ std::string metal_unary_op_type_symbol(UnaryOpType type) {
return "-";
case UnaryOpType::sqrt:
return "sqrt";
case UnaryOpType::round:
return "round";
case UnaryOpType::floor:
return "floor";
case UnaryOpType::ceil:
Expand Down
1 change: 1 addition & 0 deletions taichi/backends/vulkan/codegen_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ class TaskCodegen : public IRVisitor {
TI_NOT_IMPLEMENTED \
} \
}
UNARY_OP_TO_SPIRV(round, Round, 1, 64)
UNARY_OP_TO_SPIRV(floor, Floor, 8, 64)
UNARY_OP_TO_SPIRV(ceil, Ceil, 9, 64)
UNARY_OP_TO_SPIRV(sin, Sin, 13, 32)
Expand Down
1 change: 1 addition & 0 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) {
llvm_val[stmt] = builder->CreateNeg(input, "neg");
}
}
UNARY_INTRINSIC(round)
UNARY_INTRINSIC(floor)
UNARY_INTRINSIC(ceil)
else {
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/unary_op.inc.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
PER_UNARY_OP(neg)
PER_UNARY_OP(sqrt)
PER_UNARY_OP(round)
PER_UNARY_OP(floor)
PER_UNARY_OP(ceil)
PER_UNARY_OP(cast_value)
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/expression_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#endif

DEFINE_EXPRESSION_OP_UNARY(sqrt)
DEFINE_EXPRESSION_OP_UNARY(round)
DEFINE_EXPRESSION_OP_UNARY(floor)
DEFINE_EXPRESSION_OP_UNARY(ceil)
DEFINE_EXPRESSION_OP_UNARY(abs)
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ void UnaryOpExpression::type_check() {
throw std::runtime_error(
fmt::format("TypeError: unsupported operand type(s) for '{}': '{}'",
unary_op_type_name(type), operand->ret_type->to_string()));
if ((type == UnaryOpType::floor || type == UnaryOpType::ceil ||
is_trigonometric(type)) &&
if ((type == UnaryOpType::round || type == UnaryOpType::floor ||
type == UnaryOpType::ceil || is_trigonometric(type)) &&
!is_real(operand->ret_type))
throw std::runtime_error(fmt::format(
"TypeError: '{}' takes real inputs only, however '{}' is provided",
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ UnaryOpStmt *IRBuilder::create_logical_not(Stmt *value) {
return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::logic_not, value));
}

UnaryOpStmt *IRBuilder::create_round(Stmt *value) {
return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::round, value));
}

UnaryOpStmt *IRBuilder::create_floor(Stmt *value) {
return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::floor, value));
}
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class IRBuilder {
UnaryOpStmt *create_neg(Stmt *value);
UnaryOpStmt *create_not(Stmt *value); // bitwise
UnaryOpStmt *create_logical_not(Stmt *value);
UnaryOpStmt *create_round(Stmt *value);
UnaryOpStmt *create_floor(Stmt *value);
UnaryOpStmt *create_ceil(Stmt *value);
UnaryOpStmt *create_abs(Stmt *value);
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ void export_lang(py::module &m) {

m.def("expr_neg", [&](const Expr &e) { return -e; });
DEFINE_EXPRESSION_OP_UNARY(sqrt)
DEFINE_EXPRESSION_OP_UNARY(round)
DEFINE_EXPRESSION_OP_UNARY(floor)
DEFINE_EXPRESSION_OP_UNARY(ceil)
DEFINE_EXPRESSION_OP_UNARY(abs)
Expand Down
7 changes: 4 additions & 3 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,11 @@ class TypeCheck : public IRVisitor {
if (is_trigonometric(stmt->op_type)) {
TI_ERROR("[{}] Trigonometric operator takes real inputs only, at {}",
stmt->name(), stmt->tb);
} else if (stmt->op_type == UnaryOpType::floor ||
} else if (stmt->op_type == UnaryOpType::round ||
stmt->op_type == UnaryOpType::floor ||
stmt->op_type == UnaryOpType::ceil) {
TI_ERROR("[{}] floor/ceil takes real inputs only at {}", stmt->name(),
stmt->tb);
TI_ERROR("[{}] round/floor/ceil takes real inputs only at {}",
stmt->name(), stmt->tb);
} else if (stmt->op_type == UnaryOpType::sqrt ||
stmt->op_type == UnaryOpType::exp ||
stmt->op_type == UnaryOpType::log) {
Expand Down
6 changes: 4 additions & 2 deletions tests/python/test_element_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ def func():
assert allclose(x[11], np.maximum(y, z))


@ti.test(exclude=[ti.vulkan])
@ti.test()
def test_unary():
xi = ti.Matrix.field(3, 2, ti.i32, 4)
yi = ti.Matrix.field(3, 2, ti.i32, ())
xf = ti.Matrix.field(3, 2, ti.f32, 14)
xf = ti.Matrix.field(3, 2, ti.f32, 15)
yf = ti.Matrix.field(3, 2, ti.f32, ())

yi.from_numpy(np.array([[3, 2], [9, 0], [7, 4]], np.int32))
Expand All @@ -272,6 +272,7 @@ def func():
xf[11] = ti.exp(yf[None])
xf[12] = ti.log(yf[None])
xf[13] = ti.rsqrt(yf[None])
xf[14] = ti.round(yf[None])

func()
xi = xi.to_numpy()
Expand All @@ -295,6 +296,7 @@ def func():
assert allclose(xf[11], np.exp(yf), rel=1e-5)
assert allclose(xf[12], np.log(yf), rel=1e-5)
assert allclose(xf[13], 1 / np.sqrt(yf), rel=1e-5)
assert allclose(xf[14], np.round(yf), rel=1e-5)


@pytest.mark.parametrize('is_mat', [(True, True, True), (True, False, False),
Expand Down
1 change: 1 addition & 0 deletions tests/python/test_scalar_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
(ti.asin, np.arcsin),
(ti.acos, np.arccos),
(ti.tanh, np.tanh),
(ti.round, np.round),
(ti.floor, np.floor),
(ti.ceil, np.ceil),
]
Expand Down

0 comments on commit cbdf5dd

Please sign in to comment.