-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[complex32] cat, fill_(partial), item #75010
Changes from all commits
abb6759
5d4cf01
9fe32da
2f8ac91
0b0eed7
17c3eb2
5b08cd5
45ad3b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ struct FillFunctor { | |
}; | ||
|
||
void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) { | ||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "fill_cuda", [&]() { | ||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kHalf, kBFloat16, iter.dtype(), "fill_cuda", [&]() { | ||
gpu_kernel(iter, FillFunctor<scalar_t>(value.to<scalar_t>())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note to self: jiterate in upcoming PR |
||
}); | ||
} | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -67,8 +67,13 @@ def tearDownClass(cls): | |||
@onlyNativeDeviceTypes | ||||
@ops(op_db, dtypes=OpDTypes.none) | ||||
def test_dtypes(self, device, op): | ||||
# Check complex32 support only if the op claims. | ||||
# TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. | ||||
include_complex32 = ((torch.complex32,) if op.supports_dtype(torch.complex32, device) else ()) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice updates |
||||
|
||||
# dtypes to try to backward in | ||||
allowed_backward_dtypes = floating_and_complex_types_and(torch.bfloat16, torch.float16) | ||||
allowed_backward_dtypes = floating_and_complex_types_and( | ||||
*((torch.half, torch.bfloat16) + include_complex32)) | ||||
|
||||
# lists for (un)supported dtypes | ||||
supported_dtypes = [] | ||||
|
@@ -81,7 +86,8 @@ def unsupported(dtype): | |||
if dtype in allowed_backward_dtypes: | ||||
unsupported_backward_dtypes.append(dtype) | ||||
|
||||
for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): | ||||
for dtype in all_types_and_complex_and( | ||||
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)): | ||||
# tries to acquire samples - failure indicates lack of support | ||||
requires_grad = (dtype in allowed_backward_dtypes and op.supports_autograd) | ||||
try: | ||||
|
@@ -704,6 +710,18 @@ def check_tensor_floating_is_differentiable(t): | |||
for arg in sample.kwargs.values(): | ||||
check_tensor_floating_is_differentiable(arg) | ||||
|
||||
# Reference testing for operations in complex32 against complex64. | ||||
# NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype. | ||||
@ops(op_db, allowed_dtypes=(torch.complex32,)) | ||||
def test_complex_half_reference_testing(self, device, dtype, op): | ||||
if not op.supports_dtype(torch.complex32, device): | ||||
unittest.skip("Does not support complex32") | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't filter in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused about this. Isn't setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. That is sufficient! Thank you |
||||
|
||||
for sample in op.sample_inputs(device, dtype): | ||||
actual = op(sample.input, *sample.args, **sample.kwargs) | ||||
(inp, args, kwargs) = sample.transform(lambda x: x.to(torch.complex64)) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the future we likely want to update the
|
||||
expected = op(inp, *args, **kwargs) | ||||
self.assertEqual(actual, expected, exact_dtype=False) | ||||
|
||||
class TestCompositeCompliance(TestCase): | ||||
# Checks if the operator (if it is composite) is written to support most | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5194,6 +5194,11 @@ def make_tensor_wrapper(shape, dtype): | |
atol = 1e-2 | ||
self.assertEqual(src, dst.copy_(t), rtol=rtol, atol=atol) | ||
|
||
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32)) | ||
def test_item(self, device, dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there any existing item tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
t = torch.ones((), device=device, dtype=dtype) | ||
self.assertEqual(1, t.item()) | ||
|
||
|
||
# Tests that compare a device's computation with the (gold-standard) CPU's. | ||
class TestDevicePrecision(TestCase): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -950,7 +950,7 @@ def supported_backward_dtypes(self, device_type): | |
else: | ||
backward_dtypes = self.backward_dtypes | ||
|
||
allowed_backward_dtypes = floating_and_complex_types_and(torch.bfloat16, torch.float16) | ||
allowed_backward_dtypes = floating_and_complex_types_and(torch.bfloat16, torch.float16, torch.complex32) | ||
return set(allowed_backward_dtypes).intersection(backward_dtypes) | ||
|
||
def supports_complex_autograd(self, device_type): | ||
|
@@ -9198,11 +9198,6 @@ def ref_pairwise_distance(input1, input2): | |
skips=( | ||
# RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument | ||
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), | ||
# File "test/test_binary_ufuncs.py", line 334, in test_batch_vs_slicing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice cleanup |
||
# actual = torch.stack(actual) | ||
# RuntimeError: "cat_cuda" not implemented for 'ComplexHalf' | ||
DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_batch_vs_slicing', | ||
device_type='cuda', dtypes=(torch.half,)), | ||
)), | ||
BinaryUfuncInfo('copysign', | ||
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), | ||
|
@@ -14205,7 +14200,7 @@ def ref_pairwise_distance(input1, input2): | |
OpInfo('cat', | ||
ref=lambda input_seq, dim=0, **kwargs: np.concatenate(input_seq, axis=dim, **kwargs), | ||
aliases=('concat',), | ||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), | ||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), | ||
sample_inputs_func=sample_inputs_cat_concat, | ||
supports_forward_ad=True, | ||
supports_fwgrad_bwgrad=True, | ||
|
@@ -14310,7 +14305,8 @@ def ref_pairwise_distance(input1, input2): | |
supports_fwgrad_bwgrad=True, | ||
# https://github.com/pytorch/pytorch/issues/66357 | ||
check_batched_forward_grad=False, | ||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), | ||
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), | ||
backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should start failing once we support |
||
supports_out=False, | ||
skips=( | ||
# JIT has issue when op is passed as lambda | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Used by call to
item
whichfill_
calls for 0-DTensor
overload.