-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast path binary ops in fake tensor #94047
Conversation
Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94047
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 79 PendingAs of commit df2e401: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give more details on the kind of gains we get from this? This does add significant complexity.
PR description updated! |
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup. Before: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:53.97591 backend_compile:33.60832 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` After: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:40.18931 backend_compile:25.28828 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit# This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment: ``` diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e3bf545f3b8..395942c6ffe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + with no_dispatch(): + if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}: + return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda') + if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: ``` I am still leaving about 10s of trace time improvement on the table (5s of which is attributable to not yet handling relu.) The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences: * Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path through if at least one of the input operands matches the broadcasted shape exactly (the idea being that we will probably use that tensor's layout.) I am pretty sure this is not sound, but I need to check tests to see how unsound it is. * I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right. I intend to verify whether or not the new algorithm is correct using Z3. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup. Before: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:53.97591 backend_compile:33.60832 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` After: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:40.18931 backend_compile:25.28828 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit# This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment: ``` diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e3bf545f3b8..395942c6ffe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + with no_dispatch(): + if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}: + return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda') + if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: ``` I am still leaving about 10s of trace time improvement on the table (5s of which is attributable to not yet handling relu.) The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences: * Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path through if at least one of the input operands matches the broadcasted shape exactly (the idea being that we will probably use that tensor's layout.) I am pretty sure this is not sound, but I need to check tests to see how unsound it is. * I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right. I intend to verify whether or not the new algorithm is correct using Z3. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: 5ad698daf07101c437eb7868736c7f77c1b2e4f7 Pull Request resolved: #94047
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you tested short circuiting but reusing more of the methods from prims / elsewhere for this ? Would be nice to cut down on some of the duplication. I think a lot of the speedup would still be applicable
try: | ||
return self.dispatch(func, types, args, kwargs) | ||
except TypeError: | ||
log.exception("fake tensor raised TypeError") | ||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are these changes for ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I was working on this PR I sometimes messed up my short circuit logic and triggered a TypeError. This TypeError was silently swallowed. Now I get a log for it.
) | ||
if is_contiguous: | ||
# do contiguous | ||
count_label("fast is_contiguous") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imo it's a little strange to have this on by default and only have telemetry on such a small part of the model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH we should do more telemetry. I'm happy to remove this but this was also very useful for understand perf characteristics here.
|
||
simple_call_counter = collections.OrderedDict() | ||
simple_call_counter: OrderedDict[str, int] = collections.OrderedDict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need OrderedDict ? isn't dict already ordered ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ask @voznesenskym
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like preserving key order for things like this.
Now updated with some evidence the heuristic is OK; see bottom of PR desc= |
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup. Before: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:53.97591 backend_compile:33.60832 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` After: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:40.18931 backend_compile:25.28828 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit# This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment: ``` diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e3bf545f3b8..395942c6ffe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + with no_dispatch(): + if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}: + return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda') + if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: ``` I am still leaving about 10s of trace time improvement on the table (5s of which is attributable to not yet handling relu.) The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences: * Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last). * I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right. Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1)) Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: 1c56d3d7930c3cf970742ea069f17941c2e18f5f Pull Request resolved: #94047
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug code needs to go but sounds ok otherwise.
|
||
def count_label(label): | ||
prev = simple_call_counter.setdefault(label, 0) | ||
simple_call_counter[label] = prev + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use a defaultdict and just += 1 here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IDK why @voznesenskym didn't make this a defaultdict, I was minimizing changes here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that's fine.
btw the debug code I said need to go is the one raised by Elias above that you said you'll remove before merging. |
OK, to be clear, @eellison do you want it removed? I would prefer it to stay but if someone says "please remove" I will remove. |
I think the worse error message is worth fixing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with it I guess but if we get more telemetry we should make them not on by default. inductor has various loggings but I don't think any of them are on by default.
I don't particularly care in either direction. Alban's concern about readability is true but we can cross that bridge when we get to it.
return tuple(expandedSizes) | ||
|
||
|
||
def make_fast_binary_impl(slow_ref): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be worth moving this to fake_utils ? idk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it only needs to be used here, so let's keep it here. If we want, we could make a separate module for "fake tensor op implementations"
with mode: | ||
return slow_ref(*args, **kwargs) | ||
|
||
count_label("attempt fast") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would we even need the count_labels if these were factored out into functions ? how slow are python profilers? or maybe just annoying to use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's a combination of slowness (for non-sampling profilers) and also annoyance (the function call will be lost in a sea of other function calls.)
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup. Before: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:53.97591 backend_compile:33.60832 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` After: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:40.18931 backend_compile:25.28828 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit# This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment: ``` diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e3bf545f3b8..395942c6ffe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + with no_dispatch(): + if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}: + return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda') + if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: ``` I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.) The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences: * Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last). * I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right. Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1)) Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: 07fb7ed8733856f090a02cacb25000a27d08145d Pull Request resolved: #94047
The telemetry goes into some counters which don't get printed by default. That's the same as how dynamo collects counter telemetry too. |
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup. Before: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:53.97591 backend_compile:33.60832 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` After: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:40.18931 backend_compile:25.28828 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit# This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment: ``` diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e3bf545f3b8..395942c6ffe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + with no_dispatch(): + if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}: + return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda') + if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: ``` I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.) The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences: * Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last). * I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right. Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1)) Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup. Before: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:53.97591 backend_compile:33.60832 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` After: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:40.18931 backend_compile:25.28828 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit# This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment: ``` diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e3bf545f3b8..395942c6ffe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + with no_dispatch(): + if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}: + return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda') + if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: ``` I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.) The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences: * Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last). * I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right. Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1)) Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
@pytorchbot merge |
Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: 982f51e9932173556f9f3aee7825beca8527d953 Pull Request resolved: #94047
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing
python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18
, I get the following trace speedup.Before:
After:
My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#
This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:
I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)
The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:
Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))
Signed-off-by: Edward Z. Yang ezyang@meta.com