Skip to content

Commit

Permalink
[IR] Add type_check for Atomic/SNodeOpExpression (taichi-dev#3444)
Browse files Browse the repository at this point in the history
* Type check for atomic op

* Type check for snode op

* Auto Format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
2 people authored and sjwsl committed Nov 11, 2021
1 parent 254c138 commit fa0537b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
33 changes: 33 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,31 @@ void IdExpression::flatten(FlattenContext *ctx) {
}
}

void AtomicOpExpression::type_check() {
// TODO: assert no unknowns after type_check for all expressions are
// implemented
if (dest->ret_type == PrimitiveType::unknown ||
val->ret_type == PrimitiveType::unknown)
return;
auto error = [&]() {
throw std::runtime_error(fmt::format(
"TypeError: unsupported operand type(s) for 'atomic_{}': '{}' and '{}'",
atomic_op_type_name(op_type), dest->ret_type->to_string(),
val->ret_type->to_string()));
};
if (!val->ret_type->is<PrimitiveType>())
error();
if (auto cit = dest->ret_type->cast<CustomIntType>()) {
ret_type = cit->get_compute_type();
} else if (auto cft = dest->ret_type->cast<CustomFloatType>()) {
ret_type = cft->get_compute_type();
} else if (dest->ret_type->is<PrimitiveType>()) {
ret_type = dest->ret_type;
} else {
error();
}
}

void AtomicOpExpression::serialize(std::ostream &ss) {
if (op_type == AtomicOpType::add) {
ss << "atomic_add(";
Expand Down Expand Up @@ -566,6 +591,14 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void SNodeOpExpression::type_check() {
if (op_type == SNodeOpType::get_addr) {
ret_type = PrimitiveType::u64;
} else {
ret_type = PrimitiveType::i32;
}
}

void SNodeOpExpression::serialize(std::ostream &ss) {
ss << snode_op_type_name(op_type);
ss << '(';
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,8 @@ class AtomicOpExpression : public Expression {
: op_type(op_type), dest(dest), val(val) {
}

void type_check() override;

void serialize(std::ostream &ss) override;

void flatten(FlattenContext *ctx) override;
Expand All @@ -660,6 +662,8 @@ class SNodeOpExpression : public Expression {
: snode(snode), op_type(op_type), indices(indices), value(value) {
}

void type_check() override;

void serialize(std::ostream &ss) override;

void flatten(FlattenContext *ctx) override;
Expand Down
22 changes: 22 additions & 0 deletions tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,28 @@ TEST(FrontendTypeInference, TensorElement) {
EXPECT_EQ(load_tensor_element->ret_type, PrimitiveType::u32);
}

TEST(FrontendTypeInference, AtomicOp) {
auto const_i32 = Expr::make<ConstExpression, int32>(-(1 << 20));
const_i32->type_check();
auto const_f32 = Expr::make<ConstExpression, float32>(5.0);
const_f32->type_check();
auto atomic_add_i32 =
Expr::make<AtomicOpExpression>(AtomicOpType::add, const_i32, const_f32);
atomic_add_i32->type_check();
EXPECT_EQ(atomic_add_i32->ret_type, PrimitiveType::i32);
}

TEST(FrontendTypeInference, SNodeOp) {
auto snode = std::make_unique<SNode>(0, SNodeType::root);
snode->dt = PrimitiveType::u8;
auto index = Expr::make<ConstExpression, int32>(2);
index->type_check();
auto snode_op = Expr::make<SNodeOpExpression>(
snode.get(), SNodeOpType::get_addr, ExprGroup(index));
snode_op->type_check();
EXPECT_EQ(snode_op->ret_type, PrimitiveType::u64);
}

TEST(FrontendTypeInference, ExternalTensorShapeAlongAxis) {
auto external_tensor =
Expr::make<ExternalTensorExpression>(PrimitiveType::u64, 1, 0, 0);
Expand Down

0 comments on commit fa0537b

Please sign in to comment.