Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assert Tripped in cross entropy when using expanded tensor #71550

Closed
tmarkovich opened this issue Jan 20, 2022 · 4 comments
Closed

Assert Tripped in cross entropy when using expanded tensor #71550

tmarkovich opened this issue Jan 20, 2022 · 4 comments
Labels
high priority module: error checking Bugs related to incorrect/lacking error checking module: loss Problem is related to loss function module: regression It used to work, and now it doesn't triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@tmarkovich
Copy link

tmarkovich commented Jan 20, 2022

馃悰 Describe the bug

import torch
from torch.nn import functional as F
npos = 1000
logits = torch.zeros((npos, 2), device="cuda:0", requires_grad=True)

for _ in range(10):
    targets = torch.zeros((), dtype=torch.long, device="cuda:0").expand(npos)
    loss = F.cross_entropy(logits, targets)
    loss.backward()

This code block throws the following errors

/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [8,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [9,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [10,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [11,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [12,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [13,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [14,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [15,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [16,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [17,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [18,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [19,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [20,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [21,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [22,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [23,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [24,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [25,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [26,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [27,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [28,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [29,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [30,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1634272178570/work/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [31,0,0] Assertion `t >= 0 && t < n_classes` failed.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_12302/1279623632.py in <module>
      4     targets = torch.zeros((), dtype=torch.long, device="cuda:0").expand(npos)
      5     loss = F.cross_entropy(logits, targets)
----> 6     loss.backward()

/opt/conda/lib/python3.7/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    305                 create_graph=create_graph,
    306                 inputs=inputs)
--> 307         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    308 
    309     def register_hook(self, hook):

/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    154     Variable._execution_engine.run_backward(
    155         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 156         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    157 
    158 

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

This code runs fine on PyTorch 1.9 but throws the above errors on 1.10.

Versions

(base) jupyter@pytorch-1-10-20220110-143858:~$ python get_env.py 
Collecting environment information...
PyTorch version: 1.10.0
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 8.3.0-6) 8.3.0
Clang version: Could not collect
CMake version: version 3.13.4
Libc version: glibc-2.10

Python version: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:53)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-4.19.0-18-cloud-amd64-x86_64-with-debian-10.11
Is CUDA available: True
CUDA runtime version: 11.0.221
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 460.73.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.19.5
[pip3] torch==1.10.0
[pip3] torchbiggraph==1.0.1.dev0
[pip3] torchvision==0.11.1+cu111
[conda] blas                      2.112                       mkl    conda-forge
[conda] blas-devel                3.9.0            12_linux64_mkl    conda-forge
[conda] cudatoolkit               11.1.1               h6406543_9    conda-forge
[conda] dlenv-pytorch-1-10-gpu    1.0.20211218     py37h0ee201a_0    file:///tmp/conda-pkgs
[conda] libblas                   3.9.0            12_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            12_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            12_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            12_linux64_mkl    conda-forge
[conda] mkl                       2021.4.0           h8d4b97c_729    conda-forge
[conda] mkl-devel                 2021.4.0           ha770c72_730    conda-forge
[conda] mkl-include               2021.4.0           h8d4b97c_729    conda-forge
[conda] mypy_extensions           0.4.3            py37h89c1867_4    conda-forge
[conda] numpy                     1.19.5           py37h038b26d_2    conda-forge
[conda] pytorch                   1.10.0          py3.7_cuda11.1_cudnn8.0.5_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchbiggraph             1.0.1.dev0               pypi_0    pypi
[conda] torchvision               0.11.1+cu111             pypi_0    pypi

cc @ezyang @gchanan @zou3519 @bdhirsh

@tmarkovich
Copy link
Author

Screen Shot 2022-01-20 at 9 34 47 AM

@zou3519 zou3519 added high priority module: error checking Bugs related to incorrect/lacking error checking module: loss Problem is related to loss function triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 20, 2022
@zou3519
Copy link
Contributor

zou3519 commented Jan 20, 2022

Tentatively marking hi-pri because this seems like something that should work that people will run into. I'm not sure if this is a regression or not

@tmarkovich
Copy link
Author

It seems to have worked in pytorch 1.9. Looking at the git-blame for Loss.cu, it appears that there was a relatively large rewrite to it a few months back -- perhaps that's the issue?

@zou3519 zou3519 added the module: regression It used to work, and now it doesn't label Jan 20, 2022
@ngimel
Copy link
Collaborator

ngimel commented Jan 20, 2022

This is fixed in 1.10.1. nightlies and master, please reopen if you still see the issue after updating.

@ngimel ngimel closed this as completed Jan 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: error checking Bugs related to incorrect/lacking error checking module: loss Problem is related to loss function module: regression It used to work, and now it doesn't triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants