From af3964a8725236c78ce969b827fdeee1c5c54110 Mon Sep 17 00:00:00 2001 From: Dmytro Dzhulgakov Date: Sun, 3 Dec 2017 23:55:09 -0800 Subject: [PATCH] Backport transposes optimization to v0.3.0 (#3994) * Optimizer: optimize transposes in variety of circumstances (#3509) * Optimizer: Optimize transposes in variety of circumstances - No-op transposes - Consecutive transposes (fuse them) - Transposes into Gemm (fuse them into transA/transB parameter) * touch up out of date comment * Backporting optimizer changes --- torch/csrc/jit/attributes.h | 2 +- torch/csrc/jit/interned_strings.h | 2 + torch/csrc/jit/passes/onnx/peephole.cpp | 79 +++++++++++++++++++++++++ torch/csrc/jit/passes/peephole.cpp | 1 + 4 files changed, 83 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/attributes.h b/torch/csrc/jit/attributes.h index e66e91d2d3035..52775a5515493 100644 --- a/torch/csrc/jit/attributes.h +++ b/torch/csrc/jit/attributes.h @@ -78,7 +78,7 @@ using GraphsAttr = VectorAttributeValue,AttributeKind::gs // CRTP so that Node which inherits Attributes can be return for // method chaining e.g: -// Node * n = g->create(kSelect)->set_i(kOffset,3)->set_f(kValue,3.5); +// Node * n = g->create(kSelect)->i_(kOffset,3)->f_(kValue,3.5); // we return Derived* pointers because Nodes are normally held as pointers. template struct Attributes { diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index e53a93a65f7a2..42d01310cdfe7 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -74,6 +74,8 @@ _(shape) \ _(axes) \ _(group) \ _(inplace) \ +_(transA) \ +_(transB) \ _(other) enum BuiltinSymbol { diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 34187d28e6973..5758ec921c856 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -15,6 +15,27 @@ std::unordered_set broadcasting = { kGemm, }; +bool isNopTranspose(const std::vector & perm) { + for (size_t i = 0; i < perm.size(); i++) + if (perm[i] != i) + return false; + return true; +} + +// returns a vector `ret` such that transposing by `ret` is equivalent +// to transposing by `t1` and then by `t2` +std::vector composeTransposes(const std::vector & t1, + const std::vector & t2) { + JIT_ASSERT(t1.size() == t2.size()); + std::vector ret; + for (size_t i = 0; i < t1.size(); i++) { + JIT_ASSERT( t1[i] < t2.size()); + JIT_ASSERT(t2[t1[i]] < t2.size()); + ret.push_back(t2[t1[i]]); + } + return ret; +} + bool isBroadcasting(Node *node) { return broadcasting.count(node->kind()); } @@ -93,6 +114,58 @@ void fuseBroadcast(std::shared_ptr& graph) { } } +void fuseConsecutiveTransposes(std::shared_ptr& graph) { + for (auto it = graph->begin(); it != graph->end(); ++it) { + auto* n = *it; + + if (n->kind() == kTranspose && n->input()->kind() == kTranspose) { + auto origInput = n->input(); + n->is_(kperm, composeTransposes(origInput->is(kperm), n->is(kperm))); + n->replaceInput(0, origInput->input()); + if (origInput->uses().size() == 0) { + origInput->destroy(); + } + continue; + } + } +} + +void eliminateNopTranspose(std::shared_ptr& graph) { + for (auto it = graph->begin(); it != graph->end(); ++it) { + auto* n = *it; + + if (n->kind() == kTranspose) { + if (isNopTranspose(n->is(kperm))) { + n->replaceAllUsesWith(n->input()); + it.destroyCurrent(); + continue; + } + } + } +} + +void fuseTransposeIntoGemm(std::shared_ptr& graph) { + static const std::vector simpleTransPerm({1,0}); + + for (auto it = graph->begin(); it != graph->end(); ++it) { + auto* n = *it; + + if (n->kind() == kGemm) { + for (size_t i : {0,1}) { + auto inp = n->inputs()[i]; + auto trans = i == 0 ? ktransA : ktransB; + if (inp->kind() == kTranspose && inp->is(kperm) == simpleTransPerm) { + n->replaceInput(i, inp->input()); + n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1); + if (inp->uses().size() == 0) { + inp->destroy(); + } + } + } + } + } +} + // This optimization does ONNX-specific peephole optimizations. // // At the moment, here are the optimizations it does: @@ -100,6 +173,9 @@ void fuseBroadcast(std::shared_ptr& graph) { // easier for non-strided backends to more efficiently do broadcasts if this is // local information. This optimization is not useful for PyTorch as 'expand' // is free. +// - Fusing of consecutive transposes +// - Elimiation of NOP transposes +// - Fusing of transposes into Gemm // // Before you write an optimization here, ask yourself, "Could I do this // optimization on ATen operators"? If so, you should seriously consider @@ -111,6 +187,9 @@ void PeepholeOptimizeONNX(std::shared_ptr& graph) { // TODO: make it easier not to do O(k) iterations over the graph, where // k is the number of distinct peephole optimizations fuseBroadcast(graph); + fuseConsecutiveTransposes(graph); + eliminateNopTranspose(graph); + fuseTransposeIntoGemm(graph); } }} diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index f7832dbeac890..d9e5d4ab81be7 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -13,6 +13,7 @@ void PeepholeOptimize(std::shared_ptr& graph) { for (auto it = graph->begin(); it != graph->end(); ++it) { auto* n = *it; + // eliminate redundant expand if (n->kind() == kexpand) { if (n->is(ksize) == n->input()->type()->expect()->sizes()) { n->replaceAllUsesWith(n->input());