You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, is it a known limitation that jit.trace will ignore temporary requires_grad = False or a bug?
To Reproduce
# EXAMPLE 1
import torch
from torch import nn, jit
from torch.optim import SGD
inputs = torch.tensor([2.0], device="cuda")
model = nn.Linear(1, 1, bias=False).to("cuda")
optimizer = SGD(model.parameters(), lr=1e-1)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.model = model
def forward(self, x):
param = next(self.parameters())
param.requires_grad = True
x = self.model(x).mean()
param.requires_grad = False
return x
c = MyModule()
forward = jit.trace(c, (inputs,))
result = forward(inputs)
result.mean().backward()
optimizer.step()
optimizer.zero_grad()
print("It does not work fine!")
Traceback (most recent call last):
File "src/run_jit.py", line 31, in <module>
result.mean().backward()
File "/home/tim/miniconda3/envs/core/lib/python3.7/site-packages/torch/tensor.py", line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/tim/miniconda3/envs/core/lib/python3.7/site-packages/torch/autograd/__init__.py", line 147, in backward
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
But when I switch the requires_grad flag:
# EXAMPLE 2
import torch
from torch import nn, jit
from torch.optim import SGD
inputs = torch.tensor([2.0], device="cuda")
model = nn.Linear(1, 1, bias=False).to("cuda")
optimizer = SGD(model.parameters(), lr=1e-1)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.model = model
def forward(self, x):
param = next(self.parameters())
param.requires_grad = False # True --> False
x = self.model(x).mean()
param.requires_grad = True # False --> True
return x
c = MyModule()
forward = jit.trace(c, (inputs,))
result = forward(inputs)
result.mean().backward()
optimizer.step()
optimizer.zero_grad()
print("It does work fine!")
It works fine!
--
Expected behavior
Example 2 should fail with the error from example 1, while example 1 should run just fine.
Environment
PyTorch version: 1.8.0
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: Could not collect
Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.8.0
[pip3] torchaudio==0.8.0a0+a751e1d
[pip3] torchvision==0.9.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.1.1 h6406543_8 conda-forge
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2020.4 h726a3e6_304 conda-forge
[conda] mkl-service 2.3.0 py37h8f50634_2 conda-forge
[conda] mkl_fft 1.3.0 py37h902c9e0_1 conda-forge
[conda] mkl_random 1.2.0 py37h9fdb41a_1 conda-forge
[conda] numpy 1.19.2 py37h54aff64_0
[conda] numpy-base 1.19.2 py37hfa32c7d_0
[conda] pytorch 1.8.0 py3.7_cuda11.1_cudnn8.0.5_0 pytorch
[conda] torchaudio 0.8.0 py37 pytorch
[conda] torchvision 0.9.0 py37_cu111 pytorch
馃悰 Bug
Hi, is it a known limitation that jit.trace will ignore temporary
requires_grad = False
or a bug?To Reproduce
But when I switch the
requires_grad
flag:--
Expected behavior
Example 2 should fail with the error from example 1, while example 1 should run just fine.
Environment
cc @gmagogsfm
The text was updated successfully, but these errors were encountered: