Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Opt] Simplify multiplying/dividing POT #2332

Merged
merged 6 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/TaichiTests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ endif()
# TODO(#2195):
# 1. "cpp" -> "cpp_legacy", "cpp_new" -> "cpp"
# 2. Re-implement the legacy CPP tests using googletest
file(GLOB_RECURSE TAICHI_TESTS_SOURCE "tests/cpp/analysis/*.cpp" "tests/cpp/common/*.cpp" "tests/cpp/ir/*.cpp")
file(GLOB_RECURSE TAICHI_TESTS_SOURCE "tests/cpp/analysis/*.cpp" "tests/cpp/common/*.cpp" "tests/cpp/ir/*.cpp"
"tests/cpp/transforms/*.cpp")

include_directories(
${PROJECT_SOURCE_DIR},
Expand Down
209 changes: 147 additions & 62 deletions taichi/transforms/alg_simp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/program/program.h"
#include "taichi/util/bit.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -64,16 +65,144 @@ class AlgSimp : public BasicStmtVisitor {
}
}

bool optimize_multiplication(BinaryOpStmt *stmt) {
// return true iff the IR is modified
auto lhs = stmt->lhs->cast<ConstStmt>();
auto rhs = stmt->rhs->cast<ConstStmt>();
TI_ASSERT(stmt->op_type == BinaryOpType::mul);
if (alg_is_one(lhs) || alg_is_one(rhs)) {
// 1 * a -> a, a * 1 -> a
stmt->replace_with(alg_is_one(lhs) ? stmt->rhs : stmt->lhs);
modifier.erase(stmt);
return true;
}
if ((fast_math || is_integral(stmt->ret_type)) &&
(alg_is_zero(lhs) || alg_is_zero(rhs))) {
// fast_math or integral operands: 0 * a -> 0, a * 0 -> 0
replace_with_zero(stmt);
return true;
}
if (is_integral(stmt->ret_type) && (alg_is_pot(lhs) || alg_is_pot(rhs))) {
// a * pot -> a << log2(pot)
if (alg_is_pot(lhs)) {
std::swap(stmt->lhs, stmt->rhs);
std::swap(lhs, rhs);
}
int log2rhs;
if (is_signed(rhs->val[0].dt)) {
log2rhs = bit::log2int(rhs->val[0].val_int());
} else {
log2rhs = bit::log2int(rhs->val[0].val_uint());
}
auto new_rhs = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(
TypedConstant(stmt->lhs->ret_type, log2rhs)));
auto result = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_shl, stmt->lhs,
new_rhs.get());
result->ret_type = stmt->ret_type;
stmt->replace_with(result.get());
modifier.insert_before(stmt, std::move(new_rhs));
modifier.insert_before(stmt, std::move(result));
modifier.erase(stmt);
return true;
}
if (alg_is_two(lhs) || alg_is_two(rhs)) {
// 2 * a -> a + a, a * 2 -> a + a
auto a = stmt->lhs;
if (alg_is_two(lhs))
a = stmt->rhs;
cast_to_result_type(a, stmt);
auto sum = Stmt::make<BinaryOpStmt>(BinaryOpType::add, a, a);
sum->ret_type = a->ret_type;
stmt->replace_with(sum.get());
modifier.insert_before(stmt, std::move(sum));
modifier.erase(stmt);
return true;
}
return false;
}

bool optimize_division(BinaryOpStmt *stmt) {
// return true iff the IR is modified
auto rhs = stmt->rhs->cast<ConstStmt>();
TI_ASSERT(stmt->op_type == BinaryOpType::div ||
stmt->op_type == BinaryOpType::floordiv);
if (alg_is_one(rhs)) {
// a / 1 -> a
stmt->replace_with(stmt->lhs);
modifier.erase(stmt);
return true;
}
if ((fast_math || is_integral(stmt->ret_type)) &&
irpass::analysis::same_value(stmt->lhs, stmt->rhs)) {
// fast_math or integral operands: a / a -> 1
replace_with_one(stmt);
return true;
}
if (fast_math && rhs && is_real(rhs->ret_type) &&
stmt->op_type != BinaryOpType::floordiv) {
if (alg_is_zero(rhs)) {
TI_WARN("Potential division by 0");
} else {
// a / const -> a * (1 / const)
auto reciprocal = Stmt::make_typed<ConstStmt>(
LaneAttribute<TypedConstant>(rhs->ret_type));
if (rhs->ret_type->is_primitive(PrimitiveTypeID::f64)) {
reciprocal->val[0].val_float64() =
(float64)1.0 / rhs->val[0].val_float64();
} else if (rhs->ret_type->is_primitive(PrimitiveTypeID::f32)) {
reciprocal->val[0].val_float32() =
(float32)1.0 / rhs->val[0].val_float32();
} else {
TI_NOT_IMPLEMENTED
}
auto product = Stmt::make<BinaryOpStmt>(BinaryOpType::mul, stmt->lhs,
reciprocal.get());
product->ret_type = stmt->ret_type;
stmt->replace_with(product.get());
modifier.insert_before(stmt, std::move(reciprocal));
modifier.insert_before(stmt, std::move(product));
modifier.erase(stmt);
return true;
}
}
if (is_integral(stmt->lhs->ret_type) && is_unsigned(stmt->lhs->ret_type) &&
alg_is_pot(rhs)) {
// (unsigned)a / pot -> a >> log2(pot)
int log2rhs;
if (is_signed(rhs->val[0].dt)) {
log2rhs = bit::log2int(rhs->val[0].val_int());
} else {
log2rhs = bit::log2int(rhs->val[0].val_uint());
}
auto new_rhs = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(
TypedConstant(stmt->lhs->ret_type, log2rhs)));
auto result = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_sar, stmt->lhs,
new_rhs.get());
result->ret_type = stmt->ret_type;
stmt->replace_with(result.get());
modifier.insert_before(stmt, std::move(new_rhs));
modifier.insert_before(stmt, std::move(result));
modifier.erase(stmt);
return true;
}
return false;
}

void visit(BinaryOpStmt *stmt) override {
auto lhs = stmt->lhs->cast<ConstStmt>();
auto rhs = stmt->rhs->cast<ConstStmt>();
if (stmt->width() != 1) {
return;
}
if (stmt->op_type == BinaryOpType::add ||
stmt->op_type == BinaryOpType::sub ||
stmt->op_type == BinaryOpType::bit_or ||
stmt->op_type == BinaryOpType::bit_xor) {
if (stmt->op_type == BinaryOpType::mul) {
optimize_multiplication(stmt);
} else if (stmt->op_type == BinaryOpType::div ||
stmt->op_type == BinaryOpType::floordiv) {
optimize_division(stmt);
} else if (stmt->op_type == BinaryOpType::add ||
stmt->op_type == BinaryOpType::sub ||
stmt->op_type == BinaryOpType::bit_or ||
stmt->op_type == BinaryOpType::bit_xor) {
if (alg_is_zero(rhs)) {
// a +-|^ 0 -> a
stmt->replace_with(stmt->lhs);
Expand All @@ -94,64 +223,6 @@ class AlgSimp : public BasicStmtVisitor {
// fast_math or integral operands: a -^ a -> 0
replace_with_zero(stmt);
}
} else if (stmt->op_type == BinaryOpType::mul ||
stmt->op_type == BinaryOpType::div) {
if (alg_is_one(rhs)) {
// a */ 1 -> a
stmt->replace_with(stmt->lhs);
modifier.erase(stmt);
} else if (stmt->op_type == BinaryOpType::mul && alg_is_one(lhs)) {
// 1 * a -> a
stmt->replace_with(stmt->rhs);
modifier.erase(stmt);
} else if ((fast_math || is_integral(stmt->ret_type)) &&
stmt->op_type == BinaryOpType::mul &&
(alg_is_zero(lhs) || alg_is_zero(rhs))) {
// fast_math or integral operands: 0 * a -> 0, a * 0 -> 0
replace_with_zero(stmt);
} else if ((fast_math || is_integral(stmt->ret_type)) &&
stmt->op_type == BinaryOpType::div &&
irpass::analysis::same_value(stmt->lhs, stmt->rhs)) {
// fast_math or integral operands: a / a -> 1
replace_with_one(stmt);
} else if (stmt->op_type == BinaryOpType::mul &&
(alg_is_two(lhs) || alg_is_two(rhs))) {
// 2 * a -> a + a, a * 2 -> a + a
auto a = stmt->lhs;
if (alg_is_two(lhs))
a = stmt->rhs;
cast_to_result_type(a, stmt);
auto sum = Stmt::make<BinaryOpStmt>(BinaryOpType::add, a, a);
sum->ret_type = a->ret_type;
stmt->replace_with(sum.get());
modifier.insert_before(stmt, std::move(sum));
modifier.erase(stmt);
} else if (fast_math && stmt->op_type == BinaryOpType::div && rhs &&
is_real(rhs->ret_type)) {
if (alg_is_zero(rhs)) {
TI_WARN("Potential division by 0");
} else {
// a / const -> a * (1 / const)
auto reciprocal = Stmt::make_typed<ConstStmt>(
LaneAttribute<TypedConstant>(rhs->ret_type));
if (rhs->ret_type->is_primitive(PrimitiveTypeID::f64)) {
reciprocal->val[0].val_float64() =
(float64)1.0 / rhs->val[0].val_float64();
} else if (rhs->ret_type->is_primitive(PrimitiveTypeID::f32)) {
reciprocal->val[0].val_float32() =
(float32)1.0 / rhs->val[0].val_float32();
} else {
TI_NOT_IMPLEMENTED
}
auto product = Stmt::make<BinaryOpStmt>(BinaryOpType::mul, stmt->lhs,
reciprocal.get());
product->ret_type = stmt->ret_type;
stmt->replace_with(product.get());
modifier.insert_before(stmt, std::move(reciprocal));
modifier.insert_before(stmt, std::move(product));
modifier.erase(stmt);
}
}
} else if (rhs && stmt->op_type == BinaryOpType::pow) {
float64 exponent = rhs->val[0].val_cast_to_float64();
if (exponent == 1) {
Expand Down Expand Up @@ -313,6 +384,20 @@ class AlgSimp : public BasicStmtVisitor {
return stmt->val[0].equal_value(-1);
}

static bool alg_is_pot(ConstStmt *stmt) {
if (!stmt || stmt->width() != 1)
return false;
if (!is_integral(stmt->val[0].dt))
return false;
if (is_signed(stmt->val[0].dt)) {
auto val = stmt->val[0].val_int();
return val > 0 && val == bit::lowbit(val);
} else {
auto val = stmt->val[0].val_uint();
return val > 0 && val == bit::lowbit(val);
}
}

static bool run(IRNode *node, bool fast_math) {
AlgSimp simplifier(fast_math);
bool modified = false;
Expand Down
10 changes: 9 additions & 1 deletion taichi/transforms/binary_op_simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,21 @@ class BinaryOpSimp : public BasicStmtVisitor {
(op2 == BinaryOpType::mul ? BinaryOpType::div : BinaryOpType::mul);
return true;
}
// for bit operations it only holds when two ops are the same
// for bit operations it holds when two ops are the same
if ((op1 == BinaryOpType::bit_and || op1 == BinaryOpType::bit_or ||
op1 == BinaryOpType::bit_xor) &&
op1 == op2) {
new_op2 = op2;
return true;
}
if ((op1 == BinaryOpType::bit_shl || op1 == BinaryOpType::bit_shr ||
op1 == BinaryOpType::bit_sar) &&
op1 == op2) {
// (a << b) << c -> a << (b + c)
// (a >> b) >> c -> a >> (b + c)
new_op2 = BinaryOpType::add;
return true;
}
return false;
}

Expand Down
4 changes: 3 additions & 1 deletion taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ class ConstantFold : public BasicStmtVisitor {
// Discussion:
// https://github.com/taichi-dev/taichi/pull/839#issuecomment-625902727
if (dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::f32) ||
dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u32) ||
dt->is_primitive(PrimitiveTypeID::u64) ||
dt->is_primitive(PrimitiveTypeID::f32) ||
dt->is_primitive(PrimitiveTypeID::f64))
return true;
else
Expand Down
11 changes: 8 additions & 3 deletions taichi/transforms/loop_invariant_code_motion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ TLANG_NAMESPACE_BEGIN

class LoopInvariantCodeMotion : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;

std::stack<Block *> loop_blocks;

const CompileConfig &config;

DelayedIRModifier modifier;

LoopInvariantCodeMotion(const CompileConfig &config) : config(config) {
explicit LoopInvariantCodeMotion(const CompileConfig &config)
: config(config) {
allow_undefined_visitor = true;
}

Expand Down Expand Up @@ -49,7 +52,8 @@ class LoopInvariantCodeMotion : public BasicStmtVisitor {
Stmt *operand_parent = operand;
while (operand_parent && operand_parent->parent) {
operand_parent = operand_parent->parent->parent_stmt;
if (!operand_parent) break;
if (!operand_parent)
break;
// If the one of the parent of the operand is the top loop scope
// Then it will not be visible if we move it outside the top loop
// scope
Expand Down Expand Up @@ -118,12 +122,13 @@ class LoopInvariantCodeMotion : public BasicStmtVisitor {
if (stmt->bls_prologue)
stmt->bls_prologue->accept(this);

if (stmt->body)
if (stmt->body) {
if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
stmt->task_type == OffloadedStmt::TaskType::struct_for)
visit_loop(stmt->body.get());
else
stmt->body->accept(this);
}

if (stmt->bls_epilogue)
stmt->bls_epilogue->accept(this);
Expand Down
57 changes: 57 additions & 0 deletions tests/cpp/transforms/binary_op_simplify_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "gtest/gtest.h"

#include "taichi/ir/statements.h"
#include "taichi/ir/ir_builder.h"
#include "taichi/ir/transforms.h"

namespace taichi {
namespace lang {

class BinaryOpSimplifyTest : public ::testing::Test {
protected:
void SetUp() override {
prog_ = std::make_unique<Program>();
prog_->materialize_layout();
}

std::unique_ptr<Program> prog_;
};

TEST_F(BinaryOpSimplifyTest, MultiplyPOT) {
IRBuilder builder;
// (x * 32) << 3
auto *x = builder.create_arg_load(0, get_data_type<int>(), false);
auto *product = builder.create_mul(x, builder.get_int32(32));
auto *result = builder.create_shl(product, builder.get_int32(3));
builder.create_return(result);
auto ir = builder.extract_ir();
ASSERT_TRUE(ir->is<Block>());
auto *ir_block = ir->as<Block>();
irpass::type_check(ir_block, CompileConfig());
EXPECT_EQ(ir_block->size(), 6);

irpass::alg_simp(ir_block, CompileConfig());
// -> (x << 5) << 3
irpass::binary_op_simplify(ir_block, CompileConfig());
// -> x << (5 + 3)
irpass::constant_fold(ir_block, CompileConfig(), {prog_.get()});
// -> x << 8
irpass::die(ir_block);

EXPECT_EQ(ir_block->size(), 4);
EXPECT_EQ(ir_block->statements[0].get(), x);
EXPECT_TRUE(ir_block->statements[1]->is<ConstStmt>());
auto *const_stmt = ir_block->statements[1]->as<ConstStmt>();
EXPECT_TRUE(is_integral(const_stmt->val[0].dt));
EXPECT_TRUE(is_signed(const_stmt->val[0].dt));
EXPECT_EQ(const_stmt->val[0].val_int(), 8);
EXPECT_TRUE(ir_block->statements[2]->is<BinaryOpStmt>());
auto *bin_op = ir_block->statements[2]->as<BinaryOpStmt>();
EXPECT_EQ(bin_op->op_type, BinaryOpType::bit_shl);
EXPECT_EQ(bin_op->rhs, const_stmt);
EXPECT_TRUE(ir_block->statements[3]->is<KernelReturnStmt>());
EXPECT_EQ(ir_block->statements[3]->as<KernelReturnStmt>()->value, bin_op);
}

} // namespace lang
} // namespace taichi