New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ScalarTensor or 0dim overload for _foreach_add #111079
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111079
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cd75529 with merge base 74f6f7a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: cf86eb1f4263234457269c6c9db3ec92a35cb0d9 Pull Request resolved: #111079
@@ -152,49 +152,6 @@ __device__ __forceinline__ void binary_op_scalar( | |||
} | |||
} | |||
|
|||
template <int res_arg_index, typename Op, typename T, typename scalar_t = T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up inlining this change due to not being able to specify that scalar is of type opmath_t. I cannot figure out where to put using opmath_t = at::opmath_type<T>;
before the function signature but after the template specification.
@@ -778,15 +778,14 @@ def test_tensors_grouping(self): | |||
def test_0dim_tensor_overload_exception(self): | |||
# check exceptions of fast path | |||
tensors = [make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)] | |||
|
|||
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test looks to be already tested when d="cuda" in the next few lines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this might be a bit extreme question but could it be possible that we support Tensor alpha
?
haha isn't that just addcmul? |
_foreach_with_tensor_overload = {"_foreach_mul.Tensor"} | ||
_foreach_with_tensor_overload = { | ||
"_foreach_add.Tensor", | ||
"_foreach_mul.Tensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we already had this code for mul, why is it so much work to add support for add?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a Tensor overload will allow us to: - optimize in more cases than before - increase coverage for scalarTensor instead of just scalars in our foreach APIs The main complication in this PR was that add.Tensor has a scalar overload, so I've now built out support for that. cc crcrpar [ghstack-poisoned]
Adding a Tensor overload will allow us to: - optimize in more cases than before - increase coverage for scalarTensor instead of just scalars in our foreach APIs The main complication in this PR was that add.Tensor has a scalar overload, so I've now built out support for that. cc crcrpar [ghstack-poisoned]
Adding a Tensor overload will allow us to: - optimize in more cases than before - increase coverage for scalarTensor instead of just scalars in our foreach APIs The main complication in this PR was that add.Tensor has a scalar overload, so I've now built out support for that. cc crcrpar [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small verification on test but sounds good otherwise!
Adding a Tensor overload will allow us to: - optimize in more cases than before - increase coverage for scalarTensor instead of just scalars in our foreach APIs The main complication in this PR was that add.Tensor has a scalar overload, so I've now built out support for that. cc crcrpar [ghstack-poisoned]
This is the culminated result of #110954 (comment). We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`. ### Code ``` import torch with torch.cuda.device(0): steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)] with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: # New code: # step_device = steps[0].device # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1 # torch._foreach_add_(steps, one, 1.0) # Old code: torch._foreach_add_(steps, 1) print(p.key_averages().table(sort_by="cpu_time_total")) ``` ### Profiles **with old code** ``` ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_add_ 35.31% 52.089ms 99.99% 147.495ms 147.495ms 1 aten::add_ 25.05% 36.949ms 64.68% 95.406ms 95.406us 1000 aten::to 3.97% 5.852ms 39.63% 58.457ms 58.457us 1000 aten::_to_copy 10.11% 14.917ms 35.66% 52.605ms 52.605us 1000 aten::copy_ 21.65% 31.939ms 21.65% 31.939ms 31.939us 1000 aten::empty_strided 3.90% 5.749ms 3.90% 5.749ms 5.749us 1000 cudaDeviceSynchronize 0.01% 18.000us 0.01% 18.000us 18.000us 1 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 147.513ms ``` **with new code** ``` ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_add_ 55.06% 49.963ms 99.86% 90.625ms 90.625ms 1 aten::add_ 44.81% 40.662ms 44.81% 40.662ms 40.662us 1000 aten::detach_ 0.01% 8.000us 0.05% 45.000us 45.000us 1 detach_ 0.04% 37.000us 0.04% 37.000us 37.000us 1 aten::empty 0.03% 30.000us 0.03% 30.000us 30.000us 1 aten::to 0.03% 23.000us 0.03% 23.000us 23.000us 1 cudaDeviceSynchronize 0.02% 22.000us 0.02% 22.000us 22.000us 1 aten::lift_fresh 0.01% 6.000us 0.01% 6.000us 6.000us 1 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 90.751ms ``` Pull Request resolved: #111084 Approved by: https://github.com/albanD ghstack dependencies: #111079
Adding a Tensor overload will allow us to: - optimize in more cases than before - increase coverage for scalarTensor instead of just scalars in our foreach APIs The main complication in this PR was that add.Tensor has a scalar overload, so I've now built out support for that. Pull Request resolved: pytorch#111079 Approved by: https://github.com/albanD
This is the culminated result of pytorch#110954 (comment). We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`. ### Code ``` import torch with torch.cuda.device(0): steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)] with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: # New code: # step_device = steps[0].device # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1 # torch._foreach_add_(steps, one, 1.0) # Old code: torch._foreach_add_(steps, 1) print(p.key_averages().table(sort_by="cpu_time_total")) ``` ### Profiles **with old code** ``` ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_add_ 35.31% 52.089ms 99.99% 147.495ms 147.495ms 1 aten::add_ 25.05% 36.949ms 64.68% 95.406ms 95.406us 1000 aten::to 3.97% 5.852ms 39.63% 58.457ms 58.457us 1000 aten::_to_copy 10.11% 14.917ms 35.66% 52.605ms 52.605us 1000 aten::copy_ 21.65% 31.939ms 21.65% 31.939ms 31.939us 1000 aten::empty_strided 3.90% 5.749ms 3.90% 5.749ms 5.749us 1000 cudaDeviceSynchronize 0.01% 18.000us 0.01% 18.000us 18.000us 1 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 147.513ms ``` **with new code** ``` ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_add_ 55.06% 49.963ms 99.86% 90.625ms 90.625ms 1 aten::add_ 44.81% 40.662ms 44.81% 40.662ms 40.662us 1000 aten::detach_ 0.01% 8.000us 0.05% 45.000us 45.000us 1 detach_ 0.04% 37.000us 0.04% 37.000us 37.000us 1 aten::empty 0.03% 30.000us 0.03% 30.000us 30.000us 1 aten::to 0.03% 23.000us 0.03% 23.000us 23.000us 1 cudaDeviceSynchronize 0.02% 22.000us 0.02% 22.000us 22.000us 1 aten::lift_fresh 0.01% 6.000us 0.01% 6.000us 6.000us 1 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 90.751ms ``` Pull Request resolved: pytorch#111084 Approved by: https://github.com/albanD ghstack dependencies: pytorch#111079
Adding a Tensor overload will allow us to: - optimize in more cases than before - increase coverage for scalarTensor instead of just scalars in our foreach APIs The main complication in this PR was that add.Tensor has a scalar overload, so I've now built out support for that. Pull Request resolved: pytorch#111079 Approved by: https://github.com/albanD
This is the culminated result of pytorch#110954 (comment). We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`. ### Code ``` import torch with torch.cuda.device(0): steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)] with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: # New code: # step_device = steps[0].device # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1 # torch._foreach_add_(steps, one, 1.0) # Old code: torch._foreach_add_(steps, 1) print(p.key_averages().table(sort_by="cpu_time_total")) ``` ### Profiles **with old code** ``` ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_add_ 35.31% 52.089ms 99.99% 147.495ms 147.495ms 1 aten::add_ 25.05% 36.949ms 64.68% 95.406ms 95.406us 1000 aten::to 3.97% 5.852ms 39.63% 58.457ms 58.457us 1000 aten::_to_copy 10.11% 14.917ms 35.66% 52.605ms 52.605us 1000 aten::copy_ 21.65% 31.939ms 21.65% 31.939ms 31.939us 1000 aten::empty_strided 3.90% 5.749ms 3.90% 5.749ms 5.749us 1000 cudaDeviceSynchronize 0.01% 18.000us 0.01% 18.000us 18.000us 1 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 147.513ms ``` **with new code** ``` ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_add_ 55.06% 49.963ms 99.86% 90.625ms 90.625ms 1 aten::add_ 44.81% 40.662ms 44.81% 40.662ms 40.662us 1000 aten::detach_ 0.01% 8.000us 0.05% 45.000us 45.000us 1 detach_ 0.04% 37.000us 0.04% 37.000us 37.000us 1 aten::empty 0.03% 30.000us 0.03% 30.000us 30.000us 1 aten::to 0.03% 23.000us 0.03% 23.000us 23.000us 1 cudaDeviceSynchronize 0.02% 22.000us 0.02% 22.000us 22.000us 1 aten::lift_fresh 0.01% 6.000us 0.01% 6.000us 6.000us 1 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 90.751ms ``` Pull Request resolved: pytorch#111084 Approved by: https://github.com/albanD ghstack dependencies: pytorch#111079
Adding a Tensor overload will allow us to:
The main complication in this PR was that add.Tensor has a scalar overload, so I've now built out support for that.
cc @crcrpar
Stack from ghstack (oldest at bottom):