Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.jit.trace fails claiming "forward method already defined" #65611

Open
BryceEakin opened this issue Sep 24, 2021 · 5 comments
Open

torch.jit.trace fails claiming "forward method already defined" #65611

BryceEakin opened this issue Sep 24, 2021 · 5 comments
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects

Comments

@BryceEakin
Copy link

BryceEakin commented Sep 24, 2021

馃悰 Bug

torch.jit.trace fails tracking a normal module with error "RuntimeError: method 'torch.dl_research.pytorch.projects.guidance.full_model.___torch_mangle_56.FullGuidanceModel.forward' already defined."

To Reproduce

Steps to reproduce the behavior:

Simplified version of model which reproduces behavior:

import torch

class SampleModel(torch.nn.Module):
    def __init__(self):
        """Generates a Guidance Model suitable for inference.
        """
        super().__init__()

    @torch.jit.export
    def forward(self, images):
        images = images.to(torch.float32) / 255.0

        predictions = torch.mean(images, dim=(1, 2, 3))

        return predictions

    
model = SampleModel()
sample_input = (torch.rand((10,1,128,128))*255).to(torch.uint8)

# Runs fine
model(sample_input)

# Fails!
torch.jit.trace(model, example_inputs=sample_input)


# Error message and stack trace
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_2908/862869542.py in <module>
----> 1 torch.jit.trace(model, example_inputs=torch.tensor(images[:10]))

~/.pyenv/versions/3.8.7/envs/dl-research-pytorch/lib/python3.8/site-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    733 
    734     if isinstance(func, torch.nn.Module):
--> 735         return trace_module(
    736             func,
    737             {"forward": example_inputs},

~/.pyenv/versions/3.8.7/envs/dl-research-pytorch/lib/python3.8/site-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    950             example_inputs = make_tuple(example_inputs)
    951 
--> 952             module._c._create_method_from_trace(
    953                 method_name,
    954                 func,

RuntimeError: method '__torch__.___torch_mangle_96.SampleModel.forward' already defined.

Expected behavior

trace should produce a scripted module without error.

Environment

  • PyTorch Version (e.g., 1.0): 1.9.1
  • OS (e.g., Linux): Ubuntu 20.04 (WSL)
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.8.7
  • CUDA/cuDNN version: n/a
  • GPU models and configuration: n/a
  • Any other relevant information: reproduced on macOS

Additional context

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 24, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage Sep 24, 2021
@BouazzaSE
Copy link

BouazzaSE commented Sep 24, 2021

Remove @torch.jit.export when using jit.trace, those decorators are for jit.script instead (and even in that case, there's no need to use that decorator on forward as it is by default included in the list of methods to jit).

@BryceEakin
Copy link
Author

BryceEakin commented Sep 24, 2021

Remove @torch.jit.export when using jit.trace, those decorators are for jit.script instead (and even in that case, there's no need to use that decorator on forward as it is by default included in the list of methods to jit).

That does remove the error, but it's not unreasonable to use a module that was intended to be scripted in a module that is then traced -- this is a toy version of my error case. It's an unexpected behavior that marking a function as being available for jit scripting export would clash with tracing.

@BouazzaSE
Copy link

I'm not familiar with the internals of PyTorch, but my guess would be that since a scripted function/module produces custom kernels that are the fusion of (mostly) many element-wise ops, it's impossible for jit.trace to trace the operations in such functions/modules. Which is in my opinion why you can't use something that was scripted inside something that you intend to trace.

@BryceEakin
Copy link
Author

To be clear, the module had not yet been scripted -- that function was merely marked as an entrypoint if the module were to be scripted. So at this point it is a fully standard torch.nn.Module. My point is more that it's an undocumented incompatibility and, therefore, either an unintended side effect (bug) or at a minimum an undocumented failure mode (needs documentation).

@BouazzaSE
Copy link

I agree this behaviour does need to be documented here https://pytorch.org/docs/stable/jit.html#torch.jit.export, it was a headache for me too when I first stumbled on that error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
JIT Triage
  
Need triage
Development

No branches or pull requests

3 participants