Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

jit tracing error for nn.Sequential with nn.Conv2d in torch 1.1.0 #20101

Closed
oldnaari opened this issue May 3, 2019 · 9 comments
Closed

jit tracing error for nn.Sequential with nn.Conv2d in torch 1.1.0 #20101

oldnaari opened this issue May 3, 2019 · 9 comments
Assignees

Comments

@oldnaari
Copy link

@oldnaari oldnaari commented May 3, 2019

馃悰 Bug

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

when tracing nn.Sequential with nn.Conv2d in torch 1.1.0

To Reproduce

Steps to reproduce the behavior:

from torch import nn
import torch.jit

model = nn.Sequential(nn.Conv2d(2, 2, 1, 1, 1))

torch.jit.trace(model.forward, torch.randn(1, 1, 2, 2))

Raises the following error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-5e9a2f5de8a5> in <module>
      4 model = nn.Sequential(nn.Conv2d(2, 2, 1, 1, 1))
      5 
----> 6 torch.jit.trace(model.forward, torch.randn(1, 1, 2, 2))

~/.virtualenvs/test/lib/python3.6/site-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class)
    693         traced = torch._C._create_function_from_trace(name, func, example_inputs,
    694                                                       var_lookup_fn,
--> 695                                                       _force_outplace)
    696 
    697     # Check the trace against new traces created from user-specified inputs

~/.virtualenvs/test/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

~/.virtualenvs/test/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    489             hook(self, input)
    490         if torch._C._get_tracing_state():
--> 491             result = self._slow_forward(*input, **kwargs)
    492         else:
    493             result = self.forward(*input, **kwargs)

~/.virtualenvs/test/lib/python3.6/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
    479         tracing_state._traced_module_stack.append(self)
    480         try:
--> 481             result = self.forward(*input, **kwargs)
    482         finally:
    483             tracing_state.pop_scope()

~/.virtualenvs/test/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
    336                             _pair(0), self.dilation, self.groups)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 
    340 

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
(1,1,.,.) = 
  0.3550

(2,1,.,.) = 
 0.01 *
  9.7722

(1,2,.,.) = 
 -0.5052

(2,2,.,.) = 
  0.5900
[ Variable[CPUType]{2,2,1,1} ]

Expected behavior

Expected to convert without issues

Environment

PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.1.85
cuDNN version: /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7.0.5

Versions of relevant libraries:
[pip3] numpy==1.16.2
[pip3] numpy-image-widget==2019.1.6
[pip3] torch==1.1.0
[pip3] torchfile==0.1.0
[pip3] torchvision==0.2.1
[conda] Could not collect

@Krovatkin

This comment has been minimized.

Copy link
Contributor

@Krovatkin Krovatkin commented May 3, 2019

@oldnaari

Could you please try calling trace on your model rather than model.forward ? It will still be tracing model.forward, but it will also capture all module's parameters.

model = nn.Sequential(nn.Conv2d(1, 1, 3))
torch.jit.trace(model, torch.randn(1, 1, 3, 3))

We are going to split trace into trace and trace_module and then we could consider making this particular error message more helpful

@SsnL

This comment has been minimized.

Copy link
Collaborator

@SsnL SsnL commented May 3, 2019

@Krovatkin But according to #19070, torch.jit.trace(model.forward, torch.randn(1, 1, 3, 3)) would eventually also work (by the 2nd syntax sugar), right?

@Krovatkin

This comment has been minimized.

Copy link
Contributor

@Krovatkin Krovatkin commented May 3, 2019

@SsnL , #19070, I believe, addresses a slightly different issue. Namely, it allows one to trace multiple methods as a part of a single module (#19905)

@SsnL

This comment has been minimized.

Copy link
Collaborator

@SsnL SsnL commented May 3, 2019

@Krovatkin Yes. I agree that this is orthogonal with trace_module. What I was asking is that, in the scenario of tracing a single module method, shouldn鈥檛 trace(n.forward, x) still work eventually, because it would be converted to trace(n.forward.__self__, x)?

@Krovatkin

This comment has been minimized.

Copy link
Contributor

@Krovatkin Krovatkin commented May 3, 2019

@SsnL I honestly don't know how @zdevito wants us to handle this particular case.
If it's a truly static function (i.e. it doesn't use module's parameters) we wouldn't want to necessarily create a whole new module to host it, since we now have standalone functions. This is exactly how trace works now. It will create a standalone function even for module's methods:

        name = getattr(func, '__name__', 'forward')
        if name == '<lambda>':
            name = '_lambda'  # make name a valid identifier
        traced = torch._C._create_function_from_trace(name, func, example_inputs,
                                                      var_lookup_fn,
                                                      _force_outplace)

On the downside, if we pass a function that does use module's parameters we get RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

If we decide that we don't want to encourage users to write static functions like this, we could drop support for this case, then we could tweak trace to do what you are suggesting.

@SsnL

This comment has been minimized.

Copy link
Collaborator

@SsnL SsnL commented May 7, 2019

@Krovatkin Yes I really hope @zdevito can help clarify. I am a bit lost on the relation between the proposal in #19070 and trace_module of #19905. It seems that #19905 covers majority of the functionalities proposed in #19070 but with a different API.

It feels to me that #19905 is trying to make the distinction that trace traces regular static functions and trace_module returns a traced module, but I am not 100% certain.

Additionally, I really hope that there will be some nice syntax sugar to make the trace_module API easier to use. Something that can work like a decorator would be nice to have.

Out of curiosity, is there a large overhead of using a traced function vs. a traced module? In particular, e.g., if I separately trace two methods of the same module using trace_module and just use the two resulting traced modules (e.g., reassign them as methods), what could go wrong?

@apaszke

This comment has been minimized.

Copy link
Member

@apaszke apaszke commented May 7, 2019

In a chat with @zdevito we concluded that the API would be the following:

  1. trace_module takes in a module and (optionally, defaulting to ['forward']) list of methods to trace. Not tracing forward is forbidden.
  2. trace always traces a single function (and returns a torch.jit.Function) without taking any parameters into the context. The only exception to that rule is that if you pass in a module, it will dispatch to trace_module for backward compatibility.
@Krovatkin

This comment has been minimized.

Copy link
Contributor

@Krovatkin Krovatkin commented May 10, 2019

@SsnL #20368 should make trace API consistent with what @apaszke and @zdevito decided. If you pass my_module.my_not_forward_method it will complain now.

@bhosmer

This comment has been minimized.

Copy link
Contributor

@bhosmer bhosmer commented Dec 13, 2019

Both repros now run correctly, mod dim tweak in original (trace model.forward) to match suggested (trace model). (unmodified original gives unrelated error on channel agreement). Closing

@bhosmer bhosmer closed this Dec 13, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
8 participants
You can鈥檛 perform that action at this time.