Skip to content

[export] run_decompositions fails on torch.ops.aten.index_put_ #141336

@xadupre

Description

@xadupre

🐛 Describe the bug

This example started to fail yesterday with the nightly build.

import torch

class UpdateModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.params = torch.zeros((4, 4, 10))

    def forward(self, update, index1, index2):
        copy = self.params.clone()
        copy[index1, torch.tensor([1, 2], dtype=torch.int64), index2] = update
        return copy

model = UpdateModel()

update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32)
index1 = torch.tensor([1, 2]).to(torch.int64)
index2 = torch.tensor([7, 8]).to(torch.int64)
model(update, index1, index2)

ep = torch.export.export(model, (update, index1, index2))
print(ep.graph)
ep.run_decompositions()  # Fails here
graph():
    %c_params : [num_users=1] = placeholder[target=c_params]
    %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
    %update : [num_users=1] = placeholder[target=update]
    %index1 : [num_users=1] = placeholder[target=index1]
    %index2 : [num_users=1] = placeholder[target=index2]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_params,), kwargs = {})
    %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
    %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
    %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [%index1, %detach_, %index2], %update), kwargs = {})
    return (index_put_,)


  File "site-packages/torch/fx/interpreter.py", line 228, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "site-packages/torch/fx/interpreter.py", line 308, in call_function
    return target(*args, **kwargs)
  File "site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
  File "site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
RuntimeError: false INTERNAL ASSERT FAILED at "/pytorch/build/aten/src/ATen/RegisterFunctionalization_1.cpp":5939, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

While executing %index_put_ : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%clone, [%index1, %detach_, %index2], %update), kwargs = {})
Original traceback:
  File "test_issue_pytorch_2024_export.py", line 32, in forward
    copy[index1, torch.tensor([1, 2], dtype=torch.int64), index2] = update

Versions

PyTorch version: 2.6.0.dev20241121+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.68
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
Nvidia driver version: 538.92
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.3.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 13th Gen Intel(R) Core(TM) i7-13800H
CPU family: 6
Model: 186
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 2
BogoMIPS: 5836.79
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 12.5 MiB (10 instances)
L3 cache: 24 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] bert_pytorch==0.0.1a4
[pip3] clip-anytorch==2.6.0
[pip3] CoCa-pytorch==0.1.0
[pip3] dalle2-pytorch==1.15.6
[pip3] ema-pytorch==0.7.0
[pip3] executorch==0.4.0
[pip3] flake8==7.1.1
[pip3] mypy==1.11.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.1.3
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] onnx==1.18.0
[pip3] onnx-extended==0.3.0
[pip3] onnxconverter-common==1.14.0
[pip3] onnxruntime-gpu==1.21.0
[pip3] onnxruntime-training==1.21.0+cu121
[pip3] onnxscript==0.1.0.dev20240905
[pip3] open_clip_torch==2.26.1
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] pytorch-warmup==0.1.1
[pip3] rotary-embedding-torch==0.8.4
[pip3] torch==2.6.0.dev20241121+cu124
[pip3] torch-fidelity==0.3.0
[pip3] torch_geometric==2.4.0
[pip3] torchao==0.5.0
[pip3] torchaudio==2.5.0.dev20241121+cu124
[pip3] torchmetrics==1.4.3
[pip3] torchvision==0.20.0.dev20241121+cu124
[pip3] triton==3.1.0
[pip3] vector-quantize-pytorch==1.18.1
[conda] Could not collect

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions