Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] jit.trace does not support parameter.requires_grad? #53515

Closed
mctigger opened this issue Mar 8, 2021 · 1 comment
Closed

[JIT] jit.trace does not support parameter.requires_grad? #53515

mctigger opened this issue Mar 8, 2021 · 1 comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects

Comments

@mctigger
Copy link

mctigger commented Mar 8, 2021

馃悰 Bug

Hi, is it a known limitation that jit.trace will ignore temporary requires_grad = False or a bug?

To Reproduce

# EXAMPLE 1
import torch
from torch import nn, jit
from torch.optim import SGD


inputs = torch.tensor([2.0], device="cuda")
model = nn.Linear(1, 1, bias=False).to("cuda")

optimizer = SGD(model.parameters(), lr=1e-1)


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = model

    def forward(self, x):
        param = next(self.parameters())

        param.requires_grad = True
        x = self.model(x).mean()
        param.requires_grad = False
        return x


c = MyModule()
forward = jit.trace(c, (inputs,))
result = forward(inputs)

result.mean().backward()

optimizer.step()
optimizer.zero_grad()

print("It does not work fine!")
Traceback (most recent call last):
  File "src/run_jit.py", line 31, in <module>
    result.mean().backward()
  File "/home/tim/miniconda3/envs/core/lib/python3.7/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/tim/miniconda3/envs/core/lib/python3.7/site-packages/torch/autograd/__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

But when I switch the requires_grad flag:

# EXAMPLE 2
import torch
from torch import nn, jit
from torch.optim import SGD


inputs = torch.tensor([2.0], device="cuda")
model = nn.Linear(1, 1, bias=False).to("cuda")

optimizer = SGD(model.parameters(), lr=1e-1)


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = model

    def forward(self, x):
        param = next(self.parameters())

        param.requires_grad = False # True --> False
        x = self.model(x).mean()
        param.requires_grad = True # False --> True
        return x


c = MyModule()
forward = jit.trace(c, (inputs,))
result = forward(inputs)

result.mean().backward()

optimizer.step()
optimizer.zero_grad()

print("It does work fine!")
It works fine!

--

Expected behavior

Example 2 should fail with the error from example 1, while example 1 should run just fine.

Environment

PyTorch version: 1.8.0
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti

Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.8.0
[pip3] torchaudio==0.8.0a0+a751e1d
[pip3] torchvision==0.9.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.1.1               h6406543_8    conda-forge
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2020.4             h726a3e6_304    conda-forge
[conda] mkl-service               2.3.0            py37h8f50634_2    conda-forge
[conda] mkl_fft                   1.3.0            py37h902c9e0_1    conda-forge
[conda] mkl_random                1.2.0            py37h9fdb41a_1    conda-forge
[conda] numpy                     1.19.2           py37h54aff64_0  
[conda] numpy-base                1.19.2           py37hfa32c7d_0  
[conda] pytorch                   1.8.0           py3.7_cuda11.1_cudnn8.0.5_0    pytorch
[conda] torchaudio                0.8.0                      py37    pytorch
[conda] torchvision               0.9.0                py37_cu111    pytorch

cc @gmagogsfm

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 8, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage Mar 8, 2021
@gmagogsfm
Copy link
Contributor

Yes, jit.trace only records Tensor operations. Modifying attributes of tensor objects are not recorded by design.

JIT Triage automation moved this from Need triage to Done Mar 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
JIT Triage
  
Done
Development

No branches or pull requests

3 participants