Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[complex32] cat, fill_(partial), item #75010

Closed

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Mar 31, 2022

Reference : #74537

cat_backwards (on CUDA) requires support for fill, have added support for fill. (Also fill requires item support)

Now fill backward requires sum (will add it in later PR).

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 31, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 45ad3b9 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@@ -14301,8 +14304,10 @@ def ref_pairwise_distance(input1, input2):
supports_fwgrad_bwgrad=True,
# https://github.com/pytorch/pytorch/issues/66357
check_batched_forward_grad=False,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should start failing once we support sum for complex32 (as it is used in the backward)

@@ -20,8 +20,8 @@ Scalar item(const Tensor& self) {

Scalar _local_scalar_dense_cpu(const Tensor& self) {
Scalar r;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used by call to item which fill_ calls for 0-D Tensor overload.

@kshitij12345 kshitij12345 changed the title [complex32] cat, fill_(partial) [complex32] cat, fill_(partial), item Mar 31, 2022
@@ -19,7 +19,7 @@ struct FillFunctor {
};

void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "fill_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kHalf, kBFloat16, iter.dtype(), "fill_cuda", [&]() {
gpu_kernel(iter, FillFunctor<scalar_t>(value.to<scalar_t>()));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: jiterate in upcoming PR

@ops(op_db, allowed_dtypes=(torch.complex32,))
def test_complex_half_reference_testing(self, device, dtype, op):
if not op.supports_dtype(torch.complex32, device):
unittest.skip("Does not support complex32")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't filter in @ops decorator as helper to decide if a dtype is supported requires device

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about this. Isn't setting allowed_dtypes in the decorator sufficient? I thought the decorator accounted for the device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. That is sufficient! Thank you

@kshitij12345 kshitij12345 marked this pull request as ready for review March 31, 2022 17:53
@kshitij12345
Copy link
Collaborator Author

@anjali411
I had a talk with @mruberry regarding adding a new meta-data supports_complex32 to handle the failures. But instead we decided to skip complex32 support by updating the test logic.


for sample in op.sample_inputs(device, dtype):
actual = op(sample.input, *sample.args, **sample.kwargs)
(inp, args, kwargs) = sample.transform(lambda x: x.to(torch.complex64))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future we likely want to update the numpy() method to handle this cast like it handles bfloat16, too:

@@ -5194,6 +5194,11 @@ def make_tensor_wrapper(shape, dtype):
atol = 1e-2
self.assertEqual(src, dst.copy_(t), rtol=rtol, atol=atol)

@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32))
def test_item(self, device, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any existing item tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find any with simple search
Screenshot 2022-03-31 at 11 32 15 PM

@@ -9198,11 +9198,6 @@ def ref_pairwise_distance(input1, input2):
skips=(
# RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
# File "test/test_binary_ufuncs.py", line 334, in test_batch_vs_slicing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup

@@ -67,8 +67,13 @@ def tearDownClass(cls):
@onlyNativeDeviceTypes
@ops(op_db, dtypes=OpDTypes.none)
def test_dtypes(self, device, op):
# Check complex32 support only if the op claims.
# TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
include_complex32 = ((torch.complex32,) if op.supports_dtype(torch.complex32, device) else ())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice updates

@dagitses dagitses added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 1, 2022
Copy link
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pytorchbot merge this please

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 1, 2022

+1

@github-actions
Copy link

github-actions bot commented Apr 1, 2022

Hey @kshitij12345.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@kshitij12345 kshitij12345 added module: complex Related to complex number support in PyTorch module: half Related to float16 half-precision floats labels Apr 1, 2022
facebook-github-bot pushed a commit that referenced this pull request Apr 4, 2022
Summary:
Reference : #74537

`cat_backwards` (on CUDA) requires support for `fill`, have added support for `fill`. (Also `fill` requires `item` support)

Now `fill` backward requires `sum` (will add it in later PR).

Pull Request resolved: #75010
Approved by: https://github.com/anjali411

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/65b65af2367c451a88672e9cffd40bd839d7ccb7

Reviewed By: atalman

Differential Revision: D35317407

fbshipit-source-id: f735f5090539a4f598b413a3c40181a10bd0ad58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: complex Related to complex number support in PyTorch module: half Related to float16 half-precision floats open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants