Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests fail on A100 GPUs due to inaccurate/differing float values #52278

Closed
Flamefire opened this issue Feb 15, 2021 · 12 comments
Closed

Tests fail on A100 GPUs due to inaccurate/differing float values #52278

Flamefire opened this issue Feb 15, 2021 · 12 comments
Labels
module: ddp Issues/PRs related distributed data parallel training module: nn Related to torch.nn module: tests Issues related to tests (not the torch.testing module) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Flamefire
Copy link
Collaborator

Flamefire commented Feb 15, 2021

馃悰 Bug

Testing PyTorch on our new cluster with A100 (CC 8.0) GPUs I'm seeing multiple failures in the distributed tests of PyTorch.

E.g from the TestDistBackendWithFork suite the following tests fail:

ERROR: test_DistributedDataParallel (__main__.TestDistBackendWithFork)
ERROR: test_DistributedDataParallel_SyncBatchNorm (__main__.TestDistBackendWithFork)
ERROR: test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient (__main__.TestDistBackendWithFork)
ERROR: test_DistributedDataParallel_with_grad_is_view (__main__.TestDistBackendWithFork)

The same setup (same version of CUDA, Compiler, NCCL, ...) works on many other systems with e.g. V100 GPUs

To Reproduce

Steps to reproduce the behavior:

  1. BACKEND=nccl WORLD_SIZE=3 TEST_REPORT_SOURCE_OVERRIDE=dist-nccl python distributed/test_distributed_fork.py

AssertionError: False is not true : Tensors failed to compare as equal! With rtol=1.3e-06 and atol=1e-05, found 4 element(s) (out of 200) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 1.271069049835205e-05 (-0.1342240869998932 vs. -0.13423679769039154), which occurred at index (3, 47).
AssertionError: False is not true : Tensors failed to compare as equal! With rtol=1.3e-06 and atol=1e-05, found 2 element(s) (out of 80) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 1.558661460876465e-05 (-0.46960899233818054 vs. -0.4696245789527893), which occurred at index (29, 1).
AssertionError: False is not true : Tensors failed to compare as equal! With rtol=1.3e-06 and atol=1e-05, found 30 element(s) (out of 160) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 2.6047229766845703e-05 (-0.050193969160318375 vs. -0.05016792193055153), which occurred at index (1, 33).
AssertionError: False is not true : Tensors failed to compare as equal! With rtol=1.3e-06 and atol=1e-05, found 4 element(s) (out of 200) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 1.271069049835205e-05 (-0.1342240869998932 vs. -0.13423679769039154), which occurred at index (3, 47).

More verbose log: error.log

Environment

  • PyTorch Version (e.g., 1.0): 1.7.1
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): source
  • Python version: 3.8.6
  • CUDA/cuDNN version: 11.1.1, cuDNN 8.0.4.30, NCCL 2.8.3
  • GPU models and configuration: 8x A100

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @mruberry @VitalyFedyunin @walterddr @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Feb 15, 2021
@ngimel
Copy link
Collaborator

ngimel commented Feb 16, 2021

cc @ptrblck, did you see similar failures in your CI?

@ptrblck
Copy link
Collaborator

ptrblck commented Feb 16, 2021

Nightly CI isn't failing for TestDistBackendWithFork, but since the setup differs (nightly build vs. 1.7.1, CUDA, cudnn, NCCL version), I can try to recreate the test and reproduce it.

@Flamefire I assume you've built PyTorch from source with CUDA11.1 + cudnn8.0.4 + NCCL 2.8.3?

EDIT: yes, you are:

How you installed PyTorch (conda, pip, source): source

@Flamefire
Copy link
Collaborator Author

FWIW: We are trying to automate the installation of PyTorch on HPC clusters, so the software environment is as reproducible as possible (same versions except for core system stuff like glibc). See easybuilders/easybuild-easyconfigs#12003 for our "recipe"
On some other systems those tests seem to work, so the A100s are the major difference I found. Maybe also the AMD Epyc CPUs but I'd guess the test is on GPU anyway.
I also tested via CUDA_VISIBLE_DEVICES to reduce it down to 3 or 6 GPUs but that didn't seem to change anything.

To make the test pass I had to change the comparison to self.assertEqual(p_gpu, p_DDP, atol=2e-04, rtol=1.3e-06) (rtol unchanged) as lower values would still fail and now I got one more failure:

test_DistributedDataParallel_SyncBatchNorm (__main__.TestDistBackendWithSpawn)
The greatest difference was 0.00024185329675674438 (-0.07157765328884125 vs. -0.07181950658559799)

I'm unsure of that isn't a more general issue or if the reproducibility is maybe not that easy for the DDP

@Flamefire
Copy link
Collaborator Author

Flamefire commented Feb 16, 2021

I did some more experiments with CUDA_VISIBLE_DEVICES=0,1,2 OMP_NUM_THREADS=1 BACKEND=nccl WORLD_SIZE=3 python distributed/test_distributed_fork.py TestDistBackendWithFork.test_DistributedDataParallel and the first iteration only (where it already fails comparison)

Inside _test_DDP_helper I got loss values of

  • 0,6859778165817261 for the base model (same on all 3 ranks)
  • 0,23933732509613037 + 0,7668043375015259 + 1,051791548728942 => 0,685977737 for the DDP model. That's a difference of ~8-e8

For another node where it works values are:

  • 0,685981035232544
  • 0,7668123245239258 + 1,0517898797988892 + 0,2393409162759781 => 0,685981040 so about 5e-9 difference

Printing the output of the model I get pretty much bitwise identical results between base and DDP model, so the model isn't the problem.

Next I checked the single loss values (using reduction="None" and then manually calling .mean()). Again the loss values are bitwise identical, but the means differ.
Data:

Base losses: 0.0014468701556324959,0.19974617660045624,0.2590259611606598,3.7469475269317627,0.027207523584365845,0.2559516727924347,0.1512461155653,2.632812023162842,0.1647290289402008,0.4095090329647064,0.24858342111110687,0.1345278024673462
PyTorch mean: 0.6859778165817261
Numpy mean:  0.6859777629530678

losses GPU0: 0.027207523584365845,0.2559516727924347,0.1512461155653,2.632812023162842
PT mean: 0.7668043375015259
NP mean: 0.7668043337762356

losses GPU1: 0.1647290289402008,0.4095090329647064,0.24858342111110687,0.1345278024673462
PT mean: 0.23933732509613037
0.23933732137084007

losses GPU2: 0.0014468701556324959,0.19974617660045624,0.2590259611606598,3.7469475269317627
PT mean: 1.0517915487289429
NP mean: 1.0517916337121278

So in comparison of the final means:

PTSingle: 0.6859778165817261
PT    :   0.6859777371088663
NPSing:   0.6859777629530678
NP    :   0.6859777629530678

Hence I conclude that the mean calculation in PyTorch in that configuration is inexact.

Example code to reproduce that:

import torch
import numpy as np

data = [0.0014468701556324959,0.19974617660045624,0.2590259611606598,3.7469475269317627,0.027207523584365845,0.2559516727924347,0.1512461155653,2.632812023162842,0.1647290289402008,0.4095090329647064,0.24858342111110687,0.1345278024673462]
ptTensorCPU = torch.Tensor(data)
ptTensorGPU = ptTensorCPU.cuda()

print('np : %s' % np.mean(data))
print('cpu: %s' % float(ptTensorCPU.mean()))
print('gpu: %s' % float(ptTensorGPU.mean()))

@wayi1 wayi1 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 16, 2021
@Flamefire
Copy link
Collaborator Author

Sorry for spamming, but wanted to share my progress:

For the mean I found that the imprecision is caused by multiplying the sum of values by 1/12 instead of dividing by 12 as done on CPU. I can understand that given that seemingly on GPU multiple outputs could be handled and the factor is num_outputs/num_el. However I don't see how a mean operation with multiple outputs would be useful or what it's meaning is and the CPU implementation doesn't handle that case at all. So I changed the code to use division directly and assert single outputs.
This made the mean results from GPU match the CPU (and numpy) results. So I think this might be worth considering for upstream.

However it did not change the results at all. I get the literal same error: 1.271069049835205e-05 (-0.1342240869998932 vs. -0.13423679769039154), which occurred at index (3, 47)

Sanity checking that the reason is not my build I tested it with the docker container docker://pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel and yes: I get the very same results.

Running further I see similar issues in test_nn.py:

FAIL: test_TransformerDecoderLayer_gelu_activation_cuda (__main__.TestNN)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/easybuild-tmp/eb-j7hKdB/tmpQF7fHX/lib/python3.8/site-packages/torch/testing/_internal/common_utils.py", line 827, in wrapper
    method(*args, **kwargs)
  File "/tmp/easybuild-tmp/eb-j7hKdB/tmpQF7fHX/lib/python3.8/site-packages/torch/testing/_internal/common_utils.py", line 827, in wrapper
    method(*args, **kwargs)
  File "test_nn.py", line 9226, in <lambda>
    add(cuda_test_name, lambda self, test=test, kwargs=kwargs: test.test_cuda(self, **kwargs))
  File "/tmp/easybuild-tmp/eb-j7hKdB/tmpQF7fHX/lib/python3.8/site-packages/torch/testing/_internal/common_nn.py", line 4871, in test_cuda
    test_case.assertEqualIgnoreType(cpu_output, gpu_output, atol=self.precision, rtol=0)
  File "/tmp/easybuild-tmp/eb-j7hKdB/tmpQF7fHX/lib/python3.8/site-packages/torch/testing/_internal/common_utils.py", line 1046, in assertEqualIgnoreType
    return self.assertEqual(*args, exact_dtype=False, **kwargs)
  File "/tmp/easybuild-tmp/eb-j7hKdB/tmpQF7fHX/lib/python3.8/site-packages/torch/testing/_internal/common_utils.py", line 1136, in assertEqual
    self.assertTrue(result, msg=msg)
AssertionError: False is not true : Tensors failed to compare as equal! With rtol=0 and atol=0.0002, found 25 element(s) (out of 36) whose difference(s) exceeded the margin of error (including 0 nan compariso
ns). The greatest difference was 0.0010109298180119852 (-0.5147604100655095 vs. -0.5137494802474976), which occurred at index (0, 2, 2).

@zhaojuanmao
Copy link
Contributor

@Flamefire based on the comment, want to confirm with you that the tests not only failed on distributed test, right? test_nn.py also failed?

regarding distributed tests, how about tests in "test_distributed_spawn.py"? "test_distributed_spawn.py" run the same tests as "test_distributed_fork.py", the only difference is that one spawn subprocess and another one fork subprocesses.

@Flamefire
Copy link
Collaborator Author

This is correct. But as the tests run in order and stop on first failed file this was the first to fail. Spawn fails too and I have confirmation of other HPC centers where the same test (in this case the distributed one) fail with the exact same values and differences.
This only seems to happen on A100 so I guess this is something new on that architecture
Yes the other distributed test fail too. For the spawn test the message is The greatest difference was 1.1842697858810425e-05 (-0.025052597746253014 vs. -0.025040755048394203), which occurred at index (3, 1)

@zhaojuanmao zhaojuanmao removed oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 19, 2021
@zhaojuanmao
Copy link
Contributor

@Flamefire I see, looks like it is a general issue, not distributed package specific. In this case, Would you please change the title so that the general PyTorch Oncall will help you better. Thanks

@Flamefire Flamefire changed the title Distributed tests fail on A100 GPUs Tests fail on A100 GPUs due to inaccurate/differing float values Feb 19, 2021
@heitorschueroff heitorschueroff added high priority module: nn Related to torch.nn module: testing Issues related to the torch.testing module (not tests) labels Feb 19, 2021
@heitorschueroff heitorschueroff added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Feb 19, 2021
@mruberry mruberry added module: ddp Issues/PRs related distributed data parallel training module: tests Issues related to tests (not the torch.testing module) and removed module: testing Issues related to the torch.testing module (not tests) labels Feb 22, 2021
@surak
Copy link

surak commented Mar 17, 2021

I have the same problem, with 8 RTX 3090, and an AMD EPYC 7F72 processor.

@ngimel
Copy link
Collaborator

ngimel commented Mar 17, 2021

cc @zasdfgbnm, were the accuracies for the tests in question adjusted?

@zasdfgbnm
Copy link
Collaborator

@ngimel I didn't change any test accuracy on DDP tests, I only disabled TF32 globally: #52941, but this issue seems to be unrelated to TF32.

@Flamefire
Copy link
Collaborator Author

Flamefire commented Jun 8, 2021

Any update here? FWIW the following tests from test_nn.py fail on A100:

  • TestNN.test_Conv3d_1x1x1_no_bias_cuda
  • TestNN.test_TransformerDecoderLayer_gelu_activation_cuda_tf32
  • TestNN.test_TransformerDecoderLayer_relu_activation_cuda_tf32
  • TestNN.test_TransformerEncoderLayer_gelu_activation_cuda_tf32

When running only those then just the first test fails, so the failures are input dependent (there is a torch.randn for the input which is influenced by prior calls to it)

The first fails always, even when run standalone. Error message:

AssertionError: False is not true : Tensors failed to compare as equal!With rtol=0 and atol=0.0002, found 6 element(s) (out of 6) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.009475813915337028 (10.580324278881157 vs. 10.57084846496582), which occurred at index (2, 1, 0, 0, 0).

When run inside the full test_nn the greatest difference is 0.01662997265166055 which matches what I see with the other tf32 tests.
Hence this test is run with tf32 calculations and returns results with reduced precision. And indeed when manually disabling tf32 for this test it succeeds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: ddp Issues/PRs related distributed data parallel training module: nn Related to torch.nn module: tests Issues related to tests (not the torch.testing module) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants