Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cuDNN error message when it's Conv2d #45729

Closed
wants to merge 1 commit into from

Conversation

xwang233
Copy link
Collaborator

@xwang233 xwang233 commented Oct 2, 2020

Originally introduced in #45023. When I was doing test in the original PR, it was a Conv3d, so this problem was not discovered.

Arrays in ConvolutionParams have a fixed length of 3 or 5. This is because max_dim is set as a constexpr of 3, regardless of Conv2d or Conv3d. The current code will make some error message be weird. See below in the comments.

struct ConvolutionParams
{
cudnnDataType_t dataType;
int input_size[2 + max_dim];
int input_stride[2 + max_dim];
int weight_size[2 + max_dim];
int padding[max_dim];
int stride[max_dim];
int dilation[max_dim];
int64_t groups;
bool deterministic;
bool allow_tf32;
// NB: transposed purposely omitted: transposed just swaps
// forward and backward, so you can reuse the benchmark entry,
};

@xwang233
Copy link
Collaborator Author

xwang233 commented Oct 2, 2020

The weird error message before this PR.

Traceback (most recent call last):
  something
RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR
You can try to repro this exception using the following code snippet. If that doesn't trigger the error, please include your original repro script when reporting this issue.

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.allow_tf32 = True
data = torch.randn([128, 512, 7, 7], dtype=torch.float, device='cuda', requires_grad=True)
net = torch.nn.Conv2d(512, 2048, kernel_size=[1, 1], padding=[0, 0, 0], stride=[1, 1, 0], dilation=[1, 1, 0], groups=1)
net = net.cuda().float()
out = net(data)
out.backward(torch.randn_like(out))
torch.cuda.synchronize()

ConvolutionParams
    data_type = CUDNN_DATA_FLOAT
    padding = [0, 0, 0]
    stride = [1, 1, 0]
    dilation = [1, 1, 0]
    groups = 1
    deterministic = false
    allow_tf32 = true
input: TensorDescriptor 0x5630c484c770
    type = CUDNN_DATA_FLOAT
    nbDims = 4
    dimA = 128, 512, 7, 7,
    strideA = 25088, 49, 7, 1,
output: TensorDescriptor 0x7f38b8038cd0
    type = CUDNN_DATA_FLOAT
    nbDims = 4
    dimA = 128, 2048, 7, 7,
    strideA = 100352, 49, 7, 1,
weight: FilterDescriptor 0x7f38b8066d40
    type = CUDNN_DATA_FLOAT
    tensor_format = CUDNN_TENSOR_NCHW
    nbDims = 4
    dimA = 2048, 512, 1, 1,
Pointer addresses:
    input: 0x7f3999700000
    output: 0x7f399d440000
    weight: 0x7f3bb7d80000

Note these padding, stride, dilation are 3d. This PR fix the python repro script when the convolution is 2d.

The ConvolutionParams printed below is a bit hard to fix. However, they are mainly for developer debugging purpose. Hopefully, developers won't feel weird about it.

@xwang233
Copy link
Collaborator Author

xwang233 commented Oct 2, 2020

cc @ptrblck

@codecov
Copy link

codecov bot commented Oct 2, 2020

Codecov Report

Merging #45729 into master will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master   #45729   +/-   ##
=======================================
  Coverage   68.60%   68.60%           
=======================================
  Files         410      410           
  Lines       52670    52670           
=======================================
  Hits        36132    36132           
  Misses      16538    16538           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9201c37...48d01f6. Read the comment docs.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

xwang233 added a commit to xwang233/pytorch that referenced this pull request Oct 2, 2020
Summary:
Originally introduced in pytorch#45023. When I was doing test in the original PR, it was a Conv3d, so this problem was not discovered.

Arrays in `ConvolutionParams` have a fixed length of 3 or 5. This is because `max_dim` is set as a constexpr of 3, regardless of Conv2d or Conv3d. The current code will make some error message be weird. See below in the comments.

https://github.com/pytorch/pytorch/blob/9201c37d020007979e144693d86c8e8599e2fd8f/aten/src/ATen/native/cudnn/Conv.cpp#L212-L226

Pull Request resolved: pytorch#45729

Reviewed By: mruberry

Differential Revision: D24081542

Pulled By: ngimel

fbshipit-source-id: 141f8946f4d0db63a723131775731272abeaa6ab
@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 8619de8.

malfet pushed a commit that referenced this pull request Oct 6, 2020
Summary:
Originally introduced in #45023. When I was doing test in the original PR, it was a Conv3d, so this problem was not discovered.

Arrays in `ConvolutionParams` have a fixed length of 3 or 5. This is because `max_dim` is set as a constexpr of 3, regardless of Conv2d or Conv3d. The current code will make some error message be weird. See below in the comments.

https://github.com/pytorch/pytorch/blob/9201c37d020007979e144693d86c8e8599e2fd8f/aten/src/ATen/native/cudnn/Conv.cpp#L212-L226

Pull Request resolved: #45729

Reviewed By: mruberry

Differential Revision: D24081542

Pulled By: ngimel

fbshipit-source-id: 141f8946f4d0db63a723131775731272abeaa6ab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants