Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Scripting arguments for call are not valid #17106

Closed
sto-chastic opened this issue Feb 14, 2019 · 5 comments
Closed

[JIT] Scripting arguments for call are not valid #17106

sto-chastic opened this issue Feb 14, 2019 · 5 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@sto-chastic
Copy link

sto-chastic commented Feb 14, 2019

馃悰 Bug

I am converting my model to TorchScript in order to be able to save it and load it in C++. After solving various errors that arose along the way, I am left with a RuntimeError: arguments for call are not valid: which I do not know where it originates exactly.

To Reproduce

The latest reference to my code comes from a calling to this function:

def create_yolov3_modules(anchors, anch_mask, n_classes, ignore_thre):
    mlist = nn.ModuleList()
    mlist.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1))
    mlist.append(add_conv(in_ch=32, out_ch=64, ksize=3, stride=2))
    mlist.append(resblock(ch=64))
    mlist.append(add_conv(in_ch=64, out_ch=128, ksize=3, stride=2))
    mlist.append(resblock(ch=128, nblocks=2))
    ...

where

def add_conv(in_ch, out_ch, ksize, stride):
    stage = nn.Sequential()
    pad = (ksize - 1) // 2
    stage.add_module('conv', nn.Conv2d(in_channels=in_ch,
                                       out_channels=out_ch, kernel_size=ksize, stride=stride,
                                       padding=pad, bias=False))
    stage.add_module('batch_norm', nn.BatchNorm2d(out_ch))
    stage.add_module('leaky', nn.LeakyReLU(0.1))
    return stage

and

class resblock(ScriptModule):

    __constants__ = ["nblocks","ch","shortcut"]

    def __init__(self, ch, nblocks=1, shortcut=True):

        super().__init__()
        self.shortcut = shortcut
        self.nblocks = nblocks
        self.ch = ch
        self.module_list = nn.ModuleList()
        self.blockt1=add_conv(self.ch, self.ch//2, 1, 1)
        self.blockt2=add_conv(self.ch//2, self.ch, 3, 1)
        for _ in range(nblocks):
            resblock_one = nn.ModuleList()
            self.blockt1
            self.blockt2
            self.module_list.append(resblock_one)

    @script_method
    def forward(self, x):
        for _ in range(self.nblocks):#in_ch, out_ch, ksize, stride
            h = x
            h = self.blockt1(h)
            h = self.blockt2(h)
            x = x + h if self.shortcut else h
        return x

produces this error:

> RuntimeError: 
> arguments for call are not valid:
>   
>   for operator aten::__interpolate(Tensor input, int? size=<default>, float[]? scale_factor=<default>, string mode=<default>, bool? align_corners=<default>) -> Tensor:
>   expected a value of type float[]? for argument 'scale_factor' but found int
>   @weak_script_method
>   def forward(self, input):
>       warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
>       return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
>                                              ~~~~~~~~~~ <--- HERE
>   
>   for operator aten::__interpolate(Tensor input, int[]? size=<default>, float[]? scale_factor=<default>, string mode=<default>, bool? align_corners=<default>) -> Tensor:
>   expected a value of type float[]? for argument 'scale_factor' but found int
>   @weak_script_method
>   def forward(self, input):
>       warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
>       return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
>                                              ~~~~~~~~~~ <--- HERE
>   
>   for operator aten::__interpolate(Tensor input, int? size=<default>, float? scale_factor=<default>, string mode=<default>, bool? align_corners=<default>) -> Tensor:
>   expected a value of type float? for argument 'scale_factor' but found int
>   @weak_script_method
>   def forward(self, input):
>       warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
>       return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
>                                              ~~~~~~~~~~ <--- HERE
>   
>   for operator aten::__interpolate(Tensor input, int[]? size=<default>, float? scale_factor=<default>, string mode=<default>, bool? align_corners=<default>) -> Tensor:
>   expected a value of type float? for argument 'scale_factor' but found int
>   @weak_script_method
>   def forward(self, input):
>       warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
>       return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
>                                              ~~~~~~~~~~ <--- HERE
> for call at:
> @weak_script_method
> def forward(self, input):
>     warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
>     return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
>            ~~~~~~~~~~~~~ <--- HERE

and the traceback is:

> Traceback (most recent call last):
>   File "train.py", line 204, in <module>
>     main()
>   File "train.py", line 80, in main
>     model = YOLOv3(anchors, anch_mask, n_classes, ignore_thre=ignore_thre)
>   File "/home/.conda/envs/python36_ocv_pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 951, in init_then_register
>     original_init(self, *args, **kwargs)
>   File "/home/Documents/yolotorch/training_implementation/models/yolov3.py", line 294, in __init__
>     self.module_list= create_yolov3_modules(anchors, anch_mask, n_classes, ignore_thre)
>   File "/home/.conda/envs/python36_ocv_pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 1123, in __setattr__
>     super(ScriptModule, self).__setattr__(attr, _ConstModuleList(value))
>   File "/home/.conda/envs/python36_ocv_pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 951, in init_then_register
>     original_init(self, *args, **kwargs)
>   File "/home/.conda/envs/python36_ocv_pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 1355, in __init__
>     module = _make_strong(module)
>   File "/home/.conda/envs/python36_ocv_pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 1254, in _make_strong
>     proxy = WeakScriptModuleProxy(mod, stubs)
>   File "/home/.conda/envs/python36_ocv_pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 951, in init_then_register
>     original_init(self, *args, **kwargs)
>   File "/home/.conda/envs/python36_ocv_pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 1191, in __init__
>     _create_methods_from_stubs(self, stubs)
>   File "/home/.conda/envs/python36_ocv_pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 913, in _create_methods_from_stubs
>     self._create_methods(defs, rcbs, defaults)

Expected behavior

The code should produce a ScriptModule which I should be able to save.

Environment

PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.2.148
GPU models and configuration: GPU 0: GeForce GTX 1070 Ti
Nvidia driver version: 410.48
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.12.1
[pip] torch==1.0.1.post2
[pip] torchvision==0.2.1
[conda] blas 1.0 mkl
[conda] mkl 2018.0.3 1
[conda] pytorch 1.0.1 py3.6_cuda10.0.130_cudnn7.4.2_2 pytorch
[conda] torchvision 0.2.1 py_2 pytorch

@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Feb 14, 2019
@eellison
Copy link
Contributor

We'll look into this, thanks for raising the issue.

@driazati
Copy link
Contributor

I can't seem to repro this issue on master, the (slightly modified, look at __constants__) code below runs fine.

import torch
import torch.nn as nn

def add_conv(in_ch, out_ch, ksize, stride):
    stage = nn.Sequential()
    pad = (ksize - 1) // 2
    stage.add_module('conv', nn.Conv2d(in_channels=in_ch,
                                       out_channels=out_ch, kernel_size=ksize, stride=stride,
                                       padding=pad, bias=False))
    stage.add_module('batch_norm', nn.BatchNorm2d(out_ch))
    stage.add_module('leaky', nn.LeakyReLU(0.1))
    return stage

class resblock(torch.jit.ScriptModule):

    __constants__ = ["nblocks","ch","shortcut", "blockt1", "blockt2"]

    def __init__(self, ch, nblocks=1, shortcut=True):
        super(resblock, self).__init__()
        self.shortcut = shortcut
        self.nblocks = nblocks
        self.ch = ch
        self.module_list = nn.ModuleList()
        self.blockt1=add_conv(self.ch, self.ch//2, 1, 1)
        self.blockt2=add_conv(self.ch//2, self.ch, 3, 1)
        for _ in range(nblocks):
            resblock_one = nn.ModuleList()
            self.blockt1
            self.blockt2
            self.module_list.append(resblock_one)

    @torch.jit.script_method
    def forward(self, x):
        for _ in range(self.nblocks):#in_ch, out_ch, ksize, stride
            h = x
            h = self.blockt1(h)
            h = self.blockt2(h)
            x = x + h if self.shortcut else h
        return x

mlist = nn.ModuleList()
mlist.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1))
mlist.append(add_conv(in_ch=32, out_ch=64, ksize=3, stride=2))
mlist.append(resblock(ch=64))
mlist.append(add_conv(in_ch=64, out_ch=128, ksize=3, stride=2))
mlist.append(resblock(ch=128, nblocks=2))
print(mlist)
r = resblock(2)
print(r.graph)

@driazati
Copy link
Contributor

The error is related to our lack of int to float implicit conversion, e.g.

@torch.jit.script
def inner(x):
    # type: (float) -> float
    return x

@torch.jit.script
def use_float(x):
    # type: (int) -> float
    return inner(x)

So leaving this issue open to track that

@sto-chastic
Copy link
Author

Thank you, I'll look into that

@luisa-uber
Copy link

Hi, is there an update on this issue? I also ran into this error. Here's my code snippet

import torch
from torch import nn
class CrossScale(torch.jit.ScriptModule):
    __constants__ = ['xc']

    def __init__(self):
        super(CrossScale, self).__init__()
        self.xc = nn.ModuleList((nn.Conv2d(2, 2, 2, bias=True), nn.Upsample(scale_factor=2, mode='bilinear')))

    @torch.jit.script_method
    def forward(self, x):
        cols = []
        i = 0
        for xc in self.xc:
            out = xc(x[i])
            i += 1
            cols.append(out)
        return cols


if __name__ == "__main__":
    cs = CrossScale()

Changing scale_factor=2 to scale_factor=float(2) fixes the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants