Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Error with torch.flip() for cuda tensors when dims=() #50325

Closed
wants to merge 4 commits into from

Conversation

dheerajgattupalli
Copy link
Contributor

@dheerajgattupalli dheerajgattupalli commented Jan 9, 2021

Fixes #49982

The method flip_check_errors was being called in cuda file which had a condition to throw an exception for when dims size is <=0 changed that to <0 and added seperate condition for when equal to zero to return from the method... the return was needed because after this point the method was performing check expecting a non-zero size dims ...

Also removed the comment/condition written to point to the issue

@mruberry @kshitij12345 please review this once

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 9, 2021

💊 CI failures summary and remediations

As of commit 9c1ed86 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 8 times.

@dheerajgattupalli
Copy link
Contributor Author

dheerajgattupalli commented Jan 9, 2021

Hi,
is it ok if I make one more commit for fixing the tab issue or is it better to close this and a create new cleaner PR?

I think the test_tensorexpr.py file was also updated ... so I have to remove that change too... Sorry maybe this needs to be closed i will try and create a cleaner PR...

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dheerajgattupalli,

Changes look good.

I think the test_tensorexpr.py file was also updated ... so I have to remove that change too... Sorry maybe this needs to be closed i will try and create a cleaner PR...

I don't think you need to make any changes in test_tensorexpr.py (maybe I could be wrong).
It is fine to push fixes in this PR itself. No worries about that.

I forgot to mention that you'll also have to update the code here.

def sample_inputs_flip(op_info, device, dtype, requires_grad):
tensors = (
make_tensor((S, M, S), device, dtype, low=None, high=None, requires_grad=requires_grad),
make_tensor((S, 0, M), device, dtype, low=None, high=None, requires_grad=requires_grad)
)
dims = ((0, 1, 2), (0,), (0, 2), (-1,))
# On CUDA, `dims=()` errors out with IndexError
# Reference: https://github.com/pytorch/pytorch/issues/49982
if device == 'cpu':
dims = dims + ((),) # type: ignore
samples = [SampleInput(tensor, kwargs={'dims': dim}) for tensor, dim in product(tensors, dims)]
return samples

After fixing that part, you can run the flip test in test_ops.py.

aten/src/ATen/native/TensorTransformations.h Outdated Show resolved Hide resolved
@dheerajgattupalli
Copy link
Contributor Author

Hi @kshitij12345 ,

Ignore the part about test_tensorexpr.py...yeah it's not needed here i was confusing this with other issue..

Made the changes you mentioned ... now the quick-checks test is also successful... Hopefully, everything is correct now...

Thanks ...

@codecov
Copy link

codecov bot commented Jan 9, 2021

Codecov Report

Merging #50325 (9c1ed86) into master (d4c1684) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #50325      +/-   ##
==========================================
- Coverage   80.71%   80.71%   -0.01%     
==========================================
  Files        1904     1904              
  Lines      206686   206684       -2     
==========================================
- Hits       166830   166827       -3     
- Misses      39856    39857       +1     

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Even all the CI tests are passing. Few minor updates.

@mruberry will review and shepherd the PR.

@@ -10,8 +10,11 @@ namespace at {
namespace native {

static inline void flip_check_errors(int64_t total_dims, int64_t flip_dims_size, IntArrayRef dims) {
if (flip_dims_size==0){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Formatting if (flip_dims_size == 0) {

# Reference: https://github.com/pytorch/pytorch/issues/49982
if device == 'cpu':
dims = dims + ((),) # type: ignore
dims = dims + ((),) # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just do:
dims =((0, 1, 2), (0,), (0, 2), (-1,), ())

@mruberry mruberry self-requested a review January 10, 2021 10:33
@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 10, 2021
@mstfbl
Copy link
Collaborator

mstfbl commented Jan 10, 2021

I'm wondering, is there a specific reason why flip_check_errors(int64_t total_dims, int64_t flip_dims_size, IntArrayRef dims) is being called in Tensor flip_cuda(const Tensor& self, IntArrayRef dims) (line 74), but not in Tensor flip_cpu(const Tensor& self, IntArrayRef dims)? It doesn't make sense to me why we're not verifying the number of axis in dim when calling Tensor.flip(dim) in CPU.

// Flip tensor given a list of dims
Tensor flip_cuda(const Tensor& self, IntArrayRef dims) {
auto in_tensor = self;
const int64_t flip_dims_size = dims.size(), total_dims = in_tensor.dim(), N = in_tensor.numel();
flip_check_errors(total_dims, flip_dims_size, dims);
int64_t block_size = 512;
dim3 dim_block(block_size);
dim3 dim_grid((N + block_size - 1) / block_size);
auto out_tensor = at::empty_like(in_tensor, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
if (out_tensor.numel() == 0) {
return out_tensor;
}

Tensor flip_cpu(const Tensor& self, IntArrayRef dims) {
auto in_tensor = self;
const int64_t total_dims = in_tensor.dim();
auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims);
Tensor out_tensor = at::empty_like(in_tensor, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// create contiguous strides for input tensor
auto stride_contiguous_v = std::vector<int64_t>(total_dims);
for (int64_t i = total_dims - 1; i >= 0; i--) {
if (i == total_dims - 1) {
stride_contiguous_v[i] = 1;
} else {
stride_contiguous_v[i] = std::max<int64_t>(in_tensor.size(i + 1), 1) * stride_contiguous_v[i + 1];
}
}

If this omission is a mistake, I suggest flip_check_errors is called in Tensor flip_cpu as well. This should be an easy change:

Tensor flip_cpu(const Tensor& self, IntArrayRef dims) {
  auto in_tensor = self;
  const int64_t total_dims = in_tensor.dim(), flip_dims_size = dims.size();
  auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims);
  Tensor out_tensor = at::empty_like(in_tensor, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  flip_check_errors(total_dims, flip_dims_size, dims);

  // create contiguous strides for input tensor
  auto stride_contiguous_v = std::vector<int64_t>(total_dims);
  for (int64_t i = total_dims - 1; i >= 0; i--) {
    if (i == total_dims - 1) {
      stride_contiguous_v[i] = 1;
    } else {
      stride_contiguous_v[i] = std::max<int64_t>(in_tensor.size(i + 1), 1) * stride_contiguous_v[i + 1];
    }
  }

CC @mruberry @kshitij12345 @dheerajgattupalli

@dheerajgattupalli
Copy link
Contributor Author

dheerajgattupalli commented Jan 10, 2021

Hi @mstfbl ,

yeah, i agree ... adding the flip_check_erros method in CPU version also will make it more consistent... I ran the tests available for torch.flip with the change and it didn't cause any issue. will push it if @kshitij12345 and @mruberry also agree...

@kshitij12345
Copy link
Collaborator

@mstfbl Good Question! Thanks for looking into it.

@dheerajgattupalli Thanks for trying the change.

From the sample below, we can see that CPU side of code does similar check.

>>> import torch
>>> a = torch.ones((4,3,2,2))
>>> torch.flip(a, (0, 0))  # CPU Repeated Dims
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: dim 0 appears multiple times in the list of dims
>>> torch.flip(a.cuda(), (0, 0)) # CUDA Repeated Dims
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: dims has duplicates, original flip dims size=2, but unique flip dims size=1
>>> torch.flip(a, (0, 7))  # CPU Dim out-of-range
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 7)
>>> torch.flip(a.cuda(), (0, 7)) # CUDA Dim out-of-range
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: The max flip dims out of range, got max flip dims=7

Besides reporting in different format, both do the same checks.

On CPU side,at::dim_list_to_bitset function verifies the relevant checks and also returns the relevant bitset for the flip dims,

static inline std::bitset<dim_bitset_size> dim_list_to_bitset(IntArrayRef dims, int64_t ndims) {
TORCH_CHECK(ndims <= (int64_t) dim_bitset_size, "only tensors with up to ", dim_bitset_size, " dims are supported");
std::bitset<dim_bitset_size> seen;
for (size_t i = 0; i < dims.size(); i++) {
size_t dim = maybe_wrap_dim(dims[i], ndims);
TORCH_CHECK(!seen[dim], "dim ", dim, " appears multiple times in the list of dims");
seen[dim] = true;
}
return seen;
}

Also note that there are relevant cases for both devices in the test,

# not allow flip on the same dim more than once
self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1))
# not allow empty list as input
self.assertRaises(TypeError, lambda: data.flip())
# not allow size of flip dim > total dims
self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3))
# not allow dim > max dim
self.assertRaises(IndexError, lambda: data.flip(3))

So I think it is fine the way it is.

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Great job!

@mruberry will review and have the binding approval.

@@ -10,8 +10,11 @@ namespace at {
namespace native {

static inline void flip_check_errors(int64_t total_dims, int64_t flip_dims_size, IntArrayRef dims) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flip_check_errors used to be called in from flip_cpu, too. I believe it was accidentally removed by https://github.com/pytorch/pytorch/pull/13344/files. Could we add it back to flip_cpu?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see the above discussion suggests these checks may be redundant on CPU and we don't need them. Thanks @kshitij12345 for pointing out they're already tested for.

@@ -10,8 +10,11 @@ namespace at {
namespace native {

static inline void flip_check_errors(int64_t total_dims, int64_t flip_dims_size, IntArrayRef dims) {
if (flip_dims_size==0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (flip_dims_size == 0) { - spaces around the == operator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I add this? the pull request is already approved new commit will not have any problem right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK, you don't need to change it.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this issue, @dheerajgattupalli, and thank you for reviewing it, @kshitij12345.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 314351d.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[bug] torch.flip: IndexError for dims=() on CUDA but works on CPU
6 participants