Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA Error in batchNorm #42588

Closed
ghk829 opened this issue Aug 5, 2020 · 11 comments
Closed

CUDA Error in batchNorm #42588

ghk829 opened this issue Aug 5, 2020 · 11 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ghk829
Copy link

ghk829 commented Aug 5, 2020

馃悰 Bug


RuntimeError Traceback (most recent call last)
in
----> 1 loss.backward()

/opt/conda/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
183 products. Defaults to False.
184 """
--> 185 torch.autograd.backward(self, gradient, retain_graph, create_graph)
186
187 def register_hook(self, hook):

/opt/conda/lib/python3.7/site-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
125 Variable._execution_engine.run_backward(
126 tensors, grad_tensors, retain_graph, create_graph,
--> 127 allow_unreachable=True) # allow_unreachable flag
128
129

RuntimeError: Expected grad_output->is_contiguous(grad_output->suggest_memory_format()) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Exception raised from cudnn_batch_norm_backward at /opt/conda/conda-bld/pytorch_1595629403081/work/aten/src/ATen/native/cudnn/BatchNorm.cpp:249 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x4d (0x7f993197377d in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: at::native::cudnn_batch_norm_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, double, at::Tensor const&) + 0x25b2 (0x7f9932a79db2 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #2: + 0xd1150a (0x7f9932aea50a in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #3: + 0xd3fa3b (0x7f9932b18a3b in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #4: at::cudnn_batch_norm_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, double, at::Tensor const&) + 0x1ef (0x7f9964cff10f in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #5: + 0x2b59cff (0x7f9966946cff in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #6: + 0x2b6b21b (0x7f996695821b in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #7: at::cudnn_batch_norm_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, double, at::Tensor const&) + 0x1ef (0x7f9964cff10f in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #8: torch::autograd::generated::CudnnBatchNormBackward::apply(std::vector<at::Tensor, std::allocatorat::Tensor >&&) + 0x42c (0x7f99668a9fec in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #9: + 0x30d1017 (0x7f9966ebe017 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #10: torch::autograd::Engine::evaluate_function(std::shared_ptrtorch::autograd::GraphTask&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptrtorch::autograd::ReadyQueue const&) + 0x1400 (0x7f9966eb9860 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #11: torch::autograd::Engine::thread_main(std::shared_ptrtorch::autograd::GraphTask const&) + 0x451 (0x7f9966eba401 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #12: torch::autograd::Engine::thread_init(int, std::shared_ptrtorch::autograd::ReadyQueue const&, bool) + 0x89 (0x7f9966eb2579 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #13: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptrtorch::autograd::ReadyQueue const&, bool) + 0x4a (0x7f996b1e199a in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #14: + 0xc819d (0x7f99aa6fa19d in /opt/conda/bin/../lib/libstdc++.so.6)
frame #15: + 0x76db (0x7f99adb356db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #16: clone + 0x3f (0x7f99ad85e88f in /lib/x86_64-linux-gnu/libc.so.6)

To Reproduce

Steps to reproduce the behavior:

  1. I used custom lambda layer before batch

Expected behavior

Environment

In ubuntu, CUDA 10.1, python 3.7
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

cc @ngimel @csarofeen @ptrblck @xwang233

@gchanan
Copy link
Contributor

gchanan commented Aug 5, 2020

reproduction? PyTorch Version?

@malfet malfet added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 5, 2020
@malfet
Copy link
Contributor

malfet commented Aug 5, 2020

Please run python -m torch.utils.collect_env and post results here. Also, a reproduction would be nice, right now there are not enough details to proceed with the investigation.

@ngimel ngimel added needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user and removed needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user labels Aug 5, 2020
@ghk829
Copy link
Author

ghk829 commented Aug 6, 2020

@gchanan
I used pytorch 1.6.0

@ghk829
Copy link
Author

ghk829 commented Aug 6, 2020

@malfet
here is the information in my env

Collecting environment information...
PyTorch version: 1.6.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: Tesla P40
Nvidia driver version: 418.43
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.6.0
[pip3] torchvision==0.7.0
[pip3] torchviz==0.0.1
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.1.243             h6bb024c_0  
[conda] mkl                       2020.1                      217  
[conda] mkl-service               2.3.0            py37he904b0f_0  
[conda] mkl_fft                   1.1.0            py37h23d657b_0  
[conda] mkl_random                1.1.1            py37h0573a6f_0  
[conda] numpy                     1.18.5           py37ha1c710e_0  
[conda] numpy-base                1.18.5           py37hde5b4d6_0  
[conda] pytorch                   1.6.0           py3.7_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] torchvision               0.7.0                py37_cu101    pytorch
[conda] torchviz                  0.0.1                    pypi_0    pypi


@ghk829
Copy link
Author

ghk829 commented Aug 6, 2020

I'm sorry for not displaying the model because It's a model of my company..

I used some customlayer before applying the batchnorm

@kilasuelika
Copy link

I have the same problem as you. Did you find any solution?

@kilasuelika
Copy link

I have the same problem again. Strangely, if I run it on CPU, then there is no problem.

@kilasuelika
Copy link

kilasuelika commented Jan 15, 2021

Now I can provide a small example to reproduce the problem. I'm using GCC10.2 on ubuntu 20.04 with libtorch nightly build ann cuda 10.2.

In my following code, structure of DTNet is: conv - 4 layers of TGCNSABlock - linear layer. I use a Sequential to store 4 layers.

Observation:

  1. Run on cpu: no problem.
  2. Use a single TGCSABlock (not one layer DTNet): no problem.
  3. Removing sa->forward and einsum two lines or removing bn(x) in TGCSABlock.forward: no problem.
  4. Removing last linear layer in DTNet: no problem.

cmake output:

Found CUDA: /usr/local/cuda-10.2 (found version "10.2") 
-- Caffe2: CUDA detected: 10.2
-- Caffe2: CUDA nvcc is: /usr/local/cuda-10.2/bin/nvcc
-- Caffe2: CUDA toolkit directory: /usr/local/cuda-10.2
-- Caffe2: Header version is: 10.2
-- Found CUDNN: /usr/lib/x86_64-linux-gnu/libcudnn.so  
-- Found cuDNN: v8.0.5  (include: /usr/include, library: /usr/lib/x86_64-linux-gnu/libcudnn.so)
-- Autodetected CUDA architecture(s):  7.0 7.0 7.0 7.0

CMakeLists.txt:

cmake_minimum_required(VERSION 3.10)
project(main)

add_executable(mp src.cpp)

#Add fmt library
find_package(fmt)
target_link_libraries(mp fmt::fmt-header-only)

#torch
list(APPEND CMAKE_PREFIX_PATH ~/libtorch/share/cmake/Torch)
find_package(Torch)
target_link_libraries(mp ${TORCH_LIBRARIES})

target_compile_features(mp PUBLIC cxx_std_20)

DTNet.hpp:

#ifndef _DTNET_
#define _DTNET_
#define _DEBUG_BATCH_ 2
#include <torch/torch.h>
#include <fmt/format.h>

namespace F = torch::nn::functional;

class SABlock : public torch::nn::Module
{
public:
    SABlock(int in, int assets, int len)
    {

        //len is actually L of input NCSL.
        W1 = register_module("W1", torch::nn::Linear(torch::nn::LinearOptions(len, 1).bias(false)));
        W2 = register_module("W2", torch::nn::Linear(torch::nn::LinearOptions(in, len).bias(false)));
        W3 = register_module("W3", torch::nn::Linear(torch::nn::LinearOptions(in, 1).bias(false)));
        V = register_module("V", torch::nn::Linear(torch::nn::LinearOptions(assets, assets)));

        bn1 = register_module("bn1", torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(assets)));
        bn2 = register_module("bn2", torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(assets)));
        bn3 = register_module("bn3", torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(assets)));
    };
    auto forward(torch::Tensor x)
    {
        //x: NCSL
        auto x1 = x.permute({0, 2, 1, 3}); //NSCL
        auto x2 = x.permute({0, 2, 3, 1}); //NSLC

        x1 = bn1(W1(x1).squeeze(-1)); //NSC
        x1 = bn2(W2(x1));             //NSL

        x2 = bn3(W3(x2).squeeze(-1)).permute({0, 2, 1}); //NSL -> NLS

        auto y = F::softmax(V(F::relu(torch::bmm(x1, x2))), F::SoftmaxFuncOptions(-1));
        return y; //NSS
    };

private:
    torch::nn::Linear W1 = nullptr, W2 = nullptr, W3 = nullptr, V = nullptr;
    torch::nn::BatchNorm1d bn1 = nullptr, bn2 = nullptr, bn3 = nullptr;
};

class TGCNSABlock : public torch::nn::Module
{

public:
    TGCNSABlock(int in, int assets, int len, int hidden_dim, torch::Tensor &support, int dilation,
                int kernel = 2)
        : dilation(dilation),
          bn(register_module("bn", torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(hidden_dim))))
    {
        len -= dilation;

        sa = register_module("sa", std::make_shared<SABlock>(hidden_dim, assets, len));
    };
    auto forward(torch::Tensor x)
    {
        x = x.slice(3, -(x.size(3) - dilation));

        //Problem happens around here.
        auto attn_weights = sa->forward(x);
        x = torch::einsum("bnm, bfml->bfnl", {attn_weights, x});

        x = bn(x);
        return x;
    };

private:
    int dilation;
    std::shared_ptr<SABlock> sa;
    torch::nn::BatchNorm2d bn;
};

class DTNet : public torch::nn::Module
{
public:
    DTNet(int in, int assets, int len, int hidden_dim, torch::Tensor &support,
          double dropout = 0.3, int kernel_size = 2, int layers = 4)
        : layers(layers),
          conv(register_module("conv", torch::nn::Conv2d(torch::nn::Conv2dOptions(in, hidden_dim, 1)))),
          bn(register_module("bn", torch::nn::BatchNorm2d(hidden_dim))),
          tgcnsablock_seq(register_module("sagcnblock_seq", torch::nn::Sequential())),
          linear1(register_module("linear1", torch::nn::Linear(torch::nn::LinearOptions(hidden_dim, 1))))
    {
        //
        fmt::print("Constructing model...\n");
        int dilation = 1;
        auto x = torch::zeros({_DEBUG_BATCH_, hidden_dim, assets, len});
        fmt::print("Input: [{}, {}, {}, {}]\n", x.size(0), x.size(1), x.size(2), x.size(3));
        for (int i = 0; i < layers; ++i)
        {
            len = x.size(3);
            tgcnsablock_seq->push_back(std::make_shared<TGCNSABlock>(in, assets, len, hidden_dim, support, dilation));
            dilation *= 2;
            //Determine output size of each layer.
            {
                torch::NoGradGuard no_grad;
                auto mod = tgcnsablock_seq[i]->as<TGCNSABlock>();
                x = mod->forward(x);
                fmt::print("Output of layer {}: [{}, {}, {}, {}]\n", i, x.size(0), x.size(1), x.size(2), x.size(3));
            };
        };
        fmt::print("Constructing model finished.\n");
    };
    auto forward(torch::Tensor x)
    {
        x = bn(conv(x));
        x = tgcnsablock_seq->forward(x);
        x = x.squeeze(-1).permute({0, 2, 1});
        x = linear1(x).squeeze(-1);
        return x;
    };

private:
    torch::nn::Conv2d conv;
    torch::nn::BatchNorm2d bn;
    torch::nn::Sequential tgcnsablock_seq;
    torch::nn::Linear linear1;
    int layers;
};

#endif

src.cpp:

#include "DTNet.hpp"
#include <torch/torch.h>

int main()
{
    bool use_gpu = true;
    auto support = torch::randn({30, 30});
    auto x = torch::randn({2, 20, 30, 16});
    auto y = torch::randn({2, 30, 1});
    DTNet net(20, 30, 16, 128, support);
    //TGCNSABlock mod(20, 30, 16, 128, support, 2);
    if (use_gpu)
    {
        x = x.to(torch::kCUDA);
        y = y.to(torch::kCUDA);
        net.to(torch::kCUDA);
    };

    torch::optim::Adam optimizer(net.parameters(), torch::optim::AdamOptions(0.001));

    //Run backpropagation once.
    optimizer.zero_grad();
    x = net.forward(x);
    std::cout << "Final x: " << x.sizes() << std::endl;
    auto loss = torch::nn::functional::mse_loss(x.flatten(1), y.flatten(1));
    loss.backward();
    optimizer.step();
};

Error:

terminate called after throwing an instance of 'c10::Error'
  what():  Expected grad_output->is_contiguous(grad_output->suggest_memory_format()) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
Exception raised from cudnn_batch_norm_backward at ../aten/src/ATen/native/cudnn/BatchNorm.cpp:249 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x69 (0x7f40963687f9 in /home/zhouyao/libtorch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xd2 (0x7f4096365e22 in /home/zhouyao/libtorch/lib/libc10.so)
frame #2: at::native::cudnn_batch_norm_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, double, at::Tensor const&) + 0x1fda (0x7f40417d86aa in /home/zhouyao/libtorch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x3612572 (0x7f40418b9572 in /home/zhouyao/libtorch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x159b5ec (0x7f4085ee15ec in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #5: at::cudnn_batch_norm_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, double, at::Tensor const&) + 0xf3 (0x7f4085d3e9e3 in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0x2bffa77 (0x7f4087545a77 in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0x2c0017a (0x7f408754617a in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x159b5ec (0x7f4085ee15ec in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #9: at::cudnn_batch_norm_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, double, at::Tensor const&) + 0xf3 (0x7f4085d3e9e3 in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #10: torch::autograd::generated::CudnnBatchNormBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x7de (0x7f40874e8ebe in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x3249e31 (0x7f4087b8fe31 in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #12: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x16bd (0x7f4087b8b17d in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #13: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x28d (0x7f4087b8b9dd in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #14: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x9a (0x7f4087b8743a in /home/zhouyao/libtorch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0xd10ce (0x7f403d91b0ce in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #16: <unknown function> + 0x76db (0x7f408472e6db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #17: clone + 0x3f (0x7f403d36271f in /lib/x86_64-linux-gnu/libc.so.6)

Aborted (core dumped)

@ngimel ngimel added module: cudnn Related to torch.backends.cudnn, and CuDNN support module: dependency bug Problem is not caused by us, but caused by an upstream library we use labels Jan 15, 2021
@kilasuelika
Copy link

kilasuelika commented Jan 16, 2021

This is a smaller bad example.

However if change line x = linear1(x.permute({0, 2, 1, 3}).flatten(1)); in DTNet.forward() to x = linear1(x.flatten(1)); won't cause problem.

DTNet.hpp:

#define _DEBUG_BATCH_ 2
#include <torch/torch.h>
#include <fmt/format.h>

namespace F = torch::nn::functional;

class TGCNSABlock : public torch::nn::Module
{
public:
    TGCNSABlock(int in, int assets, int len, int hidden_dim, int dilation)
        : dilation(dilation),
          bn(register_module("bn", torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(hidden_dim))))
    {
        len -= dilation;
    };
    auto forward(torch::Tensor x)
    {
        x = x.slice(3, -(x.size(3) - dilation));

        //Problem happens around here.
        auto attn_weights = x.slice(3, 0, 1).squeeze(-1).slice(1, 0, 30);
        x = torch::einsum("bnm, bfml->bfnl", {attn_weights, x});

        x = bn(x);
        return x;
    };

private:
    int dilation;
    torch::nn::BatchNorm2d bn;
};

class DTNet : public torch::nn::Module
{
public:
    DTNet(int in, int assets, int len, int hidden_dim)
        : conv(register_module("conv", torch::nn::Conv2d(torch::nn::Conv2dOptions(in, hidden_dim, 1)))),
          bn(register_module("bn", torch::nn::BatchNorm2d(hidden_dim))),
          tgcnsablock_seq(register_module("sagcnblock_seq", torch::nn::Sequential())),
          linear1(register_module("linear1", torch::nn::Linear(torch::nn::LinearOptions(hidden_dim * assets * 1, 2))))
    {
        //
        fmt::print("Constructing model...\n");
        int dilation = 1;
        auto x = torch::zeros({_DEBUG_BATCH_, hidden_dim, assets, len});
        fmt::print("Input: [{}, {}, {}, {}]\n", x.size(0), x.size(1), x.size(2), x.size(3));
        for (int i = 0; i < 4; ++i)
        {
            len = x.size(3);
            tgcnsablock_seq->push_back(std::make_shared<TGCNSABlock>(in, assets, len, hidden_dim, dilation));
            dilation *= 2;
            //Determine output size of each layer.
            {
                torch::NoGradGuard no_grad;
                auto mod = tgcnsablock_seq[i]->as<TGCNSABlock>();
                x = mod->forward(x);
                fmt::print("Output of layer {}: [{}, {}, {}, {}]\n", i, x.size(0), x.size(1), x.size(2), x.size(3));
            };
        };
        fmt::print("Constructing model finished.\n");
    };
    auto forward(torch::Tensor x)
    {
        x = bn(conv(x));
        x = tgcnsablock_seq->forward(x);
        //Change this line to x = linear1(x.flatten(1)); doesn't cause the problem
        x = linear1(x.permute({0, 2, 1, 3}).flatten(1));
        return x;
    };

private:
    torch::nn::Conv2d conv;
    torch::nn::BatchNorm2d bn;
    torch::nn::Sequential tgcnsablock_seq;
    torch::nn::Linear linear1;
};

src.cpp:

#include "DTNet.hpp"
#include <torch/torch.h>

int main()
{
    bool use_gpu = true;
    auto x = torch::randn({2, 20, 30, 16});
    auto y = torch::randn({2, 2});
    DTNet net(20, 30, 16, 128);
    if (use_gpu)
    {
        x = x.to(torch::kCUDA);
        y = y.to(torch::kCUDA);
        net.to(torch::kCUDA);
    };

    torch::optim::Adam optimizer(net.parameters(), torch::optim::AdamOptions(0.001));

    //Run backpropagation once.
    optimizer.zero_grad();
    x = net.forward(x);
    std::cout << "Final x: " << x.sizes() << std::endl;
    auto loss = torch::nn::functional::mse_loss(x.flatten(1), y.flatten(1));
    loss.backward();
    optimizer.step();
};

@kilasuelika
Copy link

A super small bad example:
DTNet.hpp:

#include <torch/torch.h>

namespace F = torch::nn::functional;

class DTNet : public torch::nn::Module
{
public:
    DTNet(int in, int assets)
        : bn(register_module("bn", torch::nn::BatchNorm2d(in))),
          linear1(register_module("linear1", torch::nn::Linear(torch::nn::LinearOptions(in * assets * 1, 2)))){};
    auto forward(torch::Tensor x)
    {
        //If removing this line then no problem.
        x = x.slice(3, 0, 1);
        auto attn_weights = x.slice(3, 0, 1).squeeze(-1);
        x = torch::einsum("bnm, bfml->bfnl", {attn_weights, x});

        x = bn(x);

        std::cout << "size x: " << x.sizes() << std::endl;
        //Change this line to x = linear1(x.flatten(1)); doesn't cause the problem
        x = linear1(x.permute({0, 2, 1, 3}).flatten(1));
        std::cout << "size x: " << x.sizes() << std::endl;
        //x = linear1(x.flatten(1));
        return x;
    };

private:
    torch::nn::BatchNorm2d bn;
    torch::nn::Linear linear1;
};

src.cpp

#include "DTNet.hpp"

int main()
{
    bool use_gpu = true;
    auto x = torch::randn({2, 30, 30, 2});
    auto y = torch::randn({2, 2});
    DTNet net(30, 30);
    if (use_gpu)
    {
        x = x.to(torch::kCUDA);
        y = y.to(torch::kCUDA);
        net.to(torch::kCUDA);
    };

    torch::optim::Adam optimizer(net.parameters(), torch::optim::AdamOptions(0.001));

    //Run backpropagation once.
    optimizer.zero_grad();
    x = net.forward(x);
    std::cout << "Final x: " << x.sizes() << std::endl;
    auto loss = torch::nn::functional::mse_loss(x.flatten(1), y.flatten(1));
    loss.backward();
    optimizer.step();
};

It looks like that problem is about the torch::einsum when the dim of x is [x,x,x,1]. If don't silce(3,0,1), then x will be [2 ,30, 30,2] and no problem.

@ngimel ngimel removed module: cudnn Related to torch.backends.cudnn, and CuDNN support module: dependency bug Problem is not caused by us, but caused by an upstream library we use labels Jan 16, 2021
@ngimel
Copy link
Collaborator

ngimel commented Jan 16, 2021

Awesome, thanks for the repro, I can reproduce, we'll take a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants