Skip to content

nn.Linear with empty tensor backward error (CUDA) #34202

@lsrock1

Description

@lsrock1

EDITED by @colesbury

🐛 Bug

Calling an nn.Linear layer with an empty tensor results in an error in the backward pass. In master, the error only occurs with CUDA tensors. In PyTorch 1.4 the error also occurs with CPU tensors.

To Reproduce

Steps to reproduce the behavior:

import torch
from torch import nn

test = torch.ones(0, 2100).cuda()
f = nn.Linear(2100, 2100).cuda()
f(test).sum().backward()
RuntimeError: at::cuda::blas::gemm<float> argument ldb must be positive and less than 2147483647 but got 0 (gemm<float> at ../aten/src/ATen/cuda/CUDABlas.cpp:163)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6c (0x7fce4fd9a70c in /scratch/sgross/pytorch/torch/lib/libc10.so)
frame #1: <unknown function> + 0x33af519 (0x7fce53368519 in /scratch/sgross/pytorch/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xafcb37 (0x7fce50ab5b37 in /scratch/sgross/pytorch/torch/lib/libtorch_cuda.so)
frame #3: THCudaTensor_addmm + 0x6c (0x7fce50abb01c in /scratch/sgross/pytorch/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x34cfc59 (0x7fce53488c59 in /scratch/sgross/pytorch/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x33d3058 (0x7fce5338c058 in /scratch/sgross/pytorch/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0x7e26f4 (0x7fce630476f4 in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0x2c563fc (0x7fce654bb3fc in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x7e26f4 (0x7fce630476f4 in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #9: at::Tensor::mm(at::Tensor const&) const + 0x103 (0x7fce632b8a23 in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x28909bb (0x7fce650f59bb in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #11: torch::autograd::generated::AddmmBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x1c0 (0x7fce650f5cc0 in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x2d5f18d (0x7fce655c418d in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #13: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x177b (0x7fce655c12db in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #14: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x4e9 (0x7fce655c2169 in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #15: torch::autograd::Engine::thread_init(int) + 0x49 (0x7fce655b8c49 in /scratch/sgross/pytorch/torch/lib/libtorch_cpu.so)
frame #16: torch::autograd::python::PythonEngine::thread_init(int) + 0x48 (0x7fce689c4018 in /scratch/sgross/pytorch/torch/lib/libtorch_python.so)
frame #17: <unknown function> + 0xc819d (0x7fce6b36c19d in /scratch/sgross/miniconda3/bin/../lib/libstdc++.so.6)
frame #18: <unknown function> + 0x76db (0x7fce977316db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #19: clone + 0x3f (0x7fce9745a88f in /lib/x86_64-linux-gnu/libc.so.6)

Environment

PyTorch 1.5.0a0+f3a5081

Note that with PyTorch 1.4.0 the equivalent CPU code also fails:

import torch
from torch import nn

test = torch.ones(0, 2100)
f = nn.Linear(2100, 2100)
f(test).sum().backward()

The original environment reported by @lsrock1:

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
CMake version: version 3.15.0-rc3

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 440.33.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.1

Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.4.0
[pip] torchvision==0.5.0
[conda] mkl 2020.0 166
[conda] pytorch 1.4.0 py3.7_cuda10.1.243_cudnn7.6.3_0 pytorch
[conda] torchvision 0.5.0 py37_cu101 pytorch

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: derivativesRelated to derivatives of operatorsmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions