Skip to content

jacrev failed to compute the gradient for torch.take due to error of vmap #95738

@cafffeeee

Description

@cafffeeee

🐛 Describe the bug

jacrev failed to compute the gradient for torch.take. By contrast, the backward can compute without any error

import torch
from torch.func import jacrev

torch.manual_seed(420)

x = torch.rand(1000000)

def func(x):
    y = torch.tensor([1,1,1,1,1,1,1,1,1,1])
    z = torch.take(x, y)
    return z

x_clone = x.clone().requires_grad_()
func(x_clone).sum().backward()
print(x_clone.grad)
# tensor([ 0., 10.,  0.,  ...,  0.,  0.,  0.])

jacrev(func)(x)
# RuntimeError: vmap: aten::put_(self, *extra_args) is not possible because there exists a Tensor `other` 
# in extra_args that has more elements than `self`. 
# This happened due to `other` being vmapped over but `self` not being vmapped over at level 1. 
# Please try to use out-of-place operators instead of aten::put_. 
# If said operator is being called inside the PyTorch framework, please file a bug report instead.

Versions

PyTorch version: 2.0.0.dev20230105
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.9.15 (main, Nov 24 2022, 14:31:59)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090

Nvidia driver version: 515.86.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0.dev20230105
[pip3] torchaudio==2.0.0.dev20230105
[pip3] torchvision==0.15.0.dev20230105
[conda] blas                      1.0                         mkl
[conda] mkl                       2021.4.0           h06a4308_640
[conda] mkl-service               2.4.0            py39h7f8727e_0
[conda] mkl_fft                   1.3.1            py39hd3c417c_0
[conda] mkl_random                1.2.2            py39h51133e4_0
[conda] numpy                     1.23.5           py39h14f4228_0
[conda] numpy-base                1.23.5           py39h31eccc5_0
[conda] pytorch                   2.0.0.dev20230105 py3.9_cuda11.7_cudnn8.5.0_0    pytorch-nightly
[conda] pytorch-cuda              11.7                 h67b0de4_2    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchaudio                2.0.0.dev20230105      py39_cu117    pytorch-nightly
[conda] torchtriton               2.0.0+0d7e753227            py39    pytorch-nightly
[conda] torchvision               0.15.0.dev20230105      py39_cu117    pytorch-nightly

cc @zou3519 @Chillee @samdow @soumith @kshitij12345 @janeyx99

Metadata

Metadata

Assignees

Labels

module: functorchPertaining to torch.func or pytorch/functorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions