Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Did Dropout2d change for Pytorch 1.11? #77081

Closed
albertfgu opened this issue May 9, 2022 · 23 comments
Closed

Did Dropout2d change for Pytorch 1.11? #77081

albertfgu opened this issue May 9, 2022 · 23 comments
Assignees
Labels
high priority module: cuda Related to torch.cuda, and CUDA support in general module: random Related to random number generation in PyTorch (rng generator) module: regression It used to work, and now it doesn't triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@albertfgu
Copy link

albertfgu commented May 9, 2022

馃悰 Describe the bug

After having issues reproducing results, my collaborators and I ran small models on many different environments, GPUs, and package versions, and found that there is a definite performance change when using nn.Dropout2d on torch=1.11 compared to torch=1.10. For example, for one of our small models with 100K parameters and dropout=0.1, the version on torch 1.11 is around 3% worse than the version on torch 1.10, and this gap begins within the first few epochs and persists throughout training for 100 epochs. No other hyperparameters in this model seem to cause performance differences between the torch versions. The performance differences are consistent across several different environments (e.g. Google Cloud and our local cluster) and GPUs (T4, P100, A100).

I don't have time to create a minimal reproducible example right now, but can do so later if that is helpful. For now I just want to understand the difference in implementation: what changed for Dropout2d between 1.10 and 1.11?

Versions

Collecting environment information...
PyTorch version: 1.11.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.6 LTS (x86_64)
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
Clang version: 3.8.0-2ubuntu4 (tags/RELEASE_380/final)
CMake version: version 3.21.0
Libc version: glibc-2.23

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.4.0-130-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: Tesla P100-PCIE-16GB
Nvidia driver version: 430.26
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.5
[pip3] pytorch-fast-transformers==0.4.0
[pip3] pytorch-lightning==1.5.0
[pip3] torch==1.11.0
[pip3] torchaudio==0.11.0
[pip3] torchmetrics==0.6.0
[pip3] torchtext==0.11.0
[pip3] torchvision==0.12.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.20.3 pypi_0 pypi
[conda] numpy-base 1.21.5 py38hf524024_2
[conda] pytorch 1.11.0 py3.8_cuda10.2_cudnn7.6.5_0 pytorch
[conda] pytorch-fast-transformers 0.4.0 pypi_0 pypi
[conda] pytorch-lightning 1.5.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.10.0 pypi_0 pypi
[conda] torchaudio 0.11.0 py38_cu102 pytorch
[conda] torchmetrics 0.6.0 pypi_0 pypi
[conda] torchtext 0.11.0 pypi_0 pypi
[conda] torchvision 0.12.0 py38_cu102 pytorch

cc @ezyang @gchanan @zou3519 @ngimel @pbelevich @mruberry @kurtamohler

@ngimel
Copy link
Collaborator

ngimel commented May 9, 2022

THere's #63937 that wasn't intended to change semantics, but could have done so inadvertently (that would be a bug). Why do you think it's a Dropout2D bug, you see the difference only on the models that use it?
cc @jjsjann123

@ezyang ezyang added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: random Related to random number generation in PyTorch (rng generator) module: determinism and removed module: determinism labels May 9, 2022
@albertfgu
Copy link
Author

Yes. I am positive there is a difference. We essentially ran a sweep of {3 different environments with T4, P100, A100} x {model dropout2d=0.1, dropout2d=0.0} x {torch 1.10, 1.11} and everything is reproducible except dropout2d=0.1 x torch 1.11 gives completely different results.

@ngimel
Copy link
Collaborator

ngimel commented May 9, 2022

Are you running scripted or eager models?

@albertfgu
Copy link
Author

albertfgu commented May 9, 2022

Do you mean torchscript / jit vs. standard PyTorch? We're using the latter; nothing fancy.

I can give you an exact command line for this, but unfortunately I don't have a minimal file. I am running experiments in the S4 repo and the simple command
python -m train wandb=null pipeline=mnist model=s4 model.dropout=0.1 train.seed=0 should give different results on torch 1.10/1.11

@ngimel
Copy link
Collaborator

ngimel commented May 9, 2022

Minimum repro would be better of course, debugging convergence differences between versions is hard as there's no bitwise reproducibility guarantee between releases even in the absence of random operations

@albertfgu
Copy link
Author

albertfgu commented May 9, 2022

Yes, a minimum repro is of course better. I can try to create one later. I'll note again that although there is noise in the convergence behavior, the performance differences are very dramatic: the final accuracy has a 3 point difference (between 1.10/1.11) where the standard deviation is ~0.2 (across seeds).

Moreover the train/val accuracy after the first epoch has a very consistent difference; when fixing the same seed, all versions of the model are self-consistent across {torch 1.10, 1.11} x {environment/GPU} x {many other hyperparameter changes} when {dropout2d=0.0} is set (up to say 0.1 difference in train/val accuracy after 1 epoch), but have a 2-4 point difference (between 1.10 and 1.11) when {dropout2d=0.1} is set. This is how we were able to debug and track down this issue, by running things out for 1 epoch which takes about 1-2 minutes.

@thinline72
Copy link

thinline72 commented Jun 8, 2022

Hi all,
I can confirm that I've faced a similar quality issue after upgrading to 1.11.0. I was able to track it down to the F.dropout2d function, looks like its behaviour changed in 1.11.0 in comparison to 1.10.0 or 1.10.2.

Now, if it's applied to a tensor of the shape 3 like (batch_size, num_dim, length), it'll drop out on the batch dimension (meaning the whole example in the batch will be zeroed) instead of num_dim dimension.

I've created a quick example in the Colab to showcase that: https://colab.research.google.com/drive/1nw4GqAbj0amF5WP9aRJtd0QkMwl8HUok?usp=sharing

@albertfgu
Copy link
Author

That makes so much sense, thanks for figuring out the issue @thinline72! I've been meaning to create a minimal example but had a backlog to deal with.

@thinline72
Copy link

thinline72 commented Jun 8, 2022

馃憤

I've also added a possible workaround for time being in that Colab notebook. It seems to work correctly if we firstly unsqueeze, then apply dropout and then squeeze back, like:

F.dropout2d(x.unsqueeze(2), p=0.5, training=True).squeeze(2)

@ngimel
Copy link
Collaborator

ngimel commented Jun 13, 2022

cc @zou3519, @jbschlosser were there any changes to dropout2d to support vmap/no-batch-dim? High priority for silent wrong results, thanks @thinline72 for narrowing down this bug.

@zou3519 zou3519 added this to the 1.12.0 milestone Jun 13, 2022
@zou3519 zou3519 added the module: regression It used to work, and now it doesn't label Jun 13, 2022
@zou3519
Copy link
Contributor

zou3519 commented Jun 13, 2022

On "support for functorch" I see #72078 but haven't tested it

EDIT: It's most likely not this one, this should not change the semantics

@zou3519
Copy link
Contributor

zou3519 commented Jun 13, 2022

#69885 looks suspicious, cc @jbschlosser @kshitij12345

@ngimel
Copy link
Collaborator

ngimel commented Jun 13, 2022

Yep, #69885 looks like that must be it, thanks @zou3519

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Jun 13, 2022

#69885 did update the behaviour of Dropout2d and Dropout3d to zero entire channels instead of few random values in channel (like the docs claimed) for no-batch-dim case. So if 3-D input is passed, dims are treated as C, H, W and hence the whole channels will get dropped which is documented in the docs.

Relevant Issue which PR fixed: #69801
Docs: https://pytorch.org/docs/stable/generated/torch.nn.Dropout2d.html

@ngimel
Copy link
Collaborator

ngimel commented Jun 13, 2022

But you can't just assume that if input is 3d, it has no batch dim, that's the problem reported here.

@kshitij12345
Copy link
Collaborator

No-batch-dim support for Dropout2d was shipped with 1.10 and it did document that no-batch-dim shape is C, H, W. However, it's implementation did not actually drop channels.

1.10 Docs: https://pytorch.org/docs/1.10/generated/torch.nn.Dropout2d.html?highlight=dropout2d#torch.nn.Dropout2d

@ngimel
Copy link
Collaborator

ngimel commented Jun 13, 2022

This is clearly a BC-breaking change that's leading to silent change of behavior.

@jbschlosser
Copy link
Contributor

jbschlosser commented Jun 13, 2022

Old Dropout2d technically supported shape (N, C, H, W), but didn't enforce this and just called standard dropout underneath. So 3D inputs have been historically supported by Dropout2d, and #69885 silently broke this. This silent breakage was discussed here but the approach taken was not sufficient to avoid it.

@thinline72
Copy link

+1 to @ngimel

None of the release notes mention this BC-breaking change. IIRC, I have the code using Dropout2d since 1.6.

@jbschlosser
Copy link
Contributor

None of the release notes mention this BC-breaking change. IIRC, I have the code using Dropout2d since 1.6.

True and we should document this clearly as a BC-breaking change for 1.12, indicating that it was missed in the release notes for 1.11.

Note that I'd recommend using Dropout1d if you want to drop entire channels from inputs with a single spatial dimension i.e. (N, C, L). Outside of a no-batch-dim input, there's not a sensible interpretation of 3D inputs by Dropout2d given its purpose of dropping entire channels - (N, H, W) doesn't have a channel dim.

@albertfgu
Copy link
Author

Searching the docs doesn't show any Dropout1d: https://pytorch.org/docs/stable/search.html?q=dropout1d&check_keywords=yes&area=default

It seems like it's been the standard for years to use Dropout2d to emulate Dropout1d functionality: #6442

@jbschlosser
Copy link
Contributor

jbschlosser commented Jun 14, 2022

It seems like it's been the standard for years to use Dropout2d to emulate Dropout1d functionality: #6442

Sure enough, you're right. For 1.12, I'll propose we:

This will break those depending on no-batch-dim inputs now, which is unfortunate but unavoidable. I expect fewer users depending on that vs. those depending on 1D channel-wise dropout behavior for 3D inputs in Dropout2d. A large warning in the docs will describe the situation.

Thoughts on this?

@jbschlosser
Copy link
Contributor

Closing as addressed in #79545 and #79549.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cuda Related to torch.cuda, and CUDA support in general module: random Related to random number generation in PyTorch (rng generator) module: regression It used to work, and now it doesn't triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants