Skip to content

Conversation

@kshitij12345
Copy link
Collaborator

Fixes #47176

@dr-ci
Copy link

dr-ci bot commented Nov 7, 2020

💊 CI failures summary and remediations

As of commit 80694d0 (more details on the Dr. CI page):



🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (1/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Nov 11 18:14:42 sccache: error: couldn't connect to server
Nov 11 18:14:42 +++ eval 'extract_trap_cmd ' 
Nov 11 18:14:42 ++++ extract_trap_cmd 
Nov 11 18:14:42 ++++ printf '%s\n' '' 
Nov 11 18:14:42 +++ printf '%s\n' cleanup 
Nov 11 18:14:42 ++ trap -- ' 
Nov 11 18:14:42 cleanup' EXIT 
Nov 11 18:14:42 ++ [[ pytorch-linux-xenial-py3.6-gcc5.4-test != *pytorch-win-* ]] 
Nov 11 18:14:42 ++ which sccache 
Nov 11 18:14:42 ++ sccache --stop-server 
Nov 11 18:14:42 Stopping sccache server... 
Nov 11 18:14:42 sccache: error: couldn't connect to server 
Nov 11 18:14:42 sccache: caused by: Connection refused (os error 111) 
Nov 11 18:14:42 ++ true 
Nov 11 18:14:42 ++ rm /var/lib/jenkins/sccache_error.log 
Nov 11 18:14:42 ++ [[ pytorch-linux-xenial-py3.6-gcc5.4-test == *rocm* ]] 
Nov 11 18:14:42 ++ SCCACHE_ERROR_LOG=/var/lib/jenkins/sccache_error.log 
Nov 11 18:14:42 ++ SCCACHE_IDLE_TIMEOUT=1200 
Nov 11 18:14:42 ++ RUST_LOG=sccache::server=error 
Nov 11 18:14:42 ++ sccache --start-server 
Nov 11 18:14:42 sccache: Starting the server... 
Nov 11 18:14:42 ++ sccache --zero-stats 

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (2/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Nov 11 19:16:57 ERROR [11.782s]: test_kthvalue_xla_float64 (__main__.TestTorchDeviceTypeXLA)
Nov 11 19:15:54   test_where_scalar_valid_combination_xla_float16 (__main__.TestTorchDeviceTypeXLA) ... skip (0.002s) 
Nov 11 19:15:57   test_where_scalar_valid_combination_xla_float32 (__main__.TestTorchDeviceTypeXLA) ... ok (3.114s) 
Nov 11 19:16:54   test_where_scalar_valid_combination_xla_float64 (__main__.TestTorchDeviceTypeXLA) ... ok (56.358s) 
Nov 11 19:16:54   test_where_scalar_valid_combination_xla_int16 (__main__.TestTorchDeviceTypeXLA) ... ok (0.504s) 
Nov 11 19:16:55   test_where_scalar_valid_combination_xla_int32 (__main__.TestTorchDeviceTypeXLA) ... ok (0.573s) 
Nov 11 19:16:56   test_where_scalar_valid_combination_xla_int64 (__main__.TestTorchDeviceTypeXLA) ... ok (1.606s) 
Nov 11 19:16:57   test_where_scalar_valid_combination_xla_int8 (__main__.TestTorchDeviceTypeXLA) ... ok (0.502s) 
Nov 11 19:16:57   test_where_scalar_valid_combination_xla_uint8 (__main__.TestTorchDeviceTypeXLA) ... ok (0.576s) 
Nov 11 19:16:57  
Nov 11 19:16:57 ====================================================================== 
Nov 11 19:16:57 ERROR [11.782s]: test_kthvalue_xla_float64 (__main__.TestTorchDeviceTypeXLA) 
Nov 11 19:16:57 ---------------------------------------------------------------------- 
Nov 11 19:16:57 Traceback (most recent call last): 
Nov 11 19:16:57   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 272, in instantiated_test 
Nov 11 19:16:57     result = test_fn(self, *args) 
Nov 11 19:16:57   File "/var/lib/jenkins/workspace/xla/test/../../test/test_torch.py", line 9329, in test_kthvalue 
Nov 11 19:16:57     self.assertEqual(x.squeeze().kthvalue(1), x.kthvalue(1)) 
Nov 11 19:16:57 RuntimeError: torch_xla/csrc/helpers.cpp:97 : Check failed: min_shape_dim <= dim && dim <= max_shape_dim  
Nov 11 19:16:57 *** Begin stack trace *** 
Nov 11 19:16:57 	tensorflow::CurrentStackTrace[abi:cxx11]() 
Nov 11 19:16:57 	torch_xla::XlaHelpers::GetCanonicalDimensionIndex(long long, long long) 

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

Since your merge base is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 36 times.

// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
cuda::detail::IndexToOffset<scalar_t, IndexType, 1>::get(li, b);
cuda::detail::IndexToOffset<scalar_t, IndexType, ADims>::get(li, b);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a correct fix. In most cases mask and output will be contiguous for discontiguous input, and the dim here needs to be 1 (they are coalesced to 1). Only in the case where suggest_memory_format is channels_last and for some reason vectorized kernel is not used will this error happen.

Copy link
Collaborator Author

@kshitij12345 kshitij12345 Nov 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto self_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(self);
auto ret_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(ret);
auto mask_info = cuda::detail::getTensorInfo<uint8_t, unsigned int>(mask);
self_info.collapseDims();
ret_info.collapseDims();
mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor

Post this I just printed the dims for each Tensor

      std::cout << self_info.dims << " " << mask_info.dims << " " << ret_info.dims << "\n";
      auto print_dims = [](auto tensor_info) {
        for (int i = 0; i < tensor_info.dims; ++i) {
          std::cout << tensor_info.sizes[i] << " ";
        }
        std::cout << "\n";
      };
      std::cout << "SELF_INFO SIZES:";
      print_dims(self_info);
      std::cout << "MASK_INFO SIZES:";
      print_dims(mask_info);
      std::cout << "RET_INFO SIZES:";
      print_dims(ret_info);

And the output is (for the case in the issue with self (2, 3, 3, 3)),

3 3 3
SELF_INFO SIZES:2 3 9 
MASK_INFO SIZES:2 3 9 
RET_INFO SIZES:2 3 9 

This is why I believe, replacing 1 with ADims is correct.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example in the issue is one of possible failure modes. In other cases input could be discontiguous (not collapsible), but mask/output will be collapsible. In this case ADim will be -1, so using -1 for output will work too, but it's less efficient than 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fused_dropout_cuda(const Tensor& self, double p, c10::optional<Generator> gen_){
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
Tensor ret = at::empty_like(self, self.suggest_memory_format());
Tensor mask = at::empty(self.sizes(), self.options().dtype(kByte), self.suggest_memory_format());

Also printing the strides after mask using (for the case in the issue with self (2, 3, 3, 3))

  std::cout << "SELF STRIDES:" << self.strides() << "\n";
  std::cout << "RET STRIDES:" << ret.strides() << "\n";
  std::cout << "MASK STRIDES:" << mask.strides() << "\n";
  std::cout << "RET CONTIGUOUS:" << ret.is_contiguous() << "\n";
  std::cout << "MASK CONTIGUOUS:" << mask.is_contiguous() << "\n";

Output

SELF STRIDES:[54, 1, 18, 6]
RET STRIDES:[27, 1, 9, 3]
MASK STRIDES:[27, 1, 9, 3]
RET CONTIGUOUS:0
MASK CONTIGUOUS:0

Note ret and mask are also not contiguous and their stride are similar to self

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example in the issue is one of possible failure modes. In other cases input could be discontiguous (not collapsible), but mask/output will be collapsible. In this case ADim will be -1, so using -1 for output will work too, but it's less efficient than 1.

I think mask and ret would mostly preserve the behavior that input has 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don't

import torch
mod = torch.nn.Dropout(0.5)
inp = torch.randn(3,4, device="cuda").t()
out=mod(inp)
print(inp.stride(), out.stride()) # (1, 4) (3, 1), output is contiguous, input is not
inp = torch.randn(3,4, device="cuda")[:,::2]
out = mod(inp)
print(inp.stride(), out.stride()) # (4, 2) (2, 1) output is contiguous, input is not

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. I see. Thank you very much!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note:

After changing,

Tensor ret = at::empty_like(self, self.suggest_memory_format()); 
Tensor mask = at::empty(self.sizes(), self.options().dtype(kByte), self.suggest_memory_format()); 

to

Tensor ret = at::empty_like(self);
Tensor mask = at::empty_like(self, self.options().dtype(kByte));

Output of above snippet is

(1, 4) (1, 4)
(4, 2) (2, 1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's what we want to happen - dropout preserves strides of input if input is memory-dense, and makes output dense while preserving input permutation if input is not memory-dense.
Note also that the first case can go to vectorized kernel now.

@ngimel
Copy link
Collaborator

ngimel commented Nov 7, 2020

For a good fix

  1. please allocate mask and output with empty_like, no need to explicitly specialize format, empty_like should do the right thing
  2. change eligibility condition for vectorized kernel to all the tensors being non_overlapping_and_dense instead of contiguous in one of the 2 memory formats
  3. if non-vectorized path is taken, add another specialization for non-contiguous input and contiguous mask/output, so there are (1,1), (-1, 1), (-1,-1)
  4. beef up testing, so that differently permuted, discontiguous, channels-last and misaligned inputs are tested. Also, test with reasonable dropout probability and check that the sum of the values is approximately equal to original sum (tolerance can be fairly large)

* use empty_like without passing layout.
* update vectorized condition to be is_non_overlapping_and_dense.
@kshitij12345 kshitij12345 force-pushed the fix/dropout/channel-last-discontig branch from 4c0d464 to d5629ca Compare November 8, 2020 08:27
str(module)

def _test_dropout_discontiguous(self, cls, device, memory_format=torch.contiguous_format):
# In this test, we verify that dropout preserves the layout and data for different memory formats.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test does not test that layout was preserved, so this is a misleading comment.

test/test_nn.py Outdated
def _test_dropout_mean_check(self, cls, device):
def _test_dropout_mean(inp, out):
self.assertNotEqual(inp, out)
self.assertEqual(inp.mean(), out.mean(), rtol=3, atol=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also does not test that layout was preserved. Also, rtol=3 seems way too large. Maybe try initializing tensors with ones(), then tolerances can be small?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Will try that. Thanks!

test/test_nn.py Outdated
self.assertNotEqual(inp, out)
self.assertEqual(inp.mean(), out.mean(), rtol=3, atol=1)

for memory_format in [torch.contiguous_format, torch.channels_last]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of looping over memory formats, loop over all possible permutations, like it's done here. Then you won't need memory format and transpose branches that you have below https://github.com/pytorch/pytorch/pull/47124/files#diff-8aa1a200ec63d23db422aa31b6dca1e6cb372887c43b064ef435210b1b0dec0aR18028
Also, these tests won't be testing non-vectorized kernel. To do that, you'll to have unaligned tensors (that are allocated at 1: index of a regular aligned tensor, see test_softmax_results.
Please do something like

for p in permutations...:
    for shift in shifts:
         input = torch.ones(...).permute(p).contiguous().permute(inverse_p)
         input = input[shift[0]:, shift[1]]
         compare results
         compare layout
additional branch for non-memory dense

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the reference!

@heitorschueroff heitorschueroff added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 9, 2020
test/test_nn.py Outdated
inp_discontiguous.copy_(inp)
mod = cls(p=p)
out = mod(inp_discontiguous)
self.assertEqual(out.layout, inp_discontiguous.layout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layout in all cases is torch.strided, you want to compare strides or memory format, but not layout

test/test_nn.py Outdated

def _test_dropout_mean_check(self, cls, device):
def _test_dropout_mean(inp, out):
self.assertEqual(inp.layout, out.layout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, the invariant should be that inverse permutation of out is contiguous, right? Probably start with (3,3,4,5) tensor so that even with the shift there are no 1 dimensions, there might be annoying corner cases for 1 dimensions.

test/test_nn.py Outdated
shifts = [(0, 0), (1, 0), (0, 1), (1, 1)]
for perm in itertools.permutations((0, 1, 2, 3), r=4):
for shift in shifts:
for p in [0.3, 0.5, 0.7]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add small probability here, and check for that case that all the values are preserved, similar to _test_dropout_discontiguous?

@kshitij12345 kshitij12345 force-pushed the fix/dropout/channel-last-discontig branch from 7d7a1ab to 80694d0 Compare November 11, 2020 17:39
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 4b25d83.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dropout broken on cuda for discontiguous channels-last input

5 participants