Skip to content

Commit

Permalink
Prevent JIT from overspecializing to every single size configuration (#…
Browse files Browse the repository at this point in the history
…10844)

Summary:
Please review the expects carefully to make sure there are no regressions. I tried to go over them one by one when they changed, but it's sometimes easy to miss finer details.

Summary of changes:

- Renamed `TensorType` to `CompleteTensorType`. Added a new `TensorType` which records only the scalar type, number of dimensions, and device of a value. The argument behind the rename is to encourage people to use `CompleteTensorType` less, as most passes will only have limited information available. To make transition easier `complete_type->cast<TensorType>()` works, and makes our passes work with both kinds of specialization if they don't need extra the extra detail.
- Renamed `ArgumentSpec` to `CompleteArgumentSpec`. Added a new `ArgumentSpec`, which matches argument only at the level of the new `TensorType`.
- Shape analysis can process graphs with both `CompleteTensorType` and `TensorType`.
- Fuser was a part that heavily relied on full shape information being available. Now, we simply try to fuse the largest possible graphs, and have to do run-time checks to make sure they match the code we generate. If they don't, we fall back to regular interpretation. The shape checks are implementing using an optimized method exploiting algebraic properties of shapes with broadcasting, and the relations of broadcasting with pointwise ops. A full written proof of correctness of the shape checking algorithm is included in a comment in `graph_fuser.cpp`.

zdevito ezyang mruberry ngimel csarofeen
Pull Request resolved: #10844

Differential Revision: D9498705

Pulled By: apaszke

fbshipit-source-id: 0c53c2fcebd871cc2a29c260f8d012276479cc61
  • Loading branch information
apaszke authored and facebook-github-bot committed Aug 26, 2018
1 parent 9679fc5 commit c8b246a
Show file tree
Hide file tree
Showing 53 changed files with 1,746 additions and 1,057 deletions.
24 changes: 10 additions & 14 deletions test/expect/TestJit.test_broadcast_fusion_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
graph(%0 : Float(4, 4)
%1 : Float(4)
%2 : Float(4)) {
%3 : int[] = prim::Constant[value=[4, 4]]()
%4 : int = prim::Constant[value=0]()
%5 : Float(4!, 4) = aten::expand(%1, %3, %4)
%6 : Float(4!, 4) = aten::expand(%2, %3, %4)
%7 : Float(4, 4) = prim::FusionGroup_0[device=0](%6, %0, %5)
return (%7);
graph(%0 : Float(*, *)
%1 : Float(*)
%2 : Float(*)) {
%3 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %0, %1)
return (%3);
}
with prim::FusionGroup_0 = graph(%1 : Float(4!, 4)
%4 : Float(4, 4)
%5 : Float(4!, 4)) {
%6 : Float(4, 4) = aten::mul(%4, %5)
with prim::FusionGroup_0 = graph(%1 : Float(*)
%4 : Float(*, *)
%5 : Float(*)) {
%6 : Float(*, *) = aten::mul(%4, %5)
%2 : int = prim::Constant[value=1]()
%3 : Float(4, 4) = aten::add(%6, %1, %2)
%3 : Float(*, *) = aten::add(%6, %1, %2)
return (%3);
}
16 changes: 8 additions & 8 deletions test/expect/TestJit.test_concat_fusion.expect
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
graph(%0 : Float(3, 20)
%1 : Float(3, 20)) {
%2 : Float(6, 20) = prim::FusionGroup_0[device=0](%0, %1)
graph(%0 : Float(*, *)
%1 : Float(*, *)) {
%2 : Float(*, *) = prim::FusionGroup_0[device=0](%0, %1)
return (%2);
}
with prim::FusionGroup_0 = graph(%3 : Float(3, 20)
%4 : Float(3, 20)) {
with prim::FusionGroup_0 = graph(%3 : Float(*, *)
%4 : Float(*, *)) {
%6 : int = prim::Constant[value=1]()
%7 : Float(3, 20) = aten::add(%3, %4, %6)
%5 : Float(3, 20) = aten::mul(%3, %4)
%2 : Float(6, 20) = prim::FusedConcat[dim=0](%7, %5)
%7 : Float(*, *) = aten::add(%3, %4, %6)
%5 : Float(*, *) = aten::mul(%3, %4)
%2 : Float(*, *) = prim::FusedConcat[dim=0](%7, %5)
return (%2);
}
20 changes: 10 additions & 10 deletions test/expect/TestJit.test_concat_fusion_invariant_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
graph(%0 : Float(2, 2)
%1 : Float(2, 2)
%2 : Float(4, 2)) {
graph(%0 : Float(*, *)
%1 : Float(*, *)
%2 : Float(*, *)) {
%3 : int = prim::Constant[value=1]()
%4 : Float(4, 2) = prim::FusionGroup_0[device=0](%0, %1)
%5 : Float(4, 2) = aten::add(%4, %2, %3)
%4 : Float(*, *) = prim::FusionGroup_0[device=0](%0, %1)
%5 : Float(*, *) = aten::add(%4, %2, %3)
return (%5);
}
with prim::FusionGroup_0 = graph(%3 : Float(2, 2)
%4 : Float(2, 2)) {
with prim::FusionGroup_0 = graph(%3 : Float(*, *)
%4 : Float(*, *)) {
%7 : int = prim::Constant[value=1]()
%8 : Float(2, 2) = aten::add(%3, %4, %7)
%8 : Float(*, *) = aten::add(%3, %4, %7)
%5 : int = prim::Constant[value=1]()
%6 : Float(2, 2) = aten::sub(%3, %4, %5)
%2 : Float(4, 2) = prim::FusedConcat[dim=0](%8, %6)
%6 : Float(*, *) = aten::sub(%3, %4, %5)
%2 : Float(*, *) = prim::FusedConcat[dim=0](%8, %6)
return (%2);
}
20 changes: 10 additions & 10 deletions test/expect/TestJit.test_fuse_last_device.expect
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
graph(%0 : Float(1)
%1 : Float(1)) {
%2 : Float(1) = prim::FusionGroup_0[device=1](%0, %1)
graph(%0 : Float(*)
%1 : Float(*)) {
%2 : Float(*) = prim::FusionGroup_0[device=1](%0, %1)
return (%2);
}
with prim::FusionGroup_0 = graph(%5 : Float(1)
%10 : Float(1)) {
with prim::FusionGroup_0 = graph(%5 : Float(*)
%10 : Float(*)) {
%11 : int = prim::Constant[value=1]()
%12 : Float(1) = aten::add(%5, %10, %11)
%9 : Float(1) = aten::mul(%5, %12)
%12 : Float(*) = aten::add(%5, %10, %11)
%9 : Float(*) = aten::mul(%5, %12)
%6 : int = prim::Constant[value=1]()
%7 : Float(1) = aten::add(%9, %5, %6)
%3 : Float(1) = aten::tanh(%7)
%1 : Float(1) = aten::sigmoid(%3)
%7 : Float(*) = aten::add(%9, %5, %6)
%3 : Float(*) = aten::tanh(%7)
%1 : Float(*) = aten::sigmoid(%3)
return (%1);
}
20 changes: 11 additions & 9 deletions test/expect/TestJit.test_fusion_distribute.expect
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
graph(%0 : Float(4, 4)
%1 : Float(4, 4)) {
%2 : Float(4, 2) = prim::FusionGroup_0[device=0](%0, %1)
return (%2);
graph(%0 : Float(*, *)
%1 : Float(*, *)) {
%2 : Dynamic[] = prim::ListConstruct(%0, %1)
%3 : Dynamic, %4 : Dynamic = aten::broadcast_tensors(%2)
%5 : Float(*, *) = prim::FusionGroup_0[device=0](%3, %4)
return (%5);
}
with prim::FusionGroup_0 = graph(%11 : Float(4, 4)
%14 : Float(4, 4)) {
with prim::FusionGroup_0 = graph(%11 : Dynamic
%14 : Dynamic) {
%15 : Dynamic, %16 : Dynamic = prim::FusedChunk[chunks=2, dim=1](%14)
%12 : Dynamic, %13 : Dynamic = prim::FusedChunk[chunks=2, dim=1](%11)
%9 : int = prim::Constant[value=1]()
%10 : Float(4, 2) = aten::add(%13, %16, %9)
%10 : Float(*, *) = aten::add(%13, %16, %9)
%5 : int = prim::Constant[value=1]()
%6 : Float(4, 2) = aten::add(%12, %15, %5)
%2 : Float(4, 2) = aten::mul(%6, %10)
%6 : Float(*, *) = aten::add(%12, %15, %5)
%2 : Float(*, *) = aten::mul(%6, %10)
return (%2);
}
62 changes: 32 additions & 30 deletions test/expect/TestJit.test_lstm_fusion_concat.expect
Original file line number Diff line number Diff line change
@@ -1,41 +1,43 @@
graph(%0 : Float(3, 10)
%1 : Float(3, 20)
%2 : Float(3, 20)
%3 : Float(80, 10)
%4 : Float(80, 20)
%5 : Float(80)
%6 : Float(80)) {
%7 : Float(10!, 80!) = aten::t(%3)
graph(%0 : Float(*, *)
%1 : Float(*, *)
%2 : Float(*, *)
%3 : Float(*, *)
%4 : Float(*, *)
%5 : Float(*)
%6 : Float(*)) {
%7 : Float(*, *) = aten::t(%3)
%8 : int = prim::Constant[value=1]()
%9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8)
%10 : Float(20!, 80!) = aten::t(%4)
%11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8)
%12 : Float(6, 20) = prim::FusionGroup_0[device=0](%2, %9, %11)
return (%12);
%9 : Float(*, *) = aten::addmm(%5, %0, %7, %8, %8)
%10 : Float(*, *) = aten::t(%4)
%11 : Float(*, *) = aten::addmm(%6, %1, %10, %8, %8)
%12 : Dynamic[] = prim::ListConstruct(%9, %11)
%13 : Dynamic, %14 : Dynamic = aten::broadcast_tensors(%12)
%15 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %13, %14)
return (%15);
}
with prim::FusionGroup_0 = graph(%15 : Float(3, 20)
%41 : Float(3, 80)
%46 : Float(3, 80)) {
with prim::FusionGroup_0 = graph(%15 : Float(*, *)
%41 : Dynamic
%46 : Dynamic) {
%47 : Dynamic, %48 : Dynamic, %49 : Dynamic, %50 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%46)
%42 : Dynamic, %43 : Dynamic, %44 : Dynamic, %45 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%41)
%39 : int = prim::Constant[value=1]()
%40 : Float(3, 20) = aten::add(%42, %47, %39)
%40 : Float(*, *) = aten::add(%42, %47, %39)
%35 : int = prim::Constant[value=1]()
%36 : Float(3, 20) = aten::add(%43, %48, %35)
%36 : Float(*, *) = aten::add(%43, %48, %35)
%31 : int = prim::Constant[value=1]()
%32 : Float(3, 20) = aten::add(%44, %49, %31)
%32 : Float(*, *) = aten::add(%44, %49, %31)
%27 : int = prim::Constant[value=1]()
%28 : Float(3, 20) = aten::add(%45, %50, %27)
%24 : Float(3, 20) = aten::sigmoid(%40)
%22 : Float(3, 20) = aten::sigmoid(%36)
%20 : Float(3, 20) = aten::tanh(%32)
%18 : Float(3, 20) = aten::sigmoid(%28)
%16 : Float(3, 20) = aten::mul(%22, %15)
%13 : Float(3, 20) = aten::mul(%24, %20)
%28 : Float(*, *) = aten::add(%45, %50, %27)
%24 : Float(*, *) = aten::sigmoid(%40)
%22 : Float(*, *) = aten::sigmoid(%36)
%20 : Float(*, *) = aten::tanh(%32)
%18 : Float(*, *) = aten::sigmoid(%28)
%16 : Float(*, *) = aten::mul(%22, %15)
%13 : Float(*, *) = aten::mul(%24, %20)
%9 : int = prim::Constant[value=1]()
%10 : Float(3, 20) = aten::add(%16, %13, %9)
%6 : Float(3, 20) = aten::tanh(%10)
%5 : Float(3, 20) = aten::mul(%18, %6)
%2 : Float(6, 20) = prim::FusedConcat[dim=0](%5, %10)
%10 : Float(*, *) = aten::add(%16, %13, %9)
%6 : Float(*, *) = aten::tanh(%10)
%5 : Float(*, *) = aten::mul(%18, %6)
%2 : Float(*, *) = prim::FusedConcat[dim=0](%5, %10)
return (%2);
}
60 changes: 31 additions & 29 deletions test/expect/TestJit.test_lstm_fusion_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,40 +1,42 @@
graph(%0 : Float(3, 10)
%1 : Float(3, 20)
%2 : Float(3, 20)
%3 : Float(80, 10)
%4 : Float(80, 20)
%5 : Float(80)
%6 : Float(80)) {
%7 : Float(10!, 80!) = aten::t(%3)
graph(%0 : Float(*, *)
%1 : Float(*, *)
%2 : Float(*, *)
%3 : Float(*, *)
%4 : Float(*, *)
%5 : Float(*)
%6 : Float(*)) {
%7 : Float(*, *) = aten::t(%3)
%8 : int = prim::Constant[value=1]()
%9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8)
%10 : Float(20!, 80!) = aten::t(%4)
%11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8)
%12 : Float(3, 20), %13 : Float(3, 20) = prim::FusionGroup_0[device=0](%2, %9, %11)
return (%12, %13);
%9 : Float(*, *) = aten::addmm(%5, %0, %7, %8, %8)
%10 : Float(*, *) = aten::t(%4)
%11 : Float(*, *) = aten::addmm(%6, %1, %10, %8, %8)
%12 : Dynamic[] = prim::ListConstruct(%9, %11)
%13 : Dynamic, %14 : Dynamic = aten::broadcast_tensors(%12)
%15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %13, %14)
return (%15, %16);
}
with prim::FusionGroup_0 = graph(%13 : Float(3, 20)
%39 : Float(3, 80)
%44 : Float(3, 80)) {
with prim::FusionGroup_0 = graph(%13 : Float(*, *)
%39 : Dynamic
%44 : Dynamic) {
%45 : Dynamic, %46 : Dynamic, %47 : Dynamic, %48 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%44)
%40 : Dynamic, %41 : Dynamic, %42 : Dynamic, %43 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%39)
%37 : int = prim::Constant[value=1]()
%38 : Float(3, 20) = aten::add(%40, %45, %37)
%38 : Float(*, *) = aten::add(%40, %45, %37)
%33 : int = prim::Constant[value=1]()
%34 : Float(3, 20) = aten::add(%41, %46, %33)
%34 : Float(*, *) = aten::add(%41, %46, %33)
%29 : int = prim::Constant[value=1]()
%30 : Float(3, 20) = aten::add(%42, %47, %29)
%30 : Float(*, *) = aten::add(%42, %47, %29)
%25 : int = prim::Constant[value=1]()
%26 : Float(3, 20) = aten::add(%43, %48, %25)
%22 : Float(3, 20) = aten::sigmoid(%38)
%20 : Float(3, 20) = aten::sigmoid(%34)
%18 : Float(3, 20) = aten::tanh(%30)
%16 : Float(3, 20) = aten::sigmoid(%26)
%14 : Float(3, 20) = aten::mul(%20, %13)
%11 : Float(3, 20) = aten::mul(%22, %18)
%26 : Float(*, *) = aten::add(%43, %48, %25)
%22 : Float(*, *) = aten::sigmoid(%38)
%20 : Float(*, *) = aten::sigmoid(%34)
%18 : Float(*, *) = aten::tanh(%30)
%16 : Float(*, *) = aten::sigmoid(%26)
%14 : Float(*, *) = aten::mul(%20, %13)
%11 : Float(*, *) = aten::mul(%22, %18)
%7 : int = prim::Constant[value=1]()
%8 : Float(3, 20) = aten::add(%14, %11, %7)
%4 : Float(3, 20) = aten::tanh(%8)
%2 : Float(3, 20) = aten::mul(%16, %4)
%8 : Float(*, *) = aten::add(%14, %11, %7)
%4 : Float(*, *) = aten::tanh(%8)
%2 : Float(*, *) = aten::mul(%16, %4)
return (%2, %8);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
graph(%0 : Double(3, 4)
%1 : Double(4, 5)) {
%2 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
%3 : Double(3, 5) = aten::neg(%2), scope: TracedModule/ScriptModule
%3 : Double(*, *) = aten::neg(%2), scope: TracedModule/ScriptModule
return (%3);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
graph(%0 : Double(3, 4)) {
%1 : Double(3, 4) = aten::neg(%0), scope: ScriptModule
%1 : Double(*, *) = aten::neg(%0), scope: ScriptModule
%2 : Long() = prim::Constant[value={1}]()
%3 : int = prim::Constant[value=1]()
%4 : Double(3, 4) = aten::add(%1, %2, %3)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
graph(%0 : Double(3, 4)) {
%1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: ScriptMod
%2 : Double(3, 3) = aten::mm(%0, %1), scope: ScriptMod
%2 : Double(*, *) = aten::mm(%0, %1), scope: ScriptMod
%3 : Long() = prim::Constant[value={1}]()
%4 : int = prim::Constant[value=1]()
%5 : Double(3, 3) = aten::add(%2, %3, %4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ graph(%0 : Double(3, 4)
%1 : Double(4, 5)
%2 : Double(5, 7)) {
%3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
%4 : Double(3, 7) = aten::mm(%3, %2), scope: TracedModule/ScriptMod[mod]
%4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule/ScriptMod[mod]
%5 : Long() = prim::Constant[value={1}](), scope: TracedModule
%6 : int = prim::Constant[value=1](), scope: TracedModule
%7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule
Expand Down
10 changes: 5 additions & 5 deletions test/expect/TestScript.test_chunk_fusion_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
graph(%x : Float(10, 6)) {
%1 : Float(10, 2) = prim::FusionGroup_0[device=0](%x)
graph(%x : Float(*, *)) {
%1 : Float(*, *) = prim::FusionGroup_0[device=0](%x)
return (%1);
}
with prim::FusionGroup_0 = graph(%7 : Float(10, 6)) {
with prim::FusionGroup_0 = graph(%7 : Float(*, *)) {
%8 : Dynamic, %9 : Dynamic, %10 : Dynamic = prim::FusedChunk[chunks=3, dim=1](%7)
%6 : Float(10, 2) = aten::mul(%8, %9)
%6 : Float(*, *) = aten::mul(%8, %9)
%2 : int = prim::Constant[value=1]()
%3 : Float(10, 2) = aten::add(%6, %10, %2)
%3 : Float(*, *) = aten::add(%6, %10, %2)
return (%3);
}
32 changes: 16 additions & 16 deletions test/expect/TestScript.test_chunk_multiple_fusion_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
graph(%s : Float(5, 2, 3)
%x : Float(5, 6, 3)
%y : Float(10, 2, 3)
%z : Float(5, 2, 6)) {
%4 : Float(5, 2, 3) = prim::FusionGroup_0[device=0](%s, %y, %x, %z)
graph(%s : Float(*, *, *)
%x : Float(*, *, *)
%y : Float(*, *, *)
%z : Float(*, *, *)) {
%4 : Float(*, *, *) = prim::FusionGroup_0[device=0](%s, %y, %x, %z)
return (%4);
}
with prim::FusionGroup_0 = graph(%24 : Float(5, 2, 3)
%28 : Float(10, 2, 3)
%31 : Float(5, 6, 3)
%35 : Float(5, 2, 6)) {
with prim::FusionGroup_0 = graph(%24 : Float(*, *, *)
%28 : Float(*, *, *)
%31 : Float(*, *, *)
%35 : Float(*, *, *)) {
%36 : Dynamic, %37 : Dynamic = prim::FusedChunk[chunks=2, dim=2](%35)
%32 : Dynamic, %33 : Dynamic, %34 : Dynamic = prim::FusedChunk[chunks=3, dim=1](%31)
%29 : Dynamic, %30 : Dynamic = prim::FusedChunk[chunks=2, dim=0](%28)
%26 : int = prim::Constant[value=1]()
%27 : Float(5, 2, 3) = aten::add(%24, %32, %26)
%27 : Float(*, *, *) = aten::add(%24, %32, %26)
%22 : int = prim::Constant[value=1]()
%23 : Float(5, 2, 3) = aten::add(%27, %33, %22)
%23 : Float(*, *, *) = aten::add(%27, %33, %22)
%18 : int = prim::Constant[value=1]()
%19 : Float(5, 2, 3) = aten::add(%23, %34, %18)
%19 : Float(*, *, *) = aten::add(%23, %34, %18)
%14 : int = prim::Constant[value=1]()
%15 : Float(5, 2, 3) = aten::add(%19, %29, %14)
%15 : Float(*, *, *) = aten::add(%19, %29, %14)
%10 : int = prim::Constant[value=1]()
%11 : Float(5, 2, 3) = aten::add(%15, %30, %10)
%11 : Float(*, *, *) = aten::add(%15, %30, %10)
%6 : int = prim::Constant[value=1]()
%7 : Float(5, 2, 3) = aten::add(%11, %36, %6)
%7 : Float(*, *, *) = aten::add(%11, %36, %6)
%2 : int = prim::Constant[value=1]()
%3 : Float(5, 2, 3) = aten::add(%7, %37, %2)
%3 : Float(*, *, *) = aten::add(%7, %37, %2)
return (%3);
}
15 changes: 15 additions & 0 deletions test/expect/TestScript.test_if_list.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
graph(%x : Double(*, *)) {
%1 : int = prim::Constant[value=1]()
%c : Dynamic[] = prim::If(%1)
block0() {
%c.1 : Dynamic[] = prim::ListConstruct(%x, %x)
-> (%c.1)
}
block1() {
%c.2 : Dynamic[] = prim::ListConstruct(%x, %x, %x)
-> (%c.2)
}
%5 : int = prim::Constant[value=0]()
%6 : Dynamic = aten::cat(%c, %5)
return (%6);
}
Loading

0 comments on commit c8b246a

Please sign in to comment.