New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JIT] Scripting arguments for call are not valid #17106
Comments
We'll look into this, thanks for raising the issue. |
I can't seem to repro this issue on master, the (slightly modified, look at import torch
import torch.nn as nn
def add_conv(in_ch, out_ch, ksize, stride):
stage = nn.Sequential()
pad = (ksize - 1) // 2
stage.add_module('conv', nn.Conv2d(in_channels=in_ch,
out_channels=out_ch, kernel_size=ksize, stride=stride,
padding=pad, bias=False))
stage.add_module('batch_norm', nn.BatchNorm2d(out_ch))
stage.add_module('leaky', nn.LeakyReLU(0.1))
return stage
class resblock(torch.jit.ScriptModule):
__constants__ = ["nblocks","ch","shortcut", "blockt1", "blockt2"]
def __init__(self, ch, nblocks=1, shortcut=True):
super(resblock, self).__init__()
self.shortcut = shortcut
self.nblocks = nblocks
self.ch = ch
self.module_list = nn.ModuleList()
self.blockt1=add_conv(self.ch, self.ch//2, 1, 1)
self.blockt2=add_conv(self.ch//2, self.ch, 3, 1)
for _ in range(nblocks):
resblock_one = nn.ModuleList()
self.blockt1
self.blockt2
self.module_list.append(resblock_one)
@torch.jit.script_method
def forward(self, x):
for _ in range(self.nblocks):#in_ch, out_ch, ksize, stride
h = x
h = self.blockt1(h)
h = self.blockt2(h)
x = x + h if self.shortcut else h
return x
mlist = nn.ModuleList()
mlist.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1))
mlist.append(add_conv(in_ch=32, out_ch=64, ksize=3, stride=2))
mlist.append(resblock(ch=64))
mlist.append(add_conv(in_ch=64, out_ch=128, ksize=3, stride=2))
mlist.append(resblock(ch=128, nblocks=2))
print(mlist)
r = resblock(2)
print(r.graph) |
The error is related to our lack of @torch.jit.script
def inner(x):
# type: (float) -> float
return x
@torch.jit.script
def use_float(x):
# type: (int) -> float
return inner(x) So leaving this issue open to track that |
Thank you, I'll look into that |
Hi, is there an update on this issue? I also ran into this error. Here's my code snippet
Changing |
馃悰 Bug
I am converting my model to TorchScript in order to be able to save it and load it in C++. After solving various errors that arose along the way, I am left with a RuntimeError: arguments for call are not valid: which I do not know where it originates exactly.
To Reproduce
The latest reference to my code comes from a calling to this function:
where
and
produces this error:
Expected behavior
The code should produce a ScriptModule which I should be able to save.
Environment
PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.2.148
GPU models and configuration: GPU 0: GeForce GTX 1070 Ti
Nvidia driver version: 410.48
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.12.1
[pip] torch==1.0.1.post2
[pip] torchvision==0.2.1
[conda] blas 1.0 mkl
[conda] mkl 2018.0.3 1
[conda] pytorch 1.0.1 py3.6_cuda10.0.130_cudnn7.4.2_2 pytorch
[conda] torchvision 0.2.1 py_2 pytorch
The text was updated successfully, but these errors were encountered: