-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
module: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
jacrev
failed to compute the gradient for torch.take
. By contrast, the backward
can compute without any error
import torch
from torch.func import jacrev
torch.manual_seed(420)
x = torch.rand(1000000)
def func(x):
y = torch.tensor([1,1,1,1,1,1,1,1,1,1])
z = torch.take(x, y)
return z
x_clone = x.clone().requires_grad_()
func(x_clone).sum().backward()
print(x_clone.grad)
# tensor([ 0., 10., 0., ..., 0., 0., 0.])
jacrev(func)(x)
# RuntimeError: vmap: aten::put_(self, *extra_args) is not possible because there exists a Tensor `other`
# in extra_args that has more elements than `self`.
# This happened due to `other` being vmapped over but `self` not being vmapped over at level 1.
# Please try to use out-of-place operators instead of aten::put_.
# If said operator is being called inside the PyTorch framework, please file a bug report instead.
Versions
PyTorch version: 2.0.0.dev20230105
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.9.15 (main, Nov 24 2022, 14:31:59) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
Nvidia driver version: 515.86.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0.dev20230105
[pip3] torchaudio==2.0.0.dev20230105
[pip3] torchvision==0.15.0.dev20230105
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.23.5 py39h14f4228_0
[conda] numpy-base 1.23.5 py39h31eccc5_0
[conda] pytorch 2.0.0.dev20230105 py3.9_cuda11.7_cudnn8.5.0_0 pytorch-nightly
[conda] pytorch-cuda 11.7 h67b0de4_2 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 2.0.0.dev20230105 py39_cu117 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 py39 pytorch-nightly
[conda] torchvision 0.15.0.dev20230105 py39_cu117 pytorch-nightly
cc @zou3519 @Chillee @samdow @soumith @kshitij12345 @janeyx99
Metadata
Metadata
Assignees
Labels
module: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module