Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[complex32] cat, fill_(partial), item #75010

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/ScalarOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ inline void fill_inplace(Tensor& self, const Scalar& value_scalar) {

namespace detail {
Tensor& scalar_fill(Tensor& self, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() {
fill_inplace<scalar_t>(self, value);
});
return self;
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/Scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ Scalar item(const Tensor& self) {

Scalar _local_scalar_dense_cpu(const Tensor& self) {
Scalar r;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used by call to item which fill_ calls for 0-D Tensor overload.

kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
scalar_t value = *self.data_ptr<scalar_t>();
r = Scalar(value);
});
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/CUDAScalar.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace native {

Scalar _local_scalar_dense_cuda(const Tensor& self) {
Scalar r;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
scalar_t value;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::memcpy_and_sync(&value, self.data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/FillKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct FillFunctor {
};

void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "fill_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kHalf, kBFloat16, iter.dtype(), "fill_cuda", [&]() {
gpu_kernel(iter, FillFunctor<scalar_t>(value.to<scalar_t>()));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: jiterate in upcoming PR

});
}
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/cuda/Shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) {
allContiguous &&
all32BitIndexable &&
allSameType) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE, 1>(out, inputs, dimension, nDims, memory_format);
});
Expand All @@ -405,8 +405,8 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) {
all32BitIndexable &&
allSameType &&
memory_format == c10::MemoryFormat::Contiguous) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(out, inputs, dimension, nDims, memory_format);
});
Expand Down
22 changes: 20 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,13 @@ def tearDownClass(cls):
@onlyNativeDeviceTypes
@ops(op_db, dtypes=OpDTypes.none)
def test_dtypes(self, device, op):
# Check complex32 support only if the op claims.
# TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
include_complex32 = ((torch.complex32,) if op.supports_dtype(torch.complex32, device) else ())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice updates


# dtypes to try to backward in
allowed_backward_dtypes = floating_and_complex_types_and(torch.bfloat16, torch.float16)
allowed_backward_dtypes = floating_and_complex_types_and(
*((torch.half, torch.bfloat16) + include_complex32))

# lists for (un)supported dtypes
supported_dtypes = []
Expand All @@ -81,7 +86,8 @@ def unsupported(dtype):
if dtype in allowed_backward_dtypes:
unsupported_backward_dtypes.append(dtype)

for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
for dtype in all_types_and_complex_and(
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)):
# tries to acquire samples - failure indicates lack of support
requires_grad = (dtype in allowed_backward_dtypes and op.supports_autograd)
try:
Expand Down Expand Up @@ -704,6 +710,18 @@ def check_tensor_floating_is_differentiable(t):
for arg in sample.kwargs.values():
check_tensor_floating_is_differentiable(arg)

# Reference testing for operations in complex32 against complex64.
# NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype.
@ops(op_db, allowed_dtypes=(torch.complex32,))
def test_complex_half_reference_testing(self, device, dtype, op):
if not op.supports_dtype(torch.complex32, device):
unittest.skip("Does not support complex32")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't filter in @ops decorator as helper to decide if a dtype is supported requires device

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about this. Isn't setting allowed_dtypes in the decorator sufficient? I thought the decorator accounted for the device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. That is sufficient! Thank you


for sample in op.sample_inputs(device, dtype):
actual = op(sample.input, *sample.args, **sample.kwargs)
(inp, args, kwargs) = sample.transform(lambda x: x.to(torch.complex64))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future we likely want to update the numpy() method to handle this cast like it handles bfloat16, too:

expected = op(inp, *args, **kwargs)
self.assertEqual(actual, expected, exact_dtype=False)

class TestCompositeCompliance(TestCase):
# Checks if the operator (if it is composite) is written to support most
Expand Down
5 changes: 5 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5194,6 +5194,11 @@ def make_tensor_wrapper(shape, dtype):
atol = 1e-2
self.assertEqual(src, dst.copy_(t), rtol=rtol, atol=atol)

@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32))
def test_item(self, device, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any existing item tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find any with simple search
Screenshot 2022-03-31 at 11 32 15 PM

t = torch.ones((), device=device, dtype=dtype)
self.assertEqual(1, t.item())


# Tests that compare a device's computation with the (gold-standard) CPU's.
class TestDevicePrecision(TestCase):
Expand Down
12 changes: 4 additions & 8 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def supported_backward_dtypes(self, device_type):
else:
backward_dtypes = self.backward_dtypes

allowed_backward_dtypes = floating_and_complex_types_and(torch.bfloat16, torch.float16)
allowed_backward_dtypes = floating_and_complex_types_and(torch.bfloat16, torch.float16, torch.complex32)
return set(allowed_backward_dtypes).intersection(backward_dtypes)

def supports_complex_autograd(self, device_type):
Expand Down Expand Up @@ -9198,11 +9198,6 @@ def ref_pairwise_distance(input1, input2):
skips=(
# RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
# File "test/test_binary_ufuncs.py", line 334, in test_batch_vs_slicing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup

# actual = torch.stack(actual)
# RuntimeError: "cat_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_batch_vs_slicing',
device_type='cuda', dtypes=(torch.half,)),
)),
BinaryUfuncInfo('copysign',
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
Expand Down Expand Up @@ -14205,7 +14200,7 @@ def ref_pairwise_distance(input1, input2):
OpInfo('cat',
ref=lambda input_seq, dim=0, **kwargs: np.concatenate(input_seq, axis=dim, **kwargs),
aliases=('concat',),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
sample_inputs_func=sample_inputs_cat_concat,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -14310,7 +14305,8 @@ def ref_pairwise_distance(input1, input2):
supports_fwgrad_bwgrad=True,
# https://github.com/pytorch/pytorch/issues/66357
check_batched_forward_grad=False,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should start failing once we support sum for complex32 (as it is used in the backward)

supports_out=False,
skips=(
# JIT has issue when op is passed as lambda
Expand Down