Skip to content
Permalink
Browse files

Add pre-glow jit passes to remove exceptions and fuse Linear (#3756)

Summary:
These two passes remove the only control flow in the BERT jit model

Documentation:
doxygen
Pull Request resolved: #3756

Test Plan:
added unit tests
run on bert model and see that all control flow is gone

Differential Revision: D18400461

Pulled By: jackm321

fbshipit-source-id: 4cc213e3a17e8edc2d111b9047f6d694cdcf5cee
  • Loading branch information...
jackm321 authored and facebook-github-bot committed Nov 8, 2019
1 parent 01e89ef commit 50a5445118b1d6b9d44c491c7b0095a231266731
@@ -27,6 +27,77 @@

namespace glow {
namespace {
/// Registers an operator with symbol \p opName but with no implementation.
/// Dummy operators can be used by glow-specific fusion passes prior to loading
/// a glow graph in order to eliminate intermediate values that are unnecessary
/// to Glow such as those created by quantization packing nodes.
void registerDummyOperator(const char *opName) {
auto options = c10::OperatorOptions();
options.setAliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION);

torch::jit::RegisterOperators op({torch::jit::Operator(
at::Symbol::fromQualString(opName),
[](const torch::jit::Node *node) -> torch::jit::Operation {
LOG(FATAL) << "Operator \"" << (*node)
<< "\" has no implementation and is meant only as a "
"placeholder while fusing ops to run with Glow";
},
options)});
}

/// To judge if we can fuse the current node into a prim::FusedConcat.
bool isFusableConcatNode(torch::jit::Node *node) {
if (node->kind() != at::aten::cat ||
!node->is_constant(torch::jit::attr::dim)) {
return false;
}

auto inputNode = node->namedInput(torch::jit::attr::tensors)->node();
if (inputNode->kind() != at::prim::ListConstruct) {
return false;
}
if (inputNode->output()->uses().size() > 1) {
return false;
}

return true;
}

/// Add one prim::FusedConcat before current node.
/// The current node and the list construct node before it will become dead
/// node.
torch::jit::Node *createFusedConcat(torch::jit::Node *node) {
AT_ASSERT(node->kind() == at::aten::cat);
torch::jit::Graph *graph = node->owningGraph();
torch::jit::Node *inputNode =
node->namedInput(torch::jit::attr::tensors)->node();
int64_t dim = node->get<int64_t>(torch::jit::attr::dim).value();

torch::jit::Node *fusedConcatNode =
graph->create(at::prim::FusedConcat, inputNode->inputs())
->i_(torch::jit::attr::dim, dim);
fusedConcatNode->insertBefore(inputNode);
fusedConcatNode->output()->copyMetadata(node->output());
node->output()->replaceAllUsesWith(fusedConcatNode->output());

return inputNode;
}

void removeExceptionsImpl(torch::jit::Block *block) {
auto nodes = block->nodes().reverse();
for (auto it = nodes.begin(); it != nodes.end(); it++) {
if (it->kind() == torch::jit::prim::RaiseException) {
it.destroyCurrent();
continue;
}
for (auto *subblock : it->blocks()) {
removeExceptionsImpl(subblock);
}
}
}
} // namespace

namespace detail {
/// This pass fuse the quantized::conv_prepack + quantized::conv2d generated by
/// JIT back to quantized::unpacked_conv2d since we dont have
/// quantized::conv_prepack in glow. However regular packed conv's
@@ -84,87 +155,202 @@ graph(%input):
rewriter.runOnGraph(graph);
}

/// Registers an operator with symbol \p opName but with no implementation.
/// Dummy operators can be used by glow-specific fusion passes prior to loading
/// a glow graph in order to eliminate intermediate values that are unnecessary
/// to Glow such as those created by quantization packing nodes.
void registerDummyOperator(const char *opName) {
auto options = c10::OperatorOptions();
options.setAliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION);
void fuseConcat(std::shared_ptr<torch::jit::Graph> &graph) {
auto block = graph->block();
for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) {
if (isFusableConcatNode(*it)) {
// Here we relay on EliminateDeadCode to remove useless node in graph
// and we dont remove them directly
createFusedConcat(*it);
}
}
}

torch::jit::RegisterOperators op({torch::jit::Operator(
at::Symbol::fromQualString(opName),
[](const torch::jit::Node *node) -> torch::jit::Operation {
LOG(FATAL) << "Operator \"" << (*node)
<< "\" has no implementation and is meant only as a "
"placeholder while fusing ops to run with Glow";
},
options)});
void removeExceptions(std::shared_ptr<torch::jit::Graph> &graph) {
return removeExceptionsImpl(graph->block());
}

/// To judge if we can fuse the current node into a prim::FusedConcat.
bool isFusableConcatNode(torch::jit::Node *node) {
if (node->kind() != at::aten::cat ||
!node->is_constant(torch::jit::attr::dim)) {
return false;
}
void fuseBranchedLinearPattern(std::shared_ptr<torch::jit::Graph> &graph) {
// before:
// graph(%input, %weight, %bias, %c %d):
// %1 = aten::dim(%input)
// %2 = aten::eq(%1, %c)
// %3 = prim::If(%2)
// block0():
// %4 = aten::t(%weight)
// %5 = prim::Constant[value=1]()
// %6 = aten::mm(%input, %4)
// %7 = aten::add(%bias, %6, %5)
// -> (%7)
// block1():
// %8 = aten::t(%weight)
// %9 = aten::matmul(%input, %8)
// %10 : Tensor = aten::add_(%9, %bias, %d)
// -> (%10)
// return (%3)";
//
// after:
// graph(%input, %weight, %bias, %c %d):
// %1 = glow::fused_linear(%input, %weight, %bias, %c %d)
// return (%1)";

auto inputNode = node->namedInput(torch::jit::attr::tensors)->node();
if (inputNode->kind() != at::prim::ListConstruct) {
return false;
}
if (inputNode->output()->uses().size() > 1) {
return false;
}
auto nodes = graph->block()->nodes().reverse();
for (auto it = nodes.begin(); it != nodes.end(); it++) {
auto *ifNode = *it;
if (ifNode->kind() != torch::jit::prim::If) {
continue;
}

return true;
}
// Define all Values we need to find.
torch::jit::Value *inputValue = nullptr;
torch::jit::Value *cValue = nullptr;
torch::jit::Value *weightValue = nullptr;
torch::jit::Value *biasValue = nullptr;
torch::jit::Value *dValue = nullptr;

/// Add one prim::FusedConcat before current node.
/// The current node and the list construct node before it will become dead
/// node.
torch::jit::Node *createFusedConcat(torch::jit::Node *node) {
AT_ASSERT(node->kind() == at::aten::cat);
torch::jit::Graph *graph = node->owningGraph();
torch::jit::Node *inputNode =
node->namedInput(torch::jit::attr::tensors)->node();
int64_t dim = node->get<int64_t>(torch::jit::attr::dim).value();
// step 1: walk upwards from if to get values from aten::eq and aten::dim
{
// find aten::eq node input to prim::If node
auto *eqNode = ifNode->input()->node();
if (eqNode->kind() != torch::jit::aten::eq) {
continue;
}

torch::jit::Node *fusedConcatNode =
graph->create(at::prim::FusedConcat, inputNode->inputs())
->i_(torch::jit::attr::dim, dim);
fusedConcatNode->insertBefore(inputNode);
fusedConcatNode->output()->copyMetadata(node->output());
node->output()->replaceAllUsesWith(fusedConcatNode->output());
// find aten::dim node input to aten::eq node
torch::jit::Node *dimNode = nullptr;
auto eqInputs = eqNode->inputs();
if (eqInputs[0]->node()->kind() == torch::jit::aten::dim) {
dimNode = eqInputs[0]->node();
cValue = eqInputs[1];
} else if (eqInputs[1]->node()->kind() == torch::jit::aten::dim) {
dimNode = eqInputs[1]->node();
cValue = eqInputs[0];
} else {
continue;
}

return inputNode;
}
inputValue = dimNode->input();
}

/// Fuse PyTorch ListConstruct and Concat node intp prim::FusedConcat node
void fuseConcat(std::shared_ptr<torch::jit::Graph> &graph) {
auto block = graph->block();
for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) {
if (isFusableConcatNode(*it)) {
// Here we relay on EliminateDeadCode to remove useless node in graph
// and we dont remove them directly
createFusedConcat(*it);
// step 2: walk if-block collecting values and verifying structure
{
torch::jit::Value *tOutputValue = nullptr;
torch::jit::Value *mmOutputValue = nullptr;
torch::jit::Value *constantOutputValue = nullptr;
torch::jit::Value *addOutputValue = nullptr;
size_t numNodes = 0;
for (auto *node : ifNode->blocks()[0]->nodes()) {
numNodes++;
if (node->kind() == torch::jit::aten::t) {
weightValue = node->input();
tOutputValue = node->output();
} else if (node->kind() == torch::jit::prim::Constant) {
// Make sure the constant value is 1
if (node->output()->type()->kind() != torch::jit::TypeKind::IntType) {
continue;
}
if (node->i(torch::jit::attr::value) != 1) {
continue;
}
constantOutputValue = node->output();
} else if (node->kind() == torch::jit::aten::mm) {
// Get inputValue and check that second input is output of the aten::t
inputValue = node->inputs()[0];
if (node->inputs()[1] != tOutputValue) {
continue;
}
mmOutputValue = node->output();
} else if (node->kind() == torch::jit::aten::add) {
// Get biasValue and check that the second input is the output of mm
biasValue = node->inputs()[0];
if (node->inputs()[1] != mmOutputValue) {
continue;
}
addOutputValue = node->output();
} else {
continue;
}
}
if (!(tOutputValue && mmOutputValue && constantOutputValue &&
addOutputValue && numNodes == 4)) {
continue;
}
}

// step 3: walk else-block collecting values and verifying structure
{
torch::jit::Value *tOutputValue = nullptr;
torch::jit::Value *matmulOutputValue = nullptr;
torch::jit::Value *addOutputValue = nullptr;
size_t numNodes = 0;
for (auto *node : ifNode->blocks()[1]->nodes()) {
numNodes++;
if (node->kind() == torch::jit::aten::t) {
if (node->input() != weightValue) {
continue;
}
tOutputValue = node->output();
} else if (node->kind() == torch::jit::aten::matmul) {
if (node->inputs()[0] != inputValue) {
continue;
}
if (node->inputs()[1] != tOutputValue) {
continue;
}
matmulOutputValue = node->output();
} else if (node->kind() == torch::jit::aten::add_) {
if (node->inputs()[0] != matmulOutputValue) {
continue;
}
if (node->inputs()[1] != biasValue) {
continue;
}
dValue = node->inputs()[2];
addOutputValue = node->output();
}
}
if (!(tOutputValue && matmulOutputValue && addOutputValue &&
numNodes == 3)) {
continue;
}
}

// step 4: create a glow::fused_linear
assert(inputValue && weightValue && biasValue && cValue && dValue);

std::vector<torch::jit::Value *> fusedLinearInputs = {
inputValue, weightValue, biasValue, cValue, dValue};

auto *newNode =
graph->create(torch::jit::Symbol::fromQualString("glow::fused_linear"),
fusedLinearInputs, /*num_outputs*/ 1);

newNode->insertAfter(ifNode);
ifNode->replaceAllUsesWith(newNode);
}
}
} // namespace
} // namespace detail

void fuseKnownPatterns(std::shared_ptr<torch::jit::Graph> &graph) {
// Register dummy nodes used by custom fusers.
static std::once_flag onceFlag;
std::call_once(onceFlag, []() {
registerDummyOperator("glow::unpacked_quantized_linear");
registerDummyOperator("glow::unpacked_quantized_conv2d");
registerDummyOperator("glow::fused_linear");
});

fuseConcat(graph);
fuseConvPrepack(graph);
fuseLinearPrepack(graph);
fuseNumToTensorToNum(graph);
detail::removeExceptions(graph);
EliminateDeadCode(graph);

detail::fuseBranchedLinearPattern(graph);
EliminateDeadCode(graph);

detail::fuseConcat(graph);
detail::fuseConvPrepack(graph);
detail::fuseLinearPrepack(graph);
detail::fuseNumToTensorToNum(graph);

EliminateCommonSubexpression(graph);
EliminateDeadCode(graph);
}
@@ -22,6 +22,37 @@
namespace glow {
/// Fuse known node patterns in \p graph to assist the PyTorchModelLoader.
void fuseKnownPatterns(std::shared_ptr<torch::jit::Graph> &graph);

/// Passes in detail namespace should not be used directly except for by
/// unittests.
namespace detail {
/// Pass that removes all prim::RaiseException nodes from the \p graph.
void removeExceptions(std::shared_ptr<torch::jit::Graph> &graph);

/// Pass that fuses the output pattern of Linear module (which contains a branch
/// based on the dims of the input) in \p graph to a glow::fused_linear op so
/// that it can be loaded by Glow with the control flow happening at graph
/// compile time.
void fuseBranchedLinearPattern(std::shared_ptr<torch::jit::Graph> &graph);

/// Pass that fuses prim::ListConstruct -> aten::cat patterns in \p graph into
/// prim::FusedConcat node so that the number of tensors being concatenated is
/// known at graph compile time.
void fuseConcat(std::shared_ptr<torch::jit::Graph> &graph);

/// Pass that fuses quantized::conv_prepack -> quantized::conv2d patterns in \p
/// graph into glow::unpacked_quantized_conv2d.
void fuseConvPrepack(std::shared_ptr<torch::jit::Graph> &graph);

/// Pass that fuses quantized::linear_prepack -> quantized::linear patterns in
/// \p graph into glow::unpacked_quantized_linear.
void fuseLinearPrepack(std::shared_ptr<torch::jit::Graph> &graph);

/// Pass that eliminates prim::NumToTensor -> aten::Int patterns in
/// \p graph.
void fuseNumToTensorToNum(std::shared_ptr<torch::jit::Graph> &graph);
} // namespace detail

} // namespace glow

#endif // GLOW_TORCH_GLOW_SRC_FUSE_KNOWN_PATERNS_H

0 comments on commit 50a5445

Please sign in to comment.
You can’t perform that action at this time.