Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.matmul output contains some nan value for large size fp16 tensors in V100 GPU #45724

Closed
bbfrog opened this issue Oct 2, 2020 · 11 comments
Labels
high priority module: cublas Problem related to cublas support module: cuda Related to torch.cuda, and CUDA support in general module: dependency bug Problem is not caused by us, but caused by an upstream library we use triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@bbfrog
Copy link

bbfrog commented Oct 2, 2020

Issue description

Please see the simple code below: If running in Nvidia V100 GPU and with the randomly generated fp16 tensors with size [13269, 8, 22, 64] as input, the torch.matmul output contains some nan value which are not expected.

  • This problem can not be reproed if running in P100 or 1080Ti, seems it is related with the fp16 computation kernel in V100 (both 16GB and 32GB V100 can repro).

  • This problem can be reproed by many other large size fp16 tensors, the code below is just an example.

  • This problem can be reproed by both pytorch 1.3.1 and pytorch 1.5.1. I haven't try other pytorch versions.

  • This problem can not be reproed in V100 if using fp32 computation.

  • The code stdout when running in P100 or 1080Ti:
    CUDA name: GeForce GTX 1080 Ti
    nan items count: 0, ratio: 0.0%

  • The code stdout when running in V100:
    CUDA name: Tesla V100-SXM2-16GB
    nan items count: 305560, ratio: 0.59473426223678%
    nan examples:
    index: 8191,7,2,0
    Computed attention_scores: nan
    Expected attention_scores: 15.4609375
    index: 8191,7,11,4
    Computed attention_scores: nan
    Expected attention_scores: 14.203125
    ...

Code example

import torch
print("CUDA name: {}".format(torch.cuda.get_device_name(0)))

torch.manual_seed(12345)
query_layer = torch.rand(13269, 8, 22, 64).cuda().half()
key_layer = torch.rand(13269, 8, 22, 64).cuda().half()

# This computation is part of transformer self attention
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

indexs = torch.isnan(attention_scores).nonzero(as_tuple=True)
nan_item_count = indexs[0].shape[0]
print("nan items count: {}, ratio: {}%\n".format(nan_item_count, 100 * nan_item_count / attention_scores.nelement()))

print('nan examples:')
for i in range(min(20, nan_item_count)):
    print("index: {},{},{},{}".format(indexs[0][i], indexs[1][i], indexs[2][i], indexs[3][i]))

    current_query_layer = query_layer[indexs[0][i], indexs[1][i], indexs[2][i]]
    current_key_layer = key_layer[indexs[0][i], indexs[1][i], indexs[3][i]]

    current_attention_score = attention_scores[indexs[0][i], indexs[1][i], indexs[2][i], indexs[3][i]]
    expected_attention_score = torch.matmul(current_query_layer.unsqueeze(0), current_key_layer.unsqueeze(1))

    #print('query_layer')
    #print(current_query_layer)
    #print('key_layer')
    #print(current_key_layer)

    print('Computed attention_scores: {}'.format(current_attention_score.item()))
    print('Expected attention_scores: {}'.format(expected_attention_score.item()))

    print('\n')

cc @ezyang @gchanan @zou3519 @csarofeen @ptrblck @ngimel

@colesbury colesbury added module: cublas Problem related to cublas support module: cuda Related to torch.cuda, and CUDA support in general labels Oct 2, 2020
@colesbury
Copy link
Member

This is strange... I can reproduce the issue on V100 with PyTorch 1.6. Is FP16 batched matrix multiply buggy for large batch sizes?

@colesbury colesbury added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 2, 2020
@bbfrog
Copy link
Author

bbfrog commented Oct 2, 2020

Thanks @colesbury. Add some more information: if change inputs to be created in CUDA with fp16 types as below, about 38% of torch.matmul output are 0 which are not expected. And seems the problem became even worser.
query_layer = torch.rand(13269, 8, 22, 64, dtype=torch.float16, device='cuda')
key_layer = torch.rand(13269, 8, 22, 64, dtype=torch.float16, device='cuda')

The whole updated repro code:

import torch
print("CUDA name: {}".format(torch.cuda.get_device_name(0)))

torch.cuda.manual_seed(12345)
query_layer = torch.rand(13269, 8, 22, 64, dtype=torch.float16, device='cuda')
key_layer = torch.rand(13269, 8, 22, 64, dtype=torch.float16, device='cuda')

# This computation is part of transformer self attention
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

indexs = torch.isnan(attention_scores) + torch.isinf(attention_scores) + (attention_scores == 0.)
indexs = indexs.nonzero(as_tuple=True)
problem_item_count = indexs[0].shape[0]
print("problem items count: {}, ratio: {}%\n".format(problem_item_count, 100 * problem_item_count / attention_scores.nelement()))

print('problem examples:')
for i in range(min(20, problem_item_count)):
    print("index: {},{},{},{}".format(indexs[0][i], indexs[1][i], indexs[2][i], indexs[3][i]))

    current_query_layer = query_layer[indexs[0][i], indexs[1][i], indexs[2][i]]
    current_key_layer = key_layer[indexs[0][i], indexs[1][i], indexs[3][i]]

    current_attention_score = attention_scores[indexs[0][i], indexs[1][i], indexs[2][i], indexs[3][i]]
    expected_attention_score = torch.matmul(current_query_layer.unsqueeze(0), current_key_layer.unsqueeze(1))

    print('Computed attention_scores: {}'.format(current_attention_score.item()))
    print('Expected attention_scores: {}'.format(expected_attention_score.item()))

    print('\n')

@ngimel
Copy link
Collaborator

ngimel commented Oct 2, 2020

This looks like cublas bug, can you please output your cuda version obtained with

print(*torch.__config__.show().split("\n"), sep="\n")

@bbfrog
Copy link
Author

bbfrog commented Oct 2, 2020

print(*torch.config.show().split("\n"), sep="\n")

CUDA name: Tesla V100-SXM2-16GB
PyTorch built with:

  • GCC 5.3
  • C++ Version: 201402
  • Intel(R) Math Kernel Library Version 2019.0.1 Product Build 20180928 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v0.21.1 (Git Hash 7d2fd500bc78936d1d648ca713b901012f470dbc)
  • OpenMP 201307 (a.k.a. OpenMP 4.0)
  • NNPACK is enabled
  • CPU capability usage: AVX2
  • CUDA Runtime 10.1
  • NVCC architecture flags: -gencode;arch=compute_35,code=sm_35;-gencode;arch=compute_52,code=sm_52;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_70,code=compute_70
  • CuDNN 7.6.4
  • Magma 2.5.2
  • Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DNDEBUG -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_INTERNAL_THREADPOOL_IMPL -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=ON, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF,

@bbfrog
Copy link
Author

bbfrog commented Oct 2, 2020

Thanks @colesbury. Add some more information: if change inputs to be created in CUDA with fp16 types as below, about 38% of torch.matmul output are 0 which are not expected. And seems the problem became even worser.
query_layer = torch.rand(13269, 8, 22, 64, dtype=torch.float16, device='cuda')
key_layer = torch.rand(13269, 8, 22, 64, dtype=torch.float16, device='cuda')

The whole updated repro code:

import torch
print("CUDA name: {}".format(torch.cuda.get_device_name(0)))

torch.cuda.manual_seed(12345)
query_layer = torch.rand(13269, 8, 22, 64, dtype=torch.float16, device='cuda')
key_layer = torch.rand(13269, 8, 22, 64, dtype=torch.float16, device='cuda')

# This computation is part of transformer self attention
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

indexs = torch.isnan(attention_scores) + torch.isinf(attention_scores) + (attention_scores == 0.)
indexs = indexs.nonzero(as_tuple=True)
problem_item_count = indexs[0].shape[0]
print("problem items count: {}, ratio: {}%\n".format(problem_item_count, 100 * problem_item_count / attention_scores.nelement()))

print('problem examples:')
for i in range(min(20, problem_item_count)):
    print("index: {},{},{},{}".format(indexs[0][i], indexs[1][i], indexs[2][i], indexs[3][i]))

    current_query_layer = query_layer[indexs[0][i], indexs[1][i], indexs[2][i]]
    current_key_layer = key_layer[indexs[0][i], indexs[1][i], indexs[3][i]]

    current_attention_score = attention_scores[indexs[0][i], indexs[1][i], indexs[2][i], indexs[3][i]]
    expected_attention_score = torch.matmul(current_query_layer.unsqueeze(0), current_key_layer.unsqueeze(1))

    print('Computed attention_scores: {}'.format(current_attention_score.item()))
    print('Expected attention_scores: {}'.format(expected_attention_score.item()))

    print('\n')

Add more info:

  1. If the input size is increased, the problem items ratio is also increased.
  2. If change the third dimension to 8N, like 24/32, seems the problem is disappeared.
  • No problem for input size [4096, 8, 22, 64]
  • Problem item ratio is 0.00152587890625% for [8192, 8, 22, 64]
  • Problem item ratio is 38.26305674881302% for [13269, 8, 22, 64]
  • Problem item ratio is 50.000762939453125% for [16384, 8, 22, 64]
  • Problem item ratio is 75.00038146972656% for [32768, 8, 22, 64]
  • No problem for [X, 8, 24, 64] or [X, 8, 32, 64] .

@ptrblck
Copy link
Collaborator

ptrblck commented Oct 2, 2020

@bbfrog Thanks for reporting this issue.
We've reproduced the invalid outpus on a V100 using cublas10.2.2.89, while cublas11.1.0.024 is creating valid outputs.
We have also forwarded this issue to our cublas team to find out the root cause for the issue in 10.2.

As a workaround, could you try to install the nightly binaries with CUDA11?

@mcarilli
Copy link
Collaborator

mcarilli commented Oct 2, 2020

THCudaBlas_HgemmStridedBatched correctly sets computeType to CUDA_R_32F so I think this is a cublas bug, not a pytorch bug.

disclaimer: I did not confirm the particular failing case routes to THCudaBlas_HgemmStridedBatched. Maybe some heuristics along the call chain send it to another function.

@bbfrog
Copy link
Author

bbfrog commented Oct 2, 2020

We have also forwarded this issue to our cublas team to find out the root cause for the issue in 10.2.

Thanks @ptrblck very much! I will try out cublas11.1.0.024 in dev machine and update the results. But in our production training environment, it may takes a long time to upgrade to cublas11.1.0.024 (I will try to push hard:)), so any other workaround is also appreciated. Thanks!

@ezyang ezyang added the module: dependency bug Problem is not caused by us, but caused by an upstream library we use label Oct 5, 2020
@ezyang
Copy link
Contributor

ezyang commented Oct 5, 2020

At the very least we should add an error when we hit this case on cublas 10 for as long as we support cuda 10. We don't think it would be too difficult to split the batch size into smaller sizes when we know we have a buggy cublas, but it will be a little tricky to test if we did it correctly (but not too bad, only 150M of cuda data to trigger the problem).

@ezyang ezyang added this to the 1.7.0 milestone Oct 5, 2020
@ngimel
Copy link
Collaborator

ngimel commented Oct 5, 2020

The surface of the bug:

  • V100 architecture (possibly Turing? Safer to use cc 7.0)
  • cublas major version 10
  • half datatype
  • batchsize >= 65536
  • at least one of matrix sizes not multiple of 8

@ptrblck does this sound right? Can you guys submit a PR throwing an error under those conditions (or, ideally, implementing workaround calling batched gemm multiple times.

zasdfgbnm added a commit that referenced this issue Oct 8, 2020
Summary:
Fixes #45724

Pull Request resolved: #46001

Reviewed By: mruberry

Differential Revision: D24184058

Pulled By: ngimel

fbshipit-source-id: 7d2bab3206ddbc10a7cae3efd9b5e253f38400a9
malfet pushed a commit that referenced this issue Oct 8, 2020
Summary:
Fixes #45724

Pull Request resolved: #46001

Reviewed By: mruberry

Differential Revision: D24184058

Pulled By: ngimel

fbshipit-source-id: 7d2bab3206ddbc10a7cae3efd9b5e253f38400a9
@bbfrog
Copy link
Author

bbfrog commented Oct 8, 2020

I checked the code fix checkin, and looks great to me. Thanks!
b2bff9e

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cublas Problem related to cublas support module: cuda Related to torch.cuda, and CUDA support in general module: dependency bug Problem is not caused by us, but caused by an upstream library we use triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants