-
Notifications
You must be signed in to change notification settings - Fork 26k
torch.dropout: fix non-contiguous layout input #47552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.dropout: fix non-contiguous layout input #47552
Conversation
💊 CI failures summary and remediationsAs of commit 80694d0 (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
aten/src/ATen/native/cuda/Dropout.cu
Outdated
| // Convert `linearIndex` into an offset of `b` | ||
| const IndexType bOffset = | ||
| cuda::detail::IndexToOffset<scalar_t, IndexType, 1>::get(li, b); | ||
| cuda::detail::IndexToOffset<scalar_t, IndexType, ADims>::get(li, b); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a correct fix. In most cases mask and output will be contiguous for discontiguous input, and the dim here needs to be 1 (they are coalesced to 1). Only in the case where suggest_memory_format is channels_last and for some reason vectorized kernel is not used will this error happen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch/aten/src/ATen/native/cuda/Dropout.cu
Lines 225 to 230 in 5a5258c
| auto self_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(self); | |
| auto ret_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(ret); | |
| auto mask_info = cuda::detail::getTensorInfo<uint8_t, unsigned int>(mask); | |
| self_info.collapseDims(); | |
| ret_info.collapseDims(); | |
| mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor |
Post this I just printed the dims for each Tensor
std::cout << self_info.dims << " " << mask_info.dims << " " << ret_info.dims << "\n";
auto print_dims = [](auto tensor_info) {
for (int i = 0; i < tensor_info.dims; ++i) {
std::cout << tensor_info.sizes[i] << " ";
}
std::cout << "\n";
};
std::cout << "SELF_INFO SIZES:";
print_dims(self_info);
std::cout << "MASK_INFO SIZES:";
print_dims(mask_info);
std::cout << "RET_INFO SIZES:";
print_dims(ret_info);And the output is (for the case in the issue with self (2, 3, 3, 3)),
3 3 3
SELF_INFO SIZES:2 3 9
MASK_INFO SIZES:2 3 9
RET_INFO SIZES:2 3 9
This is why I believe, replacing 1 with ADims is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An example in the issue is one of possible failure modes. In other cases input could be discontiguous (not collapsible), but mask/output will be collapsible. In this case ADim will be -1, so using -1 for output will work too, but it's less efficient than 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch/aten/src/ATen/native/cuda/Dropout.cu
Lines 200 to 203 in 5a5258c
| fused_dropout_cuda(const Tensor& self, double p, c10::optional<Generator> gen_){ | |
| auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator()); | |
| Tensor ret = at::empty_like(self, self.suggest_memory_format()); | |
| Tensor mask = at::empty(self.sizes(), self.options().dtype(kByte), self.suggest_memory_format()); |
Also printing the strides after mask using (for the case in the issue with self (2, 3, 3, 3))
std::cout << "SELF STRIDES:" << self.strides() << "\n";
std::cout << "RET STRIDES:" << ret.strides() << "\n";
std::cout << "MASK STRIDES:" << mask.strides() << "\n";
std::cout << "RET CONTIGUOUS:" << ret.is_contiguous() << "\n";
std::cout << "MASK CONTIGUOUS:" << mask.is_contiguous() << "\n";Output
SELF STRIDES:[54, 1, 18, 6]
RET STRIDES:[27, 1, 9, 3]
MASK STRIDES:[27, 1, 9, 3]
RET CONTIGUOUS:0
MASK CONTIGUOUS:0
Note ret and mask are also not contiguous and their stride are similar to self
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An example in the issue is one of possible failure modes. In other cases input could be discontiguous (not collapsible), but mask/output will be collapsible. In this case ADim will be -1, so using -1 for output will work too, but it's less efficient than 1.
I think mask and ret would mostly preserve the behavior that input has 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They don't
import torch
mod = torch.nn.Dropout(0.5)
inp = torch.randn(3,4, device="cuda").t()
out=mod(inp)
print(inp.stride(), out.stride()) # (1, 4) (3, 1), output is contiguous, input is not
inp = torch.randn(3,4, device="cuda")[:,::2]
out = mod(inp)
print(inp.stride(), out.stride()) # (4, 2) (2, 1) output is contiguous, input is not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah. I see. Thank you very much!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note:
After changing,
Tensor ret = at::empty_like(self, self.suggest_memory_format());
Tensor mask = at::empty(self.sizes(), self.options().dtype(kByte), self.suggest_memory_format()); to
Tensor ret = at::empty_like(self);
Tensor mask = at::empty_like(self, self.options().dtype(kByte));Output of above snippet is
(1, 4) (1, 4)
(4, 2) (2, 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that's what we want to happen - dropout preserves strides of input if input is memory-dense, and makes output dense while preserving input permutation if input is not memory-dense.
Note also that the first case can go to vectorized kernel now.
|
For a good fix
|
4c0d464 to
d5629ca
Compare
| str(module) | ||
|
|
||
| def _test_dropout_discontiguous(self, cls, device, memory_format=torch.contiguous_format): | ||
| # In this test, we verify that dropout preserves the layout and data for different memory formats. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test does not test that layout was preserved, so this is a misleading comment.
test/test_nn.py
Outdated
| def _test_dropout_mean_check(self, cls, device): | ||
| def _test_dropout_mean(inp, out): | ||
| self.assertNotEqual(inp, out) | ||
| self.assertEqual(inp.mean(), out.mean(), rtol=3, atol=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this also does not test that layout was preserved. Also, rtol=3 seems way too large. Maybe try initializing tensors with ones(), then tolerances can be small?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Will try that. Thanks!
test/test_nn.py
Outdated
| self.assertNotEqual(inp, out) | ||
| self.assertEqual(inp.mean(), out.mean(), rtol=3, atol=1) | ||
|
|
||
| for memory_format in [torch.contiguous_format, torch.channels_last]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of looping over memory formats, loop over all possible permutations, like it's done here. Then you won't need memory format and transpose branches that you have below https://github.com/pytorch/pytorch/pull/47124/files#diff-8aa1a200ec63d23db422aa31b6dca1e6cb372887c43b064ef435210b1b0dec0aR18028
Also, these tests won't be testing non-vectorized kernel. To do that, you'll to have unaligned tensors (that are allocated at 1: index of a regular aligned tensor, see test_softmax_results.
Please do something like
for p in permutations...:
for shift in shifts:
input = torch.ones(...).permute(p).contiguous().permute(inverse_p)
input = input[shift[0]:, shift[1]]
compare results
compare layout
additional branch for non-memory dense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the reference!
test/test_nn.py
Outdated
| inp_discontiguous.copy_(inp) | ||
| mod = cls(p=p) | ||
| out = mod(inp_discontiguous) | ||
| self.assertEqual(out.layout, inp_discontiguous.layout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layout in all cases is torch.strided, you want to compare strides or memory format, but not layout
test/test_nn.py
Outdated
|
|
||
| def _test_dropout_mean_check(self, cls, device): | ||
| def _test_dropout_mean(inp, out): | ||
| self.assertEqual(inp.layout, out.layout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, the invariant should be that inverse permutation of out is contiguous, right? Probably start with (3,3,4,5) tensor so that even with the shift there are no 1 dimensions, there might be annoying corner cases for 1 dimensions.
test/test_nn.py
Outdated
| shifts = [(0, 0), (1, 0), (0, 1), (1, 1)] | ||
| for perm in itertools.permutations((0, 1, 2, 3), r=4): | ||
| for shift in shifts: | ||
| for p in [0.3, 0.5, 0.7]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add small probability here, and check for that case that all the values are preserved, similar to _test_dropout_discontiguous?
7d7a1ab to
80694d0
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #47176