Skip to content

Commit

Permalink
Update on "Fast path binary ops in fake tensor"
Browse files Browse the repository at this point in the history
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}
 
+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 10s of trace time improvement on the table (5s of which is attributable to not yet handling relu.)

The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path through if at least one of the input operands matches the broadcasted shape exactly (the idea being that we will probably use that tensor's layout.) I am pretty sure this is not sound, but I need to check tests to see how unsound it is.
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

I intend to verify whether or not the new algorithm is correct using Z3.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
  • Loading branch information
ezyang committed Feb 3, 2023
1 parent 1058c6f commit bdea9e7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/functorch/test_aotdispatch.py
Expand Up @@ -2111,6 +2111,7 @@ def forward(self, x, y):

self.assertExpectedInline(shape_env.format_guards(), """\
- Eq(s1, 20)
- Ne(s0, 30)
- Eq(s2, 30)""")

assert torch.allclose(ref[0], res[0])
Expand Down
2 changes: 1 addition & 1 deletion torch/_subclasses/fake_tensor.py
Expand Up @@ -673,7 +673,7 @@ def slow(msg):
return FakeTensor(
mode,
torch.empty(
shape,
final_shape,
dtype=common_dtype,
device="meta",
memory_format=torch.channels_last,
Expand Down

0 comments on commit bdea9e7

Please sign in to comment.