Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Move IRBank to a separate file #1897

Merged
merged 3 commits into from
Sep 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
334 changes: 0 additions & 334 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,189 +13,6 @@

TLANG_NAMESPACE_BEGIN

namespace {

uint64 hash(IRNode *stmt) {
TI_ASSERT(stmt);
// TODO: upgrade this using IR comparisons
std::string serialized;
irpass::re_id(stmt);
irpass::print(stmt, &serialized);
// TODO: separate kernel from IR template
serialized += stmt->get_kernel()->name;
uint64 ret = 0;
for (uint64 i = 0; i < serialized.size(); i++) {
ret = ret * 100000007UL + (uint64)serialized[i];
}
return ret;
}

} // namespace

uint64 IRBank::get_hash(IRNode *ir) {
auto result_iterator = hash_bank_.find(ir);
if (result_iterator == hash_bank_.end()) {
auto result = hash(ir);
set_hash(ir, result);
return result;
}
return result_iterator->second;
}

void IRBank::set_hash(IRNode *ir, uint64 hash) {
hash_bank_[ir] = hash;
}

bool IRBank::insert(std::unique_ptr<IRNode> &&ir, uint64 hash) {
IRHandle handle(ir.get(), hash);
auto insert_place = ir_bank_.find(handle);
if (insert_place == ir_bank_.end()) {
ir_bank_.emplace(handle, std::move(ir));
return true;
}
insert_to_trash_bin(std::move(ir));
return false;
}

void IRBank::insert_to_trash_bin(std::unique_ptr<IRNode> &&ir) {
trash_bin.push_back(std::move(ir));
}

IRNode *IRBank::find(IRHandle ir_handle) {
auto result = ir_bank_.find(ir_handle);
if (result == ir_bank_.end())
return nullptr;
return result->second.get();
}

IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) {
auto &result = fuse_bank_[std::make_pair(handle_a, handle_b)];
if (!result.empty()) {
// assume the kernel is always the same when the ir handles are the same
return result;
}

TI_TRACE("Begin uncached fusion");
// We are about to change both |task_a| and |task_b|. Clone them first.
auto cloned_task_a = handle_a.clone();
auto cloned_task_b = handle_b.clone();
auto task_a = cloned_task_a->as<OffloadedStmt>();
auto task_b = cloned_task_b->as<OffloadedStmt>();
// TODO: in certain cases this optimization can be wrong!
// Fuse task b into task_a
for (int j = 0; j < (int)task_b->body->size(); j++) {
task_a->body->insert(std::move(task_b->body->statements[j]));
}
task_b->body->statements.clear();

// replace all reference to the offloaded statement B to A
irpass::replace_all_usages_with(task_a, task_b, task_a);

irpass::full_simplify(task_a, /*after_lower_access=*/false, kernel);
// For now, re_id is necessary for the hash to be correct.
irpass::re_id(task_a);

auto h = get_hash(task_a);
result = IRHandle(task_a, h);
insert(std::move(cloned_task_a), h);

// TODO: since cloned_task_b->body is empty, can we remove this (i.e.,
// simply delete cloned_task_b here)?
insert_to_trash_bin(std::move(cloned_task_b));

return result;
}

// TODO: make this an IR pass
class ConstExprPropagation {
public:
static std::unordered_set<Stmt *> run(
Block *block,
const std::function<bool(Stmt *)> &is_const_seed) {
std::unordered_set<Stmt *> const_stmts;

auto is_const = [&](Stmt *stmt) {
if (is_const_seed(stmt)) {
return true;
} else {
return const_stmts.find(stmt) != const_stmts.end();
}
};

for (auto &s : block->statements) {
if (is_const(s.get())) {
const_stmts.insert(s.get());
} else if (auto binary = s->cast<BinaryOpStmt>()) {
if (is_const(binary->lhs) && is_const(binary->rhs)) {
const_stmts.insert(s.get());
}
} else if (auto unary = s->cast<UnaryOpStmt>()) {
if (is_const(unary->operand)) {
const_stmts.insert(s.get());
}
} else {
// TODO: ...
}
}

return const_stmts;
}
};

IRHandle IRBank::demote_activation(IRHandle handle) {
auto &result = demote_activation_bank_[handle];
if (!result.empty()) {
return result;
}

std::unique_ptr<IRNode> new_ir = handle.clone();

OffloadedStmt *offload = new_ir->as<OffloadedStmt>();
Block *body = offload->body.get();

auto snode = offload->snode;
TI_ASSERT(snode != nullptr);

// TODO: for now we only deal with the top level. Is there an easy way to
// extend this part?
auto consts = ConstExprPropagation::run(body, [](Stmt *stmt) {
if (stmt->is<ConstStmt>()) {
return true;
} else if (stmt->is<LoopIndexStmt>())
return true;
return false;
});

bool demoted = false;
for (int k = 0; k < (int)body->statements.size(); k++) {
Stmt *stmt = body->statements[k].get();
if (auto ptr = stmt->cast<GlobalPtrStmt>(); ptr && ptr->activate) {
bool can_demote = true;
// TODO: test input mask?
for (auto ind : ptr->indices) {
if (consts.find(ind) == consts.end()) {
// non-constant index
can_demote = false;
}
}
if (can_demote) {
ptr->activate = false;
demoted = true;
}
}
}

if (!demoted) {
// Nothing demoted. Simply delete new_ir when this function returns.
result = handle;
return result;
}

result = IRHandle(new_ir.get(), get_hash(new_ir.get()));
insert(std::move(new_ir), result.hash());
return result;
}

ParallelExecutor::ParallelExecutor(int num_threads)
: num_threads(num_threads),
status(ExecutorStatus::uninitialized),
Expand Down Expand Up @@ -381,157 +198,6 @@ void AsyncEngine::launch(Kernel *kernel, Context &context) {
}
}

TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
TI_AUTO_PROF
// TODO: this function should ideally take only an IRNode
static std::mutex mut;

std::lock_guard<std::mutex> guard(mut);

auto &meta_bank = ir_bank->meta_bank_;

if (meta_bank.find(t.ir_handle) != meta_bank.end()) {
return &meta_bank[t.ir_handle];
}

using namespace irpass::analysis;
TaskMeta meta;
// TODO: this is an abuse since it gathers nothing...
auto *root_stmt = t.stmt();
meta.name = t.kernel->name + "_" +
OffloadedStmt::task_type_name(root_stmt->task_type);
meta.type = root_stmt->task_type;
gather_statements(root_stmt, [&](Stmt *stmt) {
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (auto ptr = global_load->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_states.emplace(snode, AsyncState::Type::value);
}
}
}

// Note: since global store may only partially modify a value state, the
// result (which contains the modified and unmodified part) actually needs a
// read from the previous version of the value state.
//
// I.e.,
// output_value_state = merge(input_value_state, written_part)
//
// Therefore we include the value state in input_states.
//
// The only exception is that the task may completely overwrite the value
// state (e.g., for i in x: x[i] = 0). However, for now we are not yet
// able to detect that case, so we are being conservative here.

if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
if (auto ptr = global_store->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_states.emplace(snode, AsyncState::Type::value);
meta.output_states.emplace(snode, AsyncState::Type::value);
}
}
}
if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
if (auto ptr = global_atomic->dest->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.input_states.emplace(snode, AsyncState::Type::value);
meta.output_states.emplace(snode, AsyncState::Type::value);
}
}
}

if (auto ptr = stmt->cast<GlobalPtrStmt>()) {
if (ptr->activate) {
for (auto &snode : ptr->snodes.data) {
auto s = snode;
while (s) {
if (!s->is_path_all_dense) {
meta.input_states.emplace(s, AsyncState::Type::mask);
meta.output_states.emplace(s, AsyncState::Type::mask);
}
s = s->parent;
}
}
}
for (auto &snode : ptr->snodes.data) {
if (ptr->is_element_wise(snode)) {
if (meta.element_wise.find(snode) == meta.element_wise.end()) {
meta.element_wise[snode] = true;
}
} else {
meta.element_wise[snode] = false;
}
}
}
if (auto clear_list = stmt->cast<ClearListStmt>()) {
meta.output_states.emplace(clear_list->snode, AsyncState::Type::list);
}
// TODO: handle SNodeOpStmt etc.
return false;
});
if (root_stmt->task_type == OffloadedStmt::listgen) {
TI_ASSERT(root_stmt->snode->parent);
meta.snode = root_stmt->snode;
meta.input_states.emplace(root_stmt->snode->parent, AsyncState::Type::list);
meta.input_states.emplace(root_stmt->snode, AsyncState::Type::list);
meta.input_states.emplace(root_stmt->snode, AsyncState::Type::mask);
meta.output_states.emplace(root_stmt->snode, AsyncState::Type::list);
} else if (root_stmt->task_type == OffloadedStmt::struct_for) {
meta.snode = root_stmt->snode;
meta.input_states.emplace(root_stmt->snode, AsyncState::Type::list);
}

meta_bank[t.ir_handle] = meta;
return &meta_bank[t.ir_handle];
}

TaskFusionMeta get_task_fusion_meta(IRBank *bank, const TaskLaunchRecord &t) {
TI_AUTO_PROF
// TODO: this function should ideally take only an IRNode
auto &fusion_meta_bank = bank->fusion_meta_bank_;
if (fusion_meta_bank.find(t.ir_handle) != fusion_meta_bank.end()) {
return fusion_meta_bank[t.ir_handle];
}

TaskFusionMeta meta{};
if (t.kernel->is_accessor) {
// SNode accessors can't be fused.
// TODO: just avoid snode accessors going into the async engine
return fusion_meta_bank[t.ir_handle] = TaskFusionMeta();
}
meta.kernel = t.kernel;
if (t.kernel->args.empty() && t.kernel->rets.empty()) {
meta.kernel = nullptr;
}

auto *task = t.stmt();
meta.type = task->task_type;
if (task->task_type == OffloadedStmt::struct_for) {
meta.snode = task->snode;
meta.block_dim = task->block_dim;
} else if (task->task_type == OffloadedStmt::range_for) {
// TODO: a few problems with the range-for test condition:
// 1. This could incorrectly fuse two range-for kernels that have
// different sizes, but then the loop ranges get padded to the same
// power-of-two (E.g. maybe a side effect when a struct-for is demoted
// to range-for).
// 2. It has also fused range-fors that have the same linear range,
// but are of different dimensions of loop indices, e.g. (16, ) and
// (4, 4).
if (!task->const_begin || !task->const_end) {
// Do not fuse range-for tasks with variable ranges for now.
return fusion_meta_bank[t.ir_handle] = TaskFusionMeta();
}
meta.begin_value = task->begin_value;
meta.end_value = task->end_value;
} else if (task->task_type != OffloadedStmt::serial) {
// Do not fuse gc/listgen tasks.
return fusion_meta_bank[t.ir_handle] = TaskFusionMeta();
}
meta.fusible = true;
return fusion_meta_bank[t.ir_handle] = meta;
}

void AsyncEngine::enqueue(const TaskLaunchRecord &t) {
sfg->insert_task(t);
task_queue.push_back(t);
Expand Down
27 changes: 1 addition & 26 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,13 @@
#include "taichi/program/context.h"
#undef TI_RUNTIME_HOST
#include "taichi/program/async_utils.h"
#include "taichi/program/ir_bank.h"
#include "taichi/program/state_flow_graph.h"

TLANG_NAMESPACE_BEGIN

// TODO(yuanming-hu): split into multiple files

class IRBank {
public:
uint64 get_hash(IRNode *ir);
void set_hash(IRNode *ir, uint64 hash);

bool insert(std::unique_ptr<IRNode> &&ir, uint64 hash);
void insert_to_trash_bin(std::unique_ptr<IRNode> &&ir);
IRNode *find(IRHandle ir_handle);

// Fuse handle_b into handle_a
IRHandle fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel);

IRHandle demote_activation(IRHandle handle);

std::unordered_map<IRHandle, TaskMeta> meta_bank_;
std::unordered_map<IRHandle, TaskFusionMeta> fusion_meta_bank_;

private:
std::unordered_map<IRNode *, uint64> hash_bank_;
std::unordered_map<IRHandle, std::unique_ptr<IRNode>> ir_bank_;
std::vector<std::unique_ptr<IRNode>> trash_bin; // prevent IR from deleted
std::unordered_map<std::pair<IRHandle, IRHandle>, IRHandle> fuse_bank_;
std::unordered_map<IRHandle, IRHandle> demote_activation_bank_;
};

class ParallelExecutor {
public:
using TaskType = std::function<void()>;
Expand Down Expand Up @@ -185,7 +161,6 @@ class AsyncEngine {
std::vector<IRHandle> ir_handle_cached;
};

TaskMeta create_task_meta(const TaskLaunchRecord &t);
std::unordered_map<const Kernel *, KernelMeta> kernel_metas_;
// How many times we have synchronized
int sync_counter_{0};
Expand Down