-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Closed
Labels
module: advanced indexingRelated to x[i] = y, index functionsRelated to x[i] = y, index functionsmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: determinismtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
When I run advanced integer indexing on cuda with deterministic algorithms I get
RuntimeError: linearIndex.numel()*sliceSize*nElemBefore == value.numel()INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/cuda/Indexing.cu":250, please report a bug to PyTorch. number of flattened indices did not match number of elements in the value tensor21
This bug is similar to #61032, but it affects integer indexing instead of boolean indexing.
To Reproduce
Steps to reproduce the behavior:
- Run the following script
import torch torch.use_deterministic_algorithms(True) x = torch.zeros(5).cuda() x[torch.tensor([1, 3]).cuda()] = 2 print(x)
- Get the error
Traceback (most recent call last): File "Distilled Example.py", line 6, in <module> x[torch.tensor([1, 3]).cuda()] = 2 RuntimeError: linearIndex.numel()*sliceSize*nElemBefore == value.numel()INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/cuda/Indexing.cu":250, please report a bug to PyTorch. number of flattened indices did not match number of elements in the value tensor21
Expected behavior
Not getting a runtime error
Environment
Collecting environment information...
PyTorch version: 1.10.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2
Libc version: glibc-2.26
Python version: 3.7.9 (default, Nov 15 2021, 19:06:29) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.141-67.229.amzn2.x86_64-x86_64-with-debian-buster-sid
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: Tesla K80
Nvidia driver version: 460.73.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] mypy==0.910
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.18.0
[pip3] torch==1.10.0
[pip3] torch-scatter==2.0.9
[pip3] torchvision==0.11.0
[conda] Could not collect
Additional context
I originally get this error when I try to run torchvision
FasterRCNN
on cuda with deterministic algorithms:
import torch
torch.use_deterministic_algorithms(True)
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
model = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False).cuda()
model(
torch.zeros(1, 3, 800, 800).cuda(), [{
'boxes': torch.tensor([[100, 200, 300, 400]]).cuda(),
'labels': torch.tensor([1]).cuda(),
}])
Metadata
Metadata
Assignees
Labels
module: advanced indexingRelated to x[i] = y, index functionsRelated to x[i] = y, index functionsmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: determinismtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module