-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 0dim Tensor overload for _foreach_div #113688
Closed
Closed
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
50dfbee
Add 0dim Tensor overload for _foreach_div
janeyx99 2e6bb1a
Update on "Add 0dim Tensor overload for _foreach_div"
janeyx99 97fd79f
Update on "Add 0dim Tensor overload for _foreach_div"
janeyx99 f46ca55
Update on "Add 0dim Tensor overload for _foreach_div"
janeyx99 ed39357
Update on "Add 0dim Tensor overload for _foreach_div"
janeyx99 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,7 @@ def __init__(self, func): | |
# Some foreach functions don't have in-place implementations. | ||
self.is_inplace = False if func is None else func.__name__.endswith('_') | ||
|
||
def __call__(self, inputs, is_cuda, is_fastpath, **kwargs): | ||
def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): | ||
actual = None | ||
zero_size = kwargs.pop("zero_size", False) | ||
if ( | ||
|
@@ -60,7 +60,7 @@ def __call__(self, inputs, is_cuda, is_fastpath, **kwargs): | |
actual = self.func(*inputs, **kwargs) | ||
keys = tuple([e.key for e in p.key_averages()]) | ||
mta_called = any("multi_tensor_apply_kernel" in k for k in keys) | ||
assert mta_called == (is_fastpath and (not zero_size)) | ||
assert mta_called == (expect_fastpath and (not zero_size)) | ||
else: | ||
actual = self.func(*inputs, **kwargs) | ||
# note(mkozuki): inplace foreach functions are void functions. | ||
|
@@ -127,9 +127,9 @@ def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op): | |
|
||
for sample in op.sample_zero_size_inputs(device, dtype): | ||
if not op.has_no_out_of_place: | ||
wrapped_op((sample.input, *sample.args), is_cuda=self.is_cuda, is_fastpath=True, zero_size=True) | ||
wrapped_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True) | ||
with InplaceForeachVersionBumpCheck(self, sample.input): | ||
inplace_op((sample.input, *sample.args), is_cuda=self.is_cuda, is_fastpath=True, zero_size=True) | ||
inplace_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True) | ||
|
||
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") | ||
@ops( | ||
|
@@ -150,7 +150,9 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): | |
for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous): | ||
ref_kwargs = sample.kwargs | ||
kwargs = ref_kwargs.copy() | ||
expect_fastpath = not (noncontiguous or sample.disable_fastpath) | ||
# div promotes ints to floats, so we cannot go on the fastpath there | ||
div_slowpath = dtype in integral_types_and(torch.bool) and op.name == '_foreach_div' | ||
expect_fastpath = not (noncontiguous or sample.disable_fastpath or div_slowpath) | ||
if op in foreach_pointwise_op_db: | ||
values = kwargs.pop("values", None) | ||
if values is not None: | ||
|
@@ -242,7 +244,7 @@ def clone(arg): | |
(rhs_arg,) = transformed_sample.args | ||
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg) | ||
sum(wrapped_op( | ||
[rhs_arg, tensors], is_cuda=False, is_fastpath=False | ||
[rhs_arg, tensors], is_cuda=False, expect_fastpath=False | ||
)).mean().backward() | ||
sum([ref.func(ref_rhs_arg, t) for t in ref_tensors]).mean().backward() | ||
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors]) | ||
|
@@ -628,7 +630,7 @@ def test_foreach_l2_large_value_input(self, device, dtype, op): | |
# make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`. | ||
self.assertTrue(scaler * scaler * N > max_value) | ||
fn, ref_fn, *_ = self._get_funcs(op) | ||
actual = fn(inputs, is_cuda=True, is_fastpath=True, ord=ord, zero_size=False) | ||
actual = fn(inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False) | ||
expect = ref_fn(inputs, ord=ord) | ||
|
||
if dtype == torch.float16: | ||
|
@@ -692,7 +694,7 @@ def test_outplace_with_invalid_grads(self, device, dtype, op): | |
self.assertTrue(all(t.requires_grad for t in sample.input)) | ||
if func.func in foreach_pointwise_op_db: | ||
sample.kwargs.pop("values", None) | ||
(out1, out2) = func([sample.input, *sample.args], is_cuda=False, is_fastpath=False, **sample.kwargs) | ||
(out1, out2) = func([sample.input, *sample.args], is_cuda=False, expect_fastpath=False, **sample.kwargs) | ||
out1.backward(torch.ones_like(out1)) | ||
self.assertIsNotNone(sample.input[0].grad) | ||
self.assertIsNone(sample.input[1].grad) | ||
|
@@ -710,7 +712,7 @@ def get_ref(func, sample): | |
class Foo: | ||
pass | ||
|
||
out = func((sample.input, *sample.args), is_cuda=False, is_fastpath=False, **sample.kwargs) | ||
out = func((sample.input, *sample.args), is_cuda=False, expect_fastpath=False, **sample.kwargs) | ||
foo = Foo() | ||
meta_dict = out[0].grad_fn.metadata | ||
meta_dict[0] = foo | ||
|
@@ -772,12 +774,24 @@ def test_tensors_grouping(self): | |
self.assertEqual(l3[i], list3[index]) | ||
self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list) | ||
|
||
@onlyCUDA | ||
def test_0dim_tensor_overload_cpu_ok(self): | ||
tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)] | ||
scalar_cpu_tensor = torch.tensor(4.0, device="cpu") | ||
|
||
# For mul and div, the scalar is allowed to be on CPU too | ||
actual = torch._foreach_mul(tensors, scalar_cpu_tensor) | ||
self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors]) | ||
actual = torch._foreach_div(tensors, scalar_cpu_tensor) | ||
self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors]) | ||
|
||
|
||
@onlyCUDA | ||
def test_0dim_tensor_overload_exception(self): | ||
# check exceptions of fast path | ||
tensors = [make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)] | ||
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"): | ||
torch._foreach_mul(tensors, torch.tensor(1.0, device="cpu")) | ||
torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works now, but made me realize I didn't add a case for _foreach_add when I added the overload. Adding that now. |
||
|
||
tensors = [make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")] | ||
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to expect_fastpath, which is what it really is and is less confusing