Skip to content

There is an error, when I tried to load the scripted fasterrcnn_resnet50_fpn model #1622

@dao-kun

Description

@dao-kun

There still is an error, when I tried to load the scripted fasterrcnn_resnet50_fpn model
environment:
python : 3.7
pytorch : '1.4.0a0+829499e',
torchvision : '0.5.0a0+a44d55d'

code:

    import torch
    import numpy as np
    from PIL import Image
    from torchvision import transforms
    from torchvision.models.detection import fasterrcnn_resnet50_fpn

    model = fasterrcnn_resnet50_fpn(pretrained=True)
    im =Image.open('./imgs/street.jpg')
    transform = transforms.Compose([ 
        transforms.Resize((500, 500)),
        transforms.ToTensor(),])
    im_input = transform(im).unsqueeze(0)
    model.eval()
    predictions = model(im_input) 
    traced_model = torch.jit.script(fasterrcnn_resnet50_fpn(pretrained=True).eval())
    torch.jit.save(traced_model, r'./rcnn.pt')
    load_model = torch.jit.load(r'./rcnn.pt')   # error here

Traceback (most recent call last):
File "/home/imc/anaconda3/envs/pytorch_source/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "/home/imc/anaconda3/envs/pytorch_source/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/cuda100/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/new_ptvsd/wheels/ptvsd/main.py", line 45, in
cli.main()
File "/home/cuda100/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/new_ptvsd/wheels/ptvsd/../ptvsd/server/cli.py", line 362, in main
run()
File "/home/cuda100/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/new_ptvsd/wheels/ptvsd/../ptvsd/server/cli.py", line 204, in run_file
runpy.run_path(options.target, run_name="main")
File "/home/imc/anaconda3/envs/pytorch_source/lib/python3.7/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/home/imc/anaconda3/envs/pytorch_source/lib/python3.7/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/home/imc/anaconda3/envs/pytorch_source/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/cuda100/daokun/RCNN/questsion.py", line 17, in
load_model = torch.jit.load(r'./rcnn.pt') #error here
File "/home/imc/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/jit/init.py", line 235, in load
cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files)
RuntimeError:
Variable 'last_C' previously has type None but is now being assigned to a value of type int
:
File "/home/imc/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torchvision-0.5.0a0+a44d55d-py3.7-linux-x86_64.egg/torchvision/models/detection/rpn.py", line 220
# for the objectness and the box_regression
last_C = torch.jit.annotate(Optional[int], None)
for box_cls_per_level, box_regression_per_level in zip(
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
box_cls, box_regression
):
Serialized File "code/torch/torchvision/models/detection/rpn.py", line 408
box_regression_per_level0 = _156(box_regression_per_level, N, A, 4, H, W, )
_162 = torch.append(box_regression_flattened, box_regression_per_level0)
last_C = C
~~~~~~ <--- HERE
if torch.isnot(last_C, None):
last_C0 = unchecked_cast(int, last_C)
'concat_box_prediction_layers' is being compiled since it was called from 'RegionProposalNetwork.forward'
Serialized File "code/torch/torchvision/models/detection/rpn.py", line 18
features: Dict[str, Tensor],
targets: Optional[List[Dict[str, Tensor]]]=None) -> Tuple[List[Tensor], Dict[str, Tensor]]:
_0 = torch.torchvision.models.detection.rpn.concat_box_prediction_layers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
_1 = uninitialized(List[Dict[str, Tensor]])
features0 = torch.list(torch.values(features))

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions