Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[functorch] fix batching rule for dropout #92975

Closed
wants to merge 3 commits into from

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Jan 25, 2023

Fixes #92283

The repro now works:

import torch
import torch.func
import torch.nn as nn

x = torch.randn(3, device='cuda')
y = torch.randn(1, 3, device='cuda')

def fn(x, y):
    # previously output of dropout used to be incorrect [B, 3] (B=1) and thus `mean(1)` used to fail
    # post the fix output of dropout is [B, 1, 3] and `mean(1)` works.
    return x + nn.functional.dropout(y, 0.3).mean(1)


o = torch.func.vmap(fn, in_dims=(0, None), randomness='different')(x, y)

NOTE:
native_dropout_batching_rule(const Tensor& tensor, double p, c10::optional<bool> train) was called only for CUDA tensor. Hence this issue only affected CUDA tensors and not CPU tensors

Ref:

Tensor dropout(const Tensor& input, double p, bool train) {
auto result = [&]() {
NoNamesGuard guard;
if (train && is_fused_kernel_acceptable(input, p)) {
return std::get<0>(at::native_dropout(input, p, train));
}
return _dropout<false>(input, p, train);
}();

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92975

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6acfa4c:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kshitij12345 kshitij12345 added the release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch label Jan 25, 2023
@kshitij12345 kshitij12345 marked this pull request as ready for review January 25, 2023 16:04
@Skylion007 Skylion007 self-requested a review January 25, 2023 17:37
Copy link
Collaborator

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me now.

@kshitij12345
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 26, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Randomness 'different' results in weird behavior of Dropout
5 participants