-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[FakeTensor] Reuse flat_args throughout FakeTensorMode.dispatch #112418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This function repeatedly flattens and unflattens the `args, kwargs` pair so we get a quite significant perf improvement from saving the `flat_args` and operating directly on those. I see a 15% improvement in dispatch for `empty_strided`. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112418
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2cafa24 with merge base 29f3d39 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This function repeatedly flattens and unflattens the `args, kwargs` pair so we get a quite significant perf improvement from saving the `flat_args` and operating directly on those. I see a 15% improvement in dispatch for `empty_strided`. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two optional nits
torch/_subclasses/fake_tensor.py
Outdated
|
||
def wrap_meta_outputs_with_default_device_logic(self, r, func, args, kwargs): | ||
wrap = self.gen_wrap_fn(func, args, kwargs) | ||
def wrap_meta_outputs_with_default_device_logic(self, r, func, flat_args, kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. perhaps just pass `device rather than passing flat_args and kwargs. I found this a bit confusing.
torch/_subclasses/fake_tensor.py
Outdated
|
||
return tree_map(map_out, r) | ||
flat_out = [map_out(o) for o in flat_out] | ||
return pytree.tree_unflatten(flat_out, out_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here perhaps we just want to tree_map
all the transformation on r
rather than unpacking and packing again?
…patch" This function repeatedly flattens and unflattens the `args, kwargs` pair so we get a quite significant perf improvement from saving the `flat_args` and operating directly on those. I see a 15% improvement in dispatch for `empty_strided`. [ghstack-poisoned]
…patch" This function repeatedly flattens and unflattens the `args, kwargs` pair so we get a quite significant perf improvement from saving the `flat_args` and operating directly on those. I see a 15% improvement in dispatch for `empty_strided`. [ghstack-poisoned]
`ShapeEnv` has tons of functionallity that is conditioned on this `translation_validation_enabled()` check, to the point where 8% of time in `empty_strided` is spent just in that function. However, it doesn't really make sense for the value of `translation_validation_enabled()` to change throughout the life of a `ShapeEnv` so we might as well run the check once and store it in the `ShapeEnv`. Pull Request resolved: #112493 Approved by: https://github.com/lezcano ghstack dependencies: #112418
…rch#112418) This function repeatedly flattens and unflattens the `args, kwargs` pair so we get a quite significant perf improvement from saving the `flat_args` and operating directly on those. I see a 15% improvement in dispatch for `empty_strided`. Pull Request resolved: pytorch#112418 Approved by: https://github.com/lezcano
`ShapeEnv` has tons of functionallity that is conditioned on this `translation_validation_enabled()` check, to the point where 8% of time in `empty_strided` is spent just in that function. However, it doesn't really make sense for the value of `translation_validation_enabled()` to change throughout the life of a `ShapeEnv` so we might as well run the check once and store it in the `ShapeEnv`. Pull Request resolved: pytorch#112493 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#112418
…rch#112418) This function repeatedly flattens and unflattens the `args, kwargs` pair so we get a quite significant perf improvement from saving the `flat_args` and operating directly on those. I see a 15% improvement in dispatch for `empty_strided`. Pull Request resolved: pytorch#112418 Approved by: https://github.com/lezcano
`ShapeEnv` has tons of functionallity that is conditioned on this `translation_validation_enabled()` check, to the point where 8% of time in `empty_strided` is spent just in that function. However, it doesn't really make sense for the value of `translation_validation_enabled()` to change throughout the life of a `ShapeEnv` so we might as well run the check once and store it in the `ShapeEnv`. Pull Request resolved: pytorch#112493 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#112418
Stack from ghstack (oldest at bottom):
ShapeEnv
#112493This function repeatedly flattens and unflattens the
args, kwargs
pair so weget a quite significant perf improvement from saving the
flat_args
andoperating directly on those. I see a 15% improvement in dispatch for
empty_strided
.