Skip to content

Commit

Permalink
[static runtime] add static subgraph fusion pass (#49185)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49185

This diff adds a fusion feature that will let us use static runtime for *parts* of the graph.  This will prove useful in cases where fully eliminating control flow is hard etc.

TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case

the loop example looks quite good:
```
graph(%a.1 : Tensor,
      %b.1 : Tensor,
      %iters.1 : int):
  %12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
  %c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
  %c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
    block0(%i : int, %c.12 : Tensor):
      %c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
      -> (%12, %c.10)
  return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
      %4 : Tensor):
  %5 : int = prim::Constant[value=2]()
  %6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
  %2 : int = prim::Constant[value=1]()
  %c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
  return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
      %7 : Tensor,
      %8 : Tensor):
  %9 : int = prim::Constant[value=1]()
  %c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
  %5 : int = prim::Constant[value=2]()
  %c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
  %2 : int = prim::Constant[value=1]()
  %c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
  return (%c.10)
```

(Note: this ignores all push blocking failures!)

Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest

buck test mode/no-gpu caffe2/test:static_runtime

Reviewed By: bertmaher

Differential Revision: D25385702

fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
  • Loading branch information
bwasti authored and facebook-github-bot committed Dec 10, 2020
1 parent 95a1725 commit f4226b5
Show file tree
Hide file tree
Showing 12 changed files with 406 additions and 4 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace c10 {
_(prim, FunctionalGraph) \
_(prim, DifferentiableGraph) \
_(prim, TensorExprGroup) \
_(prim, StaticSubgraph) \
_(prim, If) \
_(prim, Jump) /* debug */ \
_(prim, JumpNZ) /* debug */ \
Expand Down
32 changes: 32 additions & 0 deletions benchmarks/static_runtime/test_static_runtime.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <gtest/gtest.h>
#include <torch/csrc/jit/runtime/static/fusion.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include "deep_wide_pt.h"
#include "test_scripts.h"
Expand Down Expand Up @@ -249,3 +250,34 @@ TEST(StaticRuntime, CleanUpMemory) {
}
}
}

TEST(StaticRuntime, FusionPass) {
const int embedding_size = 32;
const int num_features = 50;
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
torch::jit::Module module = getDeepAndWideSciptModel();
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});

// run jit graph executor
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
auto output_1 = getTensor(module.forward(inputs));

Method method = module.get_method("forward");
auto graph = method.graph();
fuseStaticSubgraphs(graph);
bool hit = false;
for (const auto& n : module.get_method("forward").graph()->nodes()) {
if (n->kind() == torch::jit::prim::StaticSubgraph) {
hit = true;
}
}
EXPECT_TRUE(hit);
auto output_2 = getTensor(module.forward(inputs));
EXPECT_TRUE(output_1.equal(output_2));
}
}
}

73 changes: 73 additions & 0 deletions test/test_static_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s

def loop_graph(a, b, iters : int):
c = a + b * 2
for i in range(iters):
c = c + b
c *= 2
c -= a
return c

def output_graph(a, b, c, iters : int):
s = torch.tensor([[3, 3], [3, 3]])
k = a + b * c + s
d : Dict[int, Tensor] = {}
for i in range(iters):
d[i] = k + i
return d

class TestStaticRuntime(TestCase):
def test_multihead_attention_layer(self):
Expand Down Expand Up @@ -203,5 +218,63 @@ def test_leaky_relu(self):
o_test = tg_a(s)[0]
torch.testing.assert_allclose(o_ref, o_test)

def test_fusion_trivial_graph(self):
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
torch._C._fuse_to_static_runtime(tg.graph)
assert "StaticSubgraph" in str(tg.graph)
o_test = tg(s, s, s)
torch.testing.assert_allclose(o_ref, o_test)

def test_fusion_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)

attention.eval()
attention = torch.jit.script(attention)
attention.eval()
o_ref = attention(src, src, src, src_mask)

torch._C._fuse_to_static_runtime(attention._c)
o_test = attention(src, src, src, src_mask)

for a, b in zip(o_ref, o_test):
torch.testing.assert_allclose(a, b)

def test_fusion_loop(self):
a = torch.randn(5, 5)
b = torch.randn(5, 5)
c = 4
lg = torch.jit.script(loop_graph)
o_ref = lg(a, b, c)
torch._C._fuse_to_static_runtime(lg.graph)
assert "StaticSubgraph" in str(lg.graph)
o_test = lg(a, b, c)
torch.testing.assert_allclose(o_ref, o_test)

def test_fusion_outputs(self):
a = torch.randn(2, 2)
b = torch.randn(2, 2)
c = 4
og = torch.jit.script(output_graph)
o_ref = og(a, b, b, c)
torch._C._fuse_to_static_runtime(og.graph)
assert "StaticSubgraph" in str(og.graph)
o_test = og(a, b, b, c)
for i in o_ref.keys():
torch.testing.assert_allclose(o_ref[i], o_test[i])



if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ core_sources_full_mobile = [
]

core_sources_full = core_sources_full_mobile + [
"torch/csrc/jit/runtime/static/fusion.cpp",
"torch/csrc/jit/runtime/static/impl.cpp",
"torch/csrc/jit/runtime/static/ops.cpp",
"torch/csrc/jit/runtime/static/passes.cpp",
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/ir/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ void AliasDb::analyzeImpl(Node* node) {
return analyzeGradOf(node);
// TODO: think more about TensorExpr alias correctness
case prim::TensorExprGroup:
case prim::StaticSubgraph:
case prim::Constant:
case prim::AutogradZero:
case prim::AutogradAdd:
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/runtime/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ struct CanEmitInline {
// by the later BailOut in createBailoutBlock and its jf_index
// will become invalid.
v->node()->kind() != prim::TensorExprGroup &&
v->node()->kind() != prim::StaticSubgraph &&
v->node()->kind() != prim::CudaFusionGroup &&
v->node()->kind() != prim::FusionGroup &&
v->node()->kind() != prim::BailOut && v->uses().size() == 1 &&
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ bool printerHasSpecialCaseFor(Symbol sym) {
prim::CudaFusionGroup, // optimization pass adds it
prim::CudaFusionGuard, // optimization pass adds it
prim::TensorExprGroup, // optimization pass adds it
prim::StaticSubgraph, // optimization pass adds it
prim::Load, // used in interpreter only
prim::MMTreeReduce, // used as an optimization
prim::MMBatchSide, // used as an optimization
Expand Down Expand Up @@ -276,6 +277,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
prim::CudaFusionGroup,
prim::DifferentiableGraph,
prim::TensorExprGroup,
prim::StaticSubgraph,
prim::FunctionalGraph,
prim::Constant,
prim::Uninitialized,
Expand Down

0 comments on commit f4226b5

Please sign in to comment.