Skip to content

Deterministic integer indexing operation fails in indices size check / missing broadcast #68525

@dbalchev

Description

@dbalchev

🐛 Bug

When I run advanced integer indexing on cuda with deterministic algorithms I get

RuntimeError: linearIndex.numel()*sliceSize*nElemBefore == value.numel()INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/cuda/Indexing.cu":250, please report a bug to PyTorch. number of flattened indices did not match number of elements in the value tensor21

This bug is similar to #61032, but it affects integer indexing instead of boolean indexing.

To Reproduce

Steps to reproduce the behavior:

  1. Run the following script
    import torch
    
    torch.use_deterministic_algorithms(True)
    
    x = torch.zeros(5).cuda()
    x[torch.tensor([1, 3]).cuda()] = 2
    print(x)
  2. Get the error
    Traceback (most recent call last):
      File "Distilled Example.py", line 6, in <module>
        x[torch.tensor([1, 3]).cuda()] = 2
    RuntimeError: linearIndex.numel()*sliceSize*nElemBefore == value.numel()INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/cuda/Indexing.cu":250, please report a bug to PyTorch. number of flattened indices did not match number of elements in the value tensor21
    

Expected behavior

Not getting a runtime error

Environment

Collecting environment information...
PyTorch version: 1.10.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2
Libc version: glibc-2.26

Python version: 3.7.9 (default, Nov 15 2021, 19:06:29)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.141-67.229.amzn2.x86_64-x86_64-with-debian-buster-sid
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: Tesla K80
Nvidia driver version: 460.73.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy==0.910
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.18.0
[pip3] torch==1.10.0
[pip3] torch-scatter==2.0.9
[pip3] torchvision==0.11.0
[conda] Could not collect

Additional context

I originally get this error when I try to run torchvision FasterRCNN on cuda with deterministic algorithms:

import torch

torch.use_deterministic_algorithms(True)

from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn

model = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False).cuda()

model(
    torch.zeros(1, 3, 800, 800).cuda(), [{
        'boxes': torch.tensor([[100, 200, 300, 400]]).cuda(), 
        'labels': torch.tensor([1]).cuda(),
    }])

cc @ngimel @mruberry @kurtamohler

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: advanced indexingRelated to x[i] = y, index functionsmodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: determinismtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions