Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.nn.ZeroPad2d is not scriptable #47528

Closed
BowenBao opened this issue Nov 6, 2020 · 5 comments
Closed

torch.nn.ZeroPad2d is not scriptable #47528

BowenBao opened this issue Nov 6, 2020 · 5 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects

Comments

@BowenBao
Copy link
Collaborator

BowenBao commented Nov 6, 2020

馃悰 Bug

Runtime error thrown when trying to script model that calls nn.ZeroPad2d.

To Reproduce

Steps to reproduce the behavior:

  1. Run the below script in python.
import torch
from torch import nn

class Pad(torch.nn.Module):
    def forward(self, x):
        pad_op =  nn.ZeroPad2d(padding=(10, 20, 0, 0))
        return pad_op(x)

m = torch.jit.script(Pad())
  1. Observe error
RuntimeError:
Unknown type name '_size_4_t':
  File "/home/bowbao/repos/pytorch/torch/nn/modules/padding.py", line 453
    def __init__(self, padding: _size_4_t) -> None:
                                ~~~~~~~~~ <--- HERE
        super(ZeroPad2d, self).__init__(padding, 0.)
'ZeroPad2d.__init__' is being compiled since it was called from '__torch__.torch.nn.modules.padding.ZeroPad2d'
  File "repro_1p_rnn_fairseq.py", line 6
    def forward(self, x):
        pad_op =  nn.ZeroPad2d(padding=(10, 20, 0, 0))
                  ~~~~~~~~~~~~ <--- HERE
        return pad_op(x)
'__torch__.torch.nn.modules.padding.ZeroPad2d' is being compiled since it was called from 'Pad.forward'
  File "repro_1p_rnn_fairseq.py", line 6
    def forward(self, x):
        pad_op =  nn.ZeroPad2d(padding=(10, 20, 0, 0))
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        return pad_op(x)

Expected behavior

Create scripted module correctly.

cc @gmagogsfm

@ssnl
Copy link
Collaborator

ssnl commented Nov 6, 2020

Just curious- do we support such things given that the correct thing is to either create the module (and assign as attr) in __init__ or use the functional form?

@bdhirsh bdhirsh added the oncall: jit Add this issue/PR to JIT oncall triage queue label Nov 7, 2020
@github-actions github-actions bot added this to Need triage in JIT Triage Nov 7, 2020
@wanchaol
Copy link
Contributor

@malfet since you are adding inline type annotations for nn.Module inference, this might be a case of regression, can you take a look?

@wanchaol wanchaol moved this from Need triage to In discussion in JIT Triage Nov 10, 2020
@malfet
Copy link
Contributor

malfet commented Nov 10, 2020

The problem originates from the fact that torch.scipt does not support Unions, example below

size_2_t = Union[float, Tuple[float, float]]
def foo(x: size_2_t) -> float: return x[0] + 1.0
f=torch.jit.script(foo)

results in the same non-descriptive Unknown type name 'size_2_t':, although if one to remove the Union it will work.
Although adding a generic Union support sounds like a complicated task, I wonder if adding internal list_or_scalar_type would be an acceptable workaround, because it's a very common pattern in PyTorch

@eellison
Copy link
Contributor

I think this is what BroadcastingList is equivalent to

@malfet
Copy link
Contributor

malfet commented Nov 10, 2020

@eellison, I see, so this should work:

from torch._jit_internal import BroadcastingList2
def foo(x: BroadcastingList2[float]) -> float: return x[0] + x[1]

print(torch.jit.script(foo)(3))

malfet added a commit to malfet/pytorch that referenced this issue Nov 25, 2020
Because they are one and the same

Fixes pytorch#47528
JIT Triage automation moved this from In discussion to Done Nov 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
JIT Triage
  
Done
Development

Successfully merging a pull request may close this issue.

6 participants