Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast path binary ops in fake tensor #94047

Closed
wants to merge 7 commits into from

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Feb 3, 2023

Stack from ghstack (oldest at bottom):

Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18, I get the following trace speedup.

Before:

cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010

After:

cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}
 
+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:

I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)

The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

  • Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
  • I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))

Signed-off-by: Edward Z. Yang ezyang@meta.com

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94047

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 79 Pending

As of commit df2e401:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give more details on the kind of gains we get from this? This does add significant complexity.

@ezyang ezyang requested a review from ngimel February 3, 2023 14:26
@ezyang
Copy link
Contributor Author

ezyang commented Feb 3, 2023

PR description updated!

@ezyang ezyang changed the title Fastpath binary ops Fast path binary ops in fake tensor Feb 3, 2023
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}
 
+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 10s of trace time improvement on the table (5s of which is attributable to not yet handling relu.)

The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path through if at least one of the input operands matches the broadcasted shape exactly (the idea being that we will probably use that tensor's layout.) I am pretty sure this is not sound, but I need to check tests to see how unsound it is.
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

I intend to verify whether or not the new algorithm is correct using Z3.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}
 
+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 10s of trace time improvement on the table (5s of which is attributable to not yet handling relu.)

The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path through if at least one of the input operands matches the broadcasted shape exactly (the idea being that we will probably use that tensor's layout.) I am pretty sure this is not sound, but I need to check tests to see how unsound it is.
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

I intend to verify whether or not the new algorithm is correct using Z3.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Feb 3, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 5ad698daf07101c437eb7868736c7f77c1b2e4f7
Pull Request resolved: #94047
@ezyang ezyang added release notes: composability release notes category topic: not user facing topic category labels Feb 3, 2023
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you tested short circuiting but reusing more of the methods from prims / elsewhere for this ? Would be nice to cut down on some of the duplication. I think a lot of the speedup would still be applicable

Comment on lines +952 to +956
try:
return self.dispatch(func, types, args, kwargs)
except TypeError:
log.exception("fake tensor raised TypeError")
raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these changes for ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I was working on this PR I sometimes messed up my short circuit logic and triggered a TypeError. This TypeError was silently swallowed. Now I get a log for it.

)
if is_contiguous:
# do contiguous
count_label("fast is_contiguous")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo it's a little strange to have this on by default and only have telemetry on such a small part of the model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH we should do more telemetry. I'm happy to remove this but this was also very useful for understand perf characteristics here.

torch/_subclasses/fake_tensor.py Outdated Show resolved Hide resolved

simple_call_counter = collections.OrderedDict()
simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need OrderedDict ? isn't dict already ordered ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like preserving key order for things like this.

torch/_subclasses/fake_tensor.py Show resolved Hide resolved
@ezyang ezyang added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 3, 2023
@ezyang
Copy link
Contributor Author

ezyang commented Feb 3, 2023

Now updated with some evidence the heuristic is OK; see bottom of PR desc=

@ezyang
Copy link
Contributor Author

ezyang commented Feb 4, 2023

Here are the most improved models with this change:

image

and the least improved:

image

eca_halonext26ts is also interesting: a long running model that isn't helped much

image

Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}
 
+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 10s of trace time improvement on the table (5s of which is attributable to not yet handling relu.)

The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Feb 4, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 1c56d3d7930c3cf970742ea069f17941c2e18f5f
Pull Request resolved: #94047
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug code needs to go but sounds ok otherwise.

test/functorch/test_aotdispatch.py Outdated Show resolved Hide resolved

def count_label(label):
prev = simple_call_counter.setdefault(label, 0)
simple_call_counter[label] = prev + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use a defaultdict and just += 1 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK why @voznesenskym didn't make this a defaultdict, I was minimizing changes here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that's fine.

torch/_subclasses/fake_tensor.py Outdated Show resolved Hide resolved
@ezyang ezyang requested a review from albanD February 6, 2023 15:13
@albanD
Copy link
Collaborator

albanD commented Feb 6, 2023

btw the debug code I said need to go is the one raised by Elias above that you said you'll remove before merging.

@ezyang
Copy link
Contributor Author

ezyang commented Feb 6, 2023

OK, to be clear, @eellison do you want it removed? I would prefer it to stay but if someone says "please remove" I will remove.

@albanD
Copy link
Collaborator

albanD commented Feb 6, 2023

I think the worse error message is worth fixing.
The others are all metrics collection and logging. So fine.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with it I guess but if we get more telemetry we should make them not on by default. inductor has various loggings but I don't think any of them are on by default.

I don't particularly care in either direction. Alban's concern about readability is true but we can cross that bridge when we get to it.

return tuple(expandedSizes)


def make_fast_binary_impl(slow_ref):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be worth moving this to fake_utils ? idk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it only needs to be used here, so let's keep it here. If we want, we could make a separate module for "fake tensor op implementations"

torch/_subclasses/fake_tensor.py Outdated Show resolved Hide resolved
torch/_subclasses/fake_tensor.py Show resolved Hide resolved
with mode:
return slow_ref(*args, **kwargs)

count_label("attempt fast")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would we even need the count_labels if these were factored out into functions ? how slow are python profilers? or maybe just annoying to use

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a combination of slowness (for non-sampling profilers) and also annoyance (the function call will be lost in a sea of other function calls.)

Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}
 
+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)

The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Feb 7, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 07fb7ed8733856f090a02cacb25000a27d08145d
Pull Request resolved: #94047
@ezyang
Copy link
Contributor Author

ezyang commented Feb 7, 2023

I'm fine with it I guess but if we get more telemetry we should make them not on by default. inductor has various loggings but I don't think any of them are on by default.

The telemetry goes into some counters which don't get printed by default. That's the same as how dynamo collects counter telemetry too.

Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}
 
+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)

The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@ezyang
Copy link
Contributor Author

ezyang commented Feb 7, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed (Rule superuser). The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}
 
+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)

The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@ezyang
Copy link
Contributor Author

ezyang commented Feb 7, 2023

@pytorchbot merge

ezyang added a commit that referenced this pull request Feb 7, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 982f51e9932173556f9f3aee7825beca8527d953
Pull Request resolved: #94047
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1778/head branch June 8, 2023 16:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: composability release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants