-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ScalarTensor or 0dim overload for _foreach_add #111079
Changes from 4 commits
2560ccd
32f68ba
792ff30
9b19da0
cd75529
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -778,15 +778,14 @@ def test_tensors_grouping(self): | |
def test_0dim_tensor_overload_exception(self): | ||
# check exceptions of fast path | ||
tensors = [make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)] | ||
|
||
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test looks to be already tested when d="cuda" in the next few lines |
||
torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda")) | ||
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"): | ||
torch._foreach_mul(tensors, torch.tensor(1.0, device="cpu")) | ||
|
||
tensors = [make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")] | ||
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"): | ||
torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda")) | ||
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"): | ||
torch._foreach_add(tensors, torch.tensor([1.0, 1.0], device="cuda")) | ||
|
||
@onlyCUDA | ||
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -316,7 +316,10 @@ def is_foreach_func(f: NativeFunction) -> bool: | |
# is functional for their backward derivatives (and forward derivatives in the future), i.e., | ||
# they would find such one in `functional_info_by_signature`. There however are some exceptions: | ||
_foreach_with_inplace_ref = {"_foreach_zero_"} | ||
_foreach_with_tensor_overload = {"_foreach_mul.Tensor"} | ||
_foreach_with_tensor_overload = { | ||
"_foreach_add.Tensor", | ||
"_foreach_mul.Tensor", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we already had this code for mul, why is it so much work to add support for add? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
|
||
|
||
# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up inlining this change due to not being able to specify that scalar is of type opmath_t. I cannot figure out where to put
using opmath_t = at::opmath_type<T>;
before the function signature but after the template specification.