Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misaligned Address / Lane User Stack Overflow in cunn_SpatialSoftmax #56325

Closed
ptrblck opened this issue Apr 17, 2021 · 5 comments
Closed

Misaligned Address / Lane User Stack Overflow in cunn_SpatialSoftmax #56325

ptrblck opened this issue Apr 17, 2021 · 5 comments
Labels
high priority module: cuda Related to torch.cuda, and CUDA support in general triage review

Comments

@ptrblck
Copy link
Collaborator

ptrblck commented Apr 17, 2021

馃悰 Bug

Reported in the forum by cameronb (thanks for reporting this issue!)

To Reproduce

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
import torch.nn as nn

device = torch.device("cuda")

def make_token_tensor(id, vocab_len, should_squeeze=True):
  # print("Making token tensor of id:", id)
  t = torch.zeros(vocab_len).to(device)
  t[id] = 1
  if should_squeeze:
    return t.unsqueeze(0).unsqueeze(0)
  else:
    return t

h_size = 1536 # The Hidden size that goes into the decoder
o_size = 30522 # 30522 = Vocabulary size of default BERT tokenizer

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.gru = nn.GRU(output_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output, hidden = self.gru(input, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden


decoder = DecoderRNN(h_size, o_size)
decoder.to(device)

# The initial inputs to the decoder
d_hidden = torch.rand((1,1,h_size)).to(device)
prev_token_pred = make_token_tensor(1, o_size) # Has dimensions 1 x 1 x o_size

ans_tokens = [1, 2, 3, 4] # Imagine that in a real model these would be used for teacher forcing
max_len = len(ans_tokens)
seq_preds = []
for i in range(max_len):
  token_pred, d_hidden = decoder(prev_token_pred, d_hidden)
  prev_token_id = torch.argmax(token_pred)
  prev_token_pred = make_token_tensor(prev_token_id, o_size)
  seq_preds.append(token_pred.squeeze(0))
test_preds = torch.stack(seq_preds)

loss = nn.NLLLoss()
input = test_preds
# each element in target has to have 0 <= value < C
target = torch.tensor(ans_tokens).to(device)
output = loss(input, target)
output.backward()

Original error message:

CUDA error: Misaligned Address

$pc info:

(cuda-gdb) x/4i $pc-32
   0x5579fb8d3ab0 <_ZN2at6native78_GLOBAL__N__54_tmpxft_00004b40_00000000_13_SoftMax_compute_86_cpp1_ii_a331004220cunn_SoftMaxBackwardILi4EfffNS1_26LogSoftMaxBackwardEpilogueEEEvPT0_PT2_S7_i+1200>:   SHF.L.U32 R12, R13, 0x2, RZ
   0x5579fb8d3ac0 <_ZN2at6native78_GLOBAL__N__54_tmpxft_00004b40_00000000_13_SoftMax_compute_86_cpp1_ii_a331004220cunn_SoftMaxBackwardILi4EfffNS1_26LogSoftMaxBackwardEpilogueEEEvPT0_PT2_S7_i+1216>:   ISETP.GE.AND P2, PT, R12, R15, PT
=> 0x5579fb8d3ad0 <_ZN2at6native78_GLOBAL__N__54_tmpxft_00004b40_00000000_13_SoftMax_compute_86_cpp1_ii_a331004220cunn_SoftMaxBackwardILi4EfffNS1_26LogSoftMaxBackwardEpilogueEEEvPT0_PT2_S7_i+1232>:   FADD R8, R4, R9
   0x5579fb8d3ae0 <_ZN2at6native78_GLOBAL__N__54_tmpxft_00004b40_00000000_13_SoftMax_compute_86_cpp1_ii_a331004220cunn_SoftMaxBackwardILi4EfffNS1_26LogSoftMaxBackwardEpilogueEEEvPT0_PT2_S7_i+1248>:   FADD R9, R5, R8
(cuda-gdb) info registers $R8 $R4 $R9
R8             0xffff88c5          -30523
R4             0x0                 0
R9             0x0                 0

After rebuilding with -g -G the error changes to:

CUDA_EXCEPTION_2, Lane User Stack Overflow.

Backtrace:

(cuda-gdb) bt
#0  0x000055abe3dafad0 in void at::native::ReduceOp<float, at::native::ArgMaxOps<float>, unsigned int, long, 4>::run<1>() const ()
#1  0x000055abe292f720 in void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::native::ArgMaxOps<float>, unsigned int, long, 4> >(at::native::ReduceOp<float, at::native::ArgMaxOps<float>, unsigned int, long, 4>)
   <<<(1,1,1),(512,1,1)>>> ()

I'm currently unsure, if the stack overflow might be caused by the debug flags or if it's the real issue.

Anyway, both issues point to cunn_SpatialSoftMax.

Environment

PyTorch version: 1.9.0a0+2ecb2c7
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.19.6

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.3.58
GPU models and configuration:
GPU 0: A100-SXM4-40GB
[...]

Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.0
[..]
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] nvidia-dlprof-pytorch-nvtx==1.1.0
[pip3] pytorch-quantization==2.1.0
[pip3] pytorch-transformers==1.1.0
[pip3] torch==1.9.0a0+2ecb2c7
[pip3] torchtext==0.10.0a0
[pip3] torchvision==0.9.0a0
[conda] magma-cuda110             2.5.2                         5    local
[conda] mkl                       2019.4                      243
[conda] mkl-include               2019.4                      243
[conda] nomkl                     3.0                           0
[conda] numpy                     1.19.2           py38h6163131_0
[conda] numpy-base                1.19.2           py38h75fe3a5_0
[conda] nvidia-dlprof-pytorch-nvtx 1.1.0                    pypi_0    pypi
[conda] pytorch-quantization      2.1.0                    pypi_0    pypi
[conda] pytorch-transformers      1.1.0                    pypi_0    pypi
[conda] torch                     1.9.0a0+2ecb2c7           dev_0    <develop>
[conda] torchtext                 0.10.0a0                 pypi_0    pypi
[conda] torchvision               0.9.0a0                  pypi_0    pypi

@eqy would you like to take a shot at it?

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @ngimel

@ngimel ngimel added module: cuda Related to torch.cuda, and CUDA support in general high priority labels Apr 19, 2021
@ngimel
Copy link
Collaborator

ngimel commented Apr 19, 2021

HIgh priority for a crash

@eqy
Copy link
Collaborator

eqy commented Apr 19, 2021

After a quick look, it seems the failure appears in the call to blockReduce. It doesn't look like sdata is misaligned so another part of the setup for the reduction may be incorrect? Will take a deeper look.

@ngimel
Copy link
Collaborator

ngimel commented Apr 19, 2021

ilpReduce should be called with grad_output_shift, and not shift

shift, gradOutput, classes, AddFloat<outscalar_t, accscalar_t>(), accscalar_t(0));

@eqy
Copy link
Collaborator

eqy commented Apr 19, 2021

ilpReduce should be called with grad_output_shift, and not shift

shift, gradOutput, classes, AddFloat<outscalar_t, accscalar_t>(), accscalar_t(0));

Yup, can confirm this fixes the issue on V100.

eqy added a commit to eqy/pytorch that referenced this issue Apr 19, 2021
facebook-github-bot pushed a commit that referenced this issue Apr 20, 2021
Summary:
CC ngimel ptrblck
ref: #56325

Pull Request resolved: #56403

Reviewed By: mruberry

Differential Revision: D27866625

Pulled By: ngimel

fbshipit-source-id: 9dff0e9749f8de57fac6a653f685c14854611a02
@ngimel
Copy link
Collaborator

ngimel commented Apr 26, 2021

Fixed in #56304

@ngimel ngimel closed this as completed Apr 26, 2021
krshrimali pushed a commit to krshrimali/pytorch that referenced this issue May 19, 2021
Summary:
CC ngimel ptrblck
ref: pytorch#56325

Pull Request resolved: pytorch#56403

Reviewed By: mruberry

Differential Revision: D27866625

Pulled By: ngimel

fbshipit-source-id: 9dff0e9749f8de57fac6a653f685c14854611a02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cuda Related to torch.cuda, and CUDA support in general triage review
Projects
None yet
Development

No branches or pull requests

3 participants