Skip to content
This repository has been archived by the owner on Apr 1, 2021. It is now read-only.

Commit

Permalink
[fusion] Migrate away from CustomFuseGraph
Browse files Browse the repository at this point in the history
gh-metadata: pytorch tvm 72 gh/bwasti/46/head

Pull Request resolved: #72
  • Loading branch information
bwasti committed Aug 6, 2019
1 parent 8139fc1 commit 92f99a1
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 15 deletions.
24 changes: 22 additions & 2 deletions torch_tvm/compiler.cpp
Expand Up @@ -32,7 +32,16 @@ tvm::relay::DataType scalarTypeToTVMType(at::ScalarType pt_type) {
tvm::relay::Var TVMCompiler::convertToRelay(Value* val, TVMContext ctx) {
auto optional_ivalue = toIValue(val);
if (optional_ivalue.has_value()) {
val->inferTypeFrom(optional_ivalue.value().toTensor());
if (optional_ivalue.value().isTensor()) {
val->inferTypeFrom(optional_ivalue.value().toTensor());
} else {
auto expr = convertToRelay(optional_ivalue.value(), ctx)
.as<tvm::relay::ConstantNode>();
return tvm::relay::VarNode::make(
val->debugName() +
std::to_string(reinterpret_cast<std::uintptr_t>(val)),
expr->tensor_type());
}
}
if (val->isCompleteTensor()) {
auto pt_t = val->type()->cast<CompleteTensorType>();
Expand Down Expand Up @@ -247,7 +256,18 @@ void TVMCompiler::run(Stack& stack) {

if (cache_.find(spec) == cache_.end()) {
for (auto& kv : value_to_ivalue) {
kv.first->inferTypeFrom(kv.second.toTensor());
if (kv.second.isTensor()) {
kv.first->inferTypeFrom(kv.second.toTensor());
} else if (kv.second.isInt()) {
kv.first->setType(IntType::get());
} else {
AT_CHECK(
0,
"Cannot handle this type yet ",
kv.second,
"\nGraph:\n",
*subgraph_);
}
}
// bail out mechanism: try to convert to Relay, if it fails to convert the
// graph by any reason(i.e. op difference), depend on the user preference,
Expand Down
134 changes: 134 additions & 0 deletions torch_tvm/fusion_pass.cpp
@@ -0,0 +1,134 @@
#include "fusion_pass.h"
#include "operators.h"

#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>

using namespace torch::jit;

value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* block) {
value_list result;
for (auto i : inputs) {
if (i->node()->owningBlock() == block) {
result.push_back(i);
}
}
// Sort in reverse topological order
std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
return a->node()->isAfter(b->node());
});
return result;
}

bool canHandle(Block* block, AliasDb& aliasDb);
bool canHandle(Node* node, AliasDb& aliasDb) {
if (node->kind() == prim::Constant) {
return true;
}
if (node->kind() == prim::Loop) {
return false; // TODO
Block* body = node->blocks().at(0);
return canHandle(body, aliasDb);
}
return isSupported(node);
}

bool canHandle(Block* block, AliasDb& aliasDb) {
for (Node* node : block->nodes()) {
if (!canHandle(node, aliasDb)) {
return false;
}
}
return true;
}

#define REQ(cond) \
if (!(cond)) { \
GRAPH_DEBUG("Failed cond " #cond "\n"); \
return c10::nullopt; \
}
c10::optional<Node*> tryMerge(
Node* consumer,
Node* producer,
AliasDb& aliasDb) {
GRAPH_DEBUG(
"Trying producer ",
producer->kind().toQualString(),
" and consumer ",
consumer->kind().toQualString(),
":\n");

// Symbolic checks
REQ(canHandle(producer, aliasDb));
REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTVMSymbol()));

// Alias checks
// Requirement:
// - moveAfterTopologicallyValid(consumer, producer)
// - One of:
// 1) Both are in-place ops
// 2) Consumer is in-place, producer !hasInputWriters
// 3) Producer is in-place, consumer !hasOutputWriters
REQ(aliasDb.moveAfterTopologicallyValid(consumer, producer));

// 1)
if (!(aliasDb.isMutable(consumer) && aliasDb.isMutable(producer))) {
// 2)
if (aliasDb.isMutable(consumer)) {
REQ(!aliasDb.hasInputWriters(producer));
// 3)
} else if (aliasDb.isMutable(producer)) {
REQ(!aliasDb.hasOutputWriters(consumer));
}
}

if (!consumer->hasAttribute(attr::Subgraph) &&
consumer->kind() != getTVMSymbol()) {
consumer = SubgraphUtils::createSingletonSubgraph(consumer, getTVMSymbol());
}
if (producer->kind() == prim::Constant) {
auto& subgraph = consumer->g(attr::Subgraph);
Node* in_const = subgraph->createClone(producer, [](Value*) -> Value* {
throw std::runtime_error("unexpected input");
});
subgraph->insertNode(in_const);
} else {
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
}

return consumer;
}
#undef REQ

graph_node_list::iterator scanNode(
Node* consumer,
AliasDb& aliasDb,
Block* block) {
auto inputs = sortReverseTopological(consumer->inputs(), block);
for (auto input : inputs) {
if (auto group = tryMerge(consumer, input->node(), aliasDb)) {
// we successfully merged, so the new group's `inputs` may have
// changed. So rescan the new group for more merging opportunities.
return group.value()->reverseIterator();
}
}
return ++consumer->reverseIterator();
}

void FuseSupportedOps(std::shared_ptr<Graph> graph) {
AliasDb aliasDb(graph);
auto block = graph->block();

for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
it = scanNode(*it, aliasDb, block);
}
EliminateCommonSubexpression(graph);
EliminateDeadCode(graph);
}

const torch::jit::Symbol& getTVMSymbol() {
static torch::jit::Symbol tvm_sym =
torch::jit::Symbol::fromQualString("tvm::CompilationGroup");
return tvm_sym;
}
7 changes: 7 additions & 0 deletions torch_tvm/fusion_pass.h
@@ -0,0 +1,7 @@
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>

void FuseSupportedOps(std::shared_ptr<torch::jit::Graph> graph);

const torch::jit::Symbol& getTVMSymbol();
10 changes: 3 additions & 7 deletions torch_tvm/operators.cpp
Expand Up @@ -2,6 +2,7 @@
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include "compiler.h"
#include "fusion_pass.h" // tvm_sym

#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/jit/custom_operator.h>
Expand Down Expand Up @@ -44,8 +45,7 @@ RegisterTVMOperator::RegisterTVMOperator(std::vector<TVMOpMap> ops) {
wrapper_graph.appendNode(node);
wrapper_graph.registerOutput(node->output());

Symbol tvm_sym = Symbol::fromQualString("tvm::CompilationGroup");
node = SubgraphUtils::createSingletonSubgraph(node, tvm_sym);
node = SubgraphUtils::createSingletonSubgraph(node, getTVMSymbol());
auto cc = std::make_shared<TVMCompiler>(node);

// NB: We assume all relay ops are pure
Expand Down Expand Up @@ -424,11 +424,7 @@ RegisterTVMOperator reg({

bool isSupported(Node* node) {
auto map = getTVMOperatorMap();
auto can_handle = map.find(node->kind()) != map.end();
if (node->kind() == prim::Constant) {
can_handle = true;
}
return can_handle;
return map.find(node->kind()) != map.end();
}

tvm::relay::Expr getOperator(Node* node, tvm::Array<tvm::relay::Expr> inputs) {
Expand Down
10 changes: 4 additions & 6 deletions torch_tvm/register.cpp
Expand Up @@ -3,12 +3,11 @@
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/operator_options.h>
#include <torch/csrc/jit/pass_manager.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/pybind_utils.h>

#include "compiler.h"
#include "operators.h"
#include "fuse_linear.h"
#include "fusion_pass.h"

namespace py = pybind11;
using namespace torch::jit;
Expand All @@ -23,7 +22,6 @@ static int opt_level = 2;
static std::string device_type = "cpu";
static std::string device = "llvm -mcpu=core-avx2";
static std::string host = "llvm -mcpu=core-avx2";
static auto tvm_sym = Symbol::fromQualString("tvm::CompilationGroup");

static std::unordered_map<size_t, tvm::relay::Expr> relay_exprs;
static size_t relay_exprs_uuid = 0;
Expand All @@ -33,7 +31,7 @@ PYBIND11_MODULE(_torch_tvm, m) {
auto options = c10::OperatorOptions();
options.setAliasAnalysis(AliasAnalysisKind::PURE);
RegisterOperators op({Operator(
tvm_sym,
getTVMSymbol(),
[](const Node* node) {
auto cc = std::make_shared<TVMCompiler>(
node, opt_level, strict, device_type, device, host);
Expand All @@ -50,7 +48,7 @@ PYBIND11_MODULE(_torch_tvm, m) {
RegisterPass pass([](std::shared_ptr<Graph>& g) {
if (fusion_enabled) {
FuseLinear(g);
CustomFuseGraph(g, isSupported, tvm_sym);
FuseSupportedOps(g);
}
});

Expand Down Expand Up @@ -88,7 +86,7 @@ PYBIND11_MODULE(_torch_tvm, m) {
count == 1,
"This program cannot be exported as a single Relay expression.");
for (auto node : g->nodes()) {
if (node->kind() == tvm_sym) {
if (node->kind() == getTVMSymbol()) {
std::vector<Value*> v;
auto subgraph = node->g(attr::Subgraph);
TORCH_CHECK(
Expand Down

0 comments on commit 92f99a1

Please sign in to comment.