Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Deprecate FrontendAtomicStmt #907

Merged
merged 3 commits into from
May 2, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ PER_STATEMENT(FrontendBreakStmt)
PER_STATEMENT(FrontendContinueStmt)
PER_STATEMENT(FrontendAllocaStmt)
PER_STATEMENT(FrontendAssignStmt)
PER_STATEMENT(FrontendAtomicStmt)
PER_STATEMENT(FrontendEvalStmt)
PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear
PER_STATEMENT(FrontendAssertStmt)
Expand Down
8 changes: 4 additions & 4 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,17 @@ Expr Expr::eval() const {

void Expr::operator+=(const Expr &o) {
if (this->atomic) {
current_ast_builder().insert(Stmt::make<FrontendAtomicStmt>(
AtomicOpType::add, ptr_if_global(*this), load_if_ptr(o)));
(*this) = Expr::make<AtomicOpExpression>(
AtomicOpType::add, ptr_if_global(*this), load_if_ptr(o));
} else {
(*this) = (*this) + o;
}
}

void Expr::operator-=(const Expr &o) {
if (this->atomic) {
current_ast_builder().insert(Stmt::make<FrontendAtomicStmt>(
AtomicOpType::add, *this, -load_if_ptr(o)));
(*this) = Expr::make<AtomicOpExpression>(
AtomicOpType::add, ptr_if_global(*this), -load_if_ptr(o));
k-ye marked this conversation as resolved.
Show resolved Hide resolved
} else {
(*this) = (*this) - o;
}
Expand Down
28 changes: 22 additions & 6 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,12 +553,6 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs)
TI_ASSERT(lhs->is_lvalue());
}

FrontendAtomicStmt::FrontendAtomicStmt(AtomicOpType op_type,
const Expr &dest,
const Expr &val)
: op_type(op_type), dest(dest), val(val) {
}

IRNode *FrontendContext::root() {
return static_cast<IRNode *>(root_node.get());
}
Expand Down Expand Up @@ -847,6 +841,28 @@ std::string AtomicOpExpression::serialize() {
}
}

void AtomicOpExpression::flatten(FlattenContext *ctx) {
// replace atomic sub with negative atomic add
if (op_type == AtomicOpType::sub) {
val.set(Expr::make<UnaryOpExpression>(UnaryOpType::neg, val));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this cause an circular reference?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but nope. If you look into set():

taichi/taichi/ir/expr.h

Lines 47 to 49 in f5373b1

void set(const Expr &o) {
expr = o.expr;
}

It assigns expr to another std::shared_ptr. shared_ptr::operator=() will ref the new object, and deref the old object. https://en.cppreference.com/w/cpp/memory/shared_ptr/operator%3D

op_type = AtomicOpType::add;
}
Comment on lines +843 to +847
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Could you clarify the reason in the comment? I don't understand.. wasn't atomic_sub slimer than neg+atomic_add?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't make any change here. Please compare this with the old code (visit(FrontendAtomicStmt*))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we should ask @yuanming-hu for reason?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason is simply that the codegen assumes that sub is converted to add in the previous passes, see

if (stmt->op_type == AtomicOpType::add) {
if (is_integral(stmt->val->ret_type.data_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest],
llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f32) {
old_value =
builder->CreateCall(get_runtime_function("atomic_add_f32"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type.data_type == DataType::f64) {
old_value =
builder->CreateCall(get_runtime_function("atomic_add_f64"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::min) {
. It only handles add, but not sub

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So let's just add handles for sub? It should have no harm but profitable right? Also metal and gl handles sub too, no reason special for llvm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current approach is fine. Adding handling forsub will have to cover all the backends, not just for LLVM. Again, let’s always focus one PR for one thing.

// expand rhs
auto expr = val;
expr->flatten(ctx);
if (dest.is<IdExpression>()) { // local variable
// emit local store stmt
auto alloca = ctx->current_block->lookup_var(dest.cast<IdExpression>()->id);
ctx->push_back<AtomicOpStmt>(op_type, alloca, expr->stmt);
} else { // global variable
TI_ASSERT(dest.is<GlobalPtrExpression>());
auto global_ptr = dest.cast<GlobalPtrExpression>();
global_ptr->flatten(ctx);
ctx->push_back<AtomicOpStmt>(op_type, ctx->back_stmt(), expr->stmt);
}
stmt = ctx->back_stmt();
}

std::string SNodeOpExpression::serialize() {
if (value.expr) {
return fmt::format("{}({}, [{}], {})", snode_op_type_name(op_type),
Expand Down
23 changes: 2 additions & 21 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1336,16 +1336,6 @@ class Block : public IRNode {
DEFINE_ACCEPT
};

class FrontendAtomicStmt : public Stmt {
public:
AtomicOpType op_type;
Expr dest, val;

FrontendAtomicStmt(AtomicOpType op_type, const Expr &dest, const Expr &val);

DEFINE_ACCEPT
};

class FrontendSNodeOpStmt : public Stmt {
public:
SNodeOpType op_type;
Expand Down Expand Up @@ -1948,11 +1938,8 @@ class IdExpression : public Expression {
}
};

// This is just a wrapper class of FrontendAtomicStmt, so that we can turn
// ti.atomic_op() into an expression (with side effect).
// ti.atomic_*() is an expression with side effect.
class AtomicOpExpression : public Expression {
// TODO(issue#332): Flatten this into AtomicOpStmt directly, then we can
// deprecate FrontendAtomicStmt.
public:
AtomicOpType op_type;
Expr dest, val;
Expand All @@ -1963,13 +1950,7 @@ class AtomicOpExpression : public Expression {

std::string serialize() override;

void flatten(FlattenContext *ctx) override {
// FrontendAtomicStmt is the correct place to flatten sub-exprs like |dest|
// and |val| (See LowerAST). This class only wraps the frontend atomic_op()
// stmt as an expression.
ctx->push_back<FrontendAtomicStmt>(op_type, dest, val);
stmt = ctx->back_stmt();
}
void flatten(FlattenContext *ctx) override;
};

class SNodeOpExpression : public Expression {
Expand Down
6 changes: 0 additions & 6 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,6 @@ class IRPrinter : public IRVisitor {
stmt->op2->name(), stmt->op3->name());
}

void visit(FrontendAtomicStmt *stmt) override {
print("{}{} = atomic {}({}, {})", stmt->type_hint(), stmt->name(),
atomic_op_type_name(stmt->op_type), stmt->dest->serialize(),
stmt->val->serialize());
}

void visit(AtomicOpStmt *stmt) override {
print("{}{} = atomic {}({}, {})", stmt->type_hint(), stmt->name(),
atomic_op_type_name(stmt->op_type), stmt->dest->name(),
Expand Down
25 changes: 0 additions & 25 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,31 +314,6 @@ class LowerAST : public IRVisitor {
throw IRModified();
}

void visit(FrontendAtomicStmt *stmt) override {
// replace atomic sub with negative atomic add
if (stmt->op_type == AtomicOpType::sub) {
stmt->val.set(Expr::make<UnaryOpExpression>(UnaryOpType::neg, stmt->val));
stmt->op_type = AtomicOpType::add;
}
// expand rhs
auto expr = stmt->val;
auto fctx = make_flatten_ctx();
expr->flatten(&fctx);
if (stmt->dest.is<IdExpression>()) { // local variable
// emit local store stmt
auto alloca =
stmt->parent->lookup_var(stmt->dest.cast<IdExpression>()->id);
fctx.push_back<AtomicOpStmt>(stmt->op_type, alloca, expr->stmt);
} else { // global variable
TI_ASSERT(stmt->dest.is<GlobalPtrExpression>());
auto global_ptr = stmt->dest.cast<GlobalPtrExpression>();
global_ptr->flatten(&fctx);
fctx.push_back<AtomicOpStmt>(stmt->op_type, fctx.back_stmt(), expr->stmt);
}
stmt->parent->replace_with(stmt, std::move(fctx.stmts));
throw IRModified();
}

void visit(FrontendSNodeOpStmt *stmt) override {
// expand rhs
Stmt *val_stmt = nullptr;
Expand Down