Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when loading jit traced FasterRCNN model in C++ #35881

Open
ruoyussh opened this issue Apr 2, 2020 · 23 comments
Open

Error when loading jit traced FasterRCNN model in C++ #35881

ruoyussh opened this issue Apr 2, 2020 · 23 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ruoyussh
Copy link

ruoyussh commented Apr 2, 2020

🐛 Bug

I used the torch.jit.trace() in python to trace the fasterrcnn_resnet50_fpn model provided in the latest torchvision (I installed the torchvision from source). However, an unhandled exception occured when I used libtorch API to load this model.

To Reproduce

Steps to reproduce the behavior:

1.Use Python API to trace model:
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
script_model = torch.jit.script(model)
script_model.save("./data/rcnn.pt")

  1. Load "rcnn.pt" in Libtorch:
    torch::jit::script::Module module = torch::jit::load("./data/rcnn.pt");

  2. Then the exception occured:
    Unhandled exception at 0x00007FFF5D1CA839 in RCNN_Pytorch_Demo.exe: Microsoft C++ exception: torch::jit::script::ErrorReport at memory location 0x00000037EDCEEF30.

Expected behavior

It should load the traced model without exception.

Environment

Visual Studio 2017

PyTorch version: 1.4.0+cu92
Is debug build: No
CUDA used to build PyTorch: 9.2

OS: Microsoft Windows 10 Professional
GCC version: Could not collect
CMake version: version 3.16.0-rc3

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.2.148
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti

Nvidia driver version: 432.00
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2\bin\cudnn64_7.dll

Versions of relevant libraries:
[pip] efficientnet-pytorch==0.5.1
[pip] numpy==1.17.4
[pip] numpydoc==0.8.0
[pip] torch==1.4.0+cu92
[pip] torchfile==0.1.0
[pip] torchnet==0.0.4
[pip] torchvision==0.5.0+cu92
[conda] blas 1.0 mkl defaults
[conda] efficientnet-pytorch 0.5.1
[conda] mkl 2018.0.2 1 defaults
[conda] mkl-service 1.1.2 py36h57e144c_4 defaults
[conda] mkl_fft 1.0.1 py36h452e1ab_0 defaults
[conda] mkl_random 1.0.1 py36h9258bd6_0 defaults
[conda] numpy 1.14.3 py36h9fa60d3_1 defaults
[conda] numpy 1.17.4
[conda] numpy-base 1.14.3 py36h555522e_1 defaults
[conda] numpydoc 0.8.0 py36_0 defaults
[conda] torch 1.4.0+cu92
[conda] torchfile 0.1.0
[conda] torchnet 0.0.4
[conda] torchvision 0.6.0a0
Libtorch 1.4.0

cc @peterjc123 @maxluk @nbcsm @guyang3532 @gunandrose4u @mszhanyi @skyline75489 @gmagogsfm @suo

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 2, 2020
@peterjc123
Copy link
Collaborator

Please check the following items:

  1. The configuration of the project and the libtorch library is the same.
  2. The data is put in the right directory and they are not corrupted.
  3. The DLLs are copied to the directory of your executable.

@eellison eellison added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 2, 2020
@ruoyussh
Copy link
Author

ruoyussh commented Apr 3, 2020

Please check the following items:

  1. The configuration of the project and the libtorch library is the same.
  2. The data is put in the right directory and they are not corrupted.
  3. The DLLs are copied to the directory of your executable.

Thanks for your reply.

I've checked the abovementioned items and all of which are correct.

Meanwhile, I did two little experiments with the latest released torchvision package (0.5.0):

  1. Libtorch cannot load traced RCNN model and threw the same exception.
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    script_model = torch.jit.script(model)
    script_model.save("./data/rcnn.pt")

Unhandled exception at 0x00007FFF5D1CA839 in RCNN_Pytorch_Demo.exe: Microsoft C++ exception: torch::jit::script::ErrorReport at memory location 0x000000E7E559E9A0.

  1. Libtorch can load traced Densenet model.
    model = torchvision.models.densenet121(pretrained=True)
    script_model = torch.jit.script(model)
    script_model.save("./data/densenet.pt")

The above python code can run successfully. So I was wondering is it because the Libtorch doesn't support RCNN family models?

@ruoyussh
Copy link
Author

ruoyussh commented Apr 3, 2020

I caught the exception in Libtorch when loading the scripted model:

schemas.size() > 0 INTERNAL ASSERT FAILED at ....\torch\csrc\jit\script\schema_matching.cpp:476, please report a bug to PyTorch. (matchSchemas at ....\torch\csrc\jit\script\schema_matching.cpp:476)
(no backtrace available)

@suo suo self-assigned this Apr 3, 2020
@grady1006
Copy link

same problem issue

@joshzhung
Copy link

joshzhung commented Apr 15, 2020

same problem issue too.
I test mobilenet is fine, but crash by testing fcn .
maybe there are some bugs about two-stage FCN?

@jtavrisov1
Copy link

Any update on this, I've hit this roadblock. Using nightly installs for both python training code and libtorch c++ inference code.

@suo
Copy link
Member

suo commented Jun 9, 2020

@fmassa, any ideas here? Seems that torchvision models cannot be loaded when traced.

@jtavrisov1
Copy link

jtavrisov1 commented Jun 11, 2020

Just to add more information because my situation is slightly different than OP's. I'm looking to switch from TF2 to pytorch and this is my last roadblock before switching over.

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
script_model = torch.jit.script(model)
script_model.save("test.pt")

works fine in Python. I also did the same for fasterrcnn_resnet50_fpn because that's what I'm actually using. Load works as well as forward call for dummy input.

In C++ torch::jit::load throws an exception same as above for the same model I just tested python with.

schemas.size() > 0 INTERNAL ASSERT FAILED at "..\..\torch\csrc\jit\frontend\schema_matching.cpp":491, please report a bug to PyTorch.

Using Windows 10 with binaries downloaded from main website for preview (nightly). Python is the same nightly option but with pip. C++ code is built in QT Creator with Qmake. My other models have worked and given good outputs. Compiler is MSVC 2017

Ran it in debug to catch exception. Release just crashes on jit::load (it's wrapped in try..catch)

EDIT: nevermind on the FCN part, did the test again and it works. Not sure what I screwed up this morning. RCNN still does not

@JavaAiNiU
Copy link

Just to add more information because my situation is slightly different than OP's. I'm looking to switch from TF2 to pytorch and this is my last roadblock before switching over.

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
script_model = torch.jit.script(model)
script_model.save("test.pt")

works fine in Python. I also did the same for fasterrcnn_resnet50_fpn because that's what I'm actually using. Load works as well as forward call for dummy input.

In C++ torch::jit::load throws an exception same as above for the same model I just tested python with.

schemas.size() > 0 INTERNAL ASSERT FAILED at "....\torch\csrc\jit\frontend\schema_matching.cpp":491, please report a bug to PyTorch.

Using Windows 10 with binaries downloaded from main website for preview (nightly). Python is the same nightly option but with pip. C++ code is built in QT Creator with Qmake. My other models have worked and given good outputs. Compiler is MSVC 2017

Ran it in debug to catch exception. Release just crashes on jit::load (it's wrapped in try..catch)

EDIT: nevermind on the FCN part, did the test again and it works. Not sure what I screwed up this morning. RCNN still does not

same problem issue

@suo
Copy link
Member

suo commented Nov 17, 2020

Hm @peterjc123 were you planning to look into this? Seems to only be happening for users with windows; I suspect it has to do with torchvision custom ops not getting registered?

@soad89
Copy link

soad89 commented Nov 27, 2020

Same problem for me. Is there any workaround? Thanks

@dc986
Copy link

dc986 commented Jan 26, 2021

I have a similar problem. On Windows, VisualStudio 2019 with
libtorch 1.7.1+cpu
I am trying to load fasterrcnn_resnet50_fpn but I got the exception
Unhandled exception at 0x000007FEFCC0A06D in Test_Pytorch.exe: Microsoft C++ exception: torch::jit::ErrorReport at memory location 0x0000000000127440.

I am loading the model with this code:
torch::jit::script::Module module; module = torch::jit::load(model_path);

The model has been scripted with the same version of torch and torchvision.

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.save("my_scripted_model_save_torch171_tochvision082_false.pt")
Do you have suggestion?
Thanks

@EinarBjorn
Copy link

Same problem with Mask R-CNN resnet 50 model

@bs1119
Copy link

bs1119 commented Apr 21, 2021

Any solution on this issue so far?

@skyline75489
Copy link
Contributor

This is not a Windows-specific issue. The exception in Python looks like this:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chester/miniconda3/envs/py37/lib/python3.7/site-packages/torch/jit/_serialization.py", line 161, in load
    cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files)
RuntimeError:
Unknown builtin op: torchvision::nms.
Could not find any similar ops to torchvision::nms. This op may not exist or may not be currently supported in TorchScript.

@JavaAiNiU
Copy link

JavaAiNiU commented Apr 21, 2021 via email

@SorourMo
Copy link

SorourMo commented Jun 29, 2022

Any update on this? I got the same error with maskrcnn_resnet_50_fpn pretrained on cpu. The scripted version of the model can be loaded in Python and gives expected results. But the c++ side doesn't work. c++ code works fine with other models (unet type) in the same setting. This is a maskrcnn issue.

Windows 10
libtorch: libtorch-win-shared-with-deps-debug-1.10.1+cu102
MSVS 2019
Using cpu (no cuda)

The error is as follows:
An error ocurred: schemas.size() > 0INTERNAL ASSERT FAILED at "..\\..\\torch\\csrc\\jit\\frontend\\schema_matching.cpp":526, please report a bug to PyTorch. Exception raised from matchSchemas at ..\..\torch\csrc\jit\frontend\schema_matching.cpp:526 (most recent call first): 00007FFB91429A1A00007FFB91428AA0 c10.dll!c10::detail::LogAPIUsageFakeReturn [<unknown file> @ <unknown line number>] 00007FFB9142957A00007FFB91428AA0 c10.dll!c10::detail::LogAPIUsageFakeReturn [<unknown file> @ <unknown line number>] 00007FFB9142A78100007FFB91428AA0 c10.dll!c10::detail::LogAPIUsageFakeReturn [<unknown file> @ <unknown line number>] 00007FFB9142A3A500007FFB91428AA0 c10.dll!c10::detail::LogAPIUsageFakeReturn [<unknown file> @ <unknown line number>] 00007FFB91427FAF00007FFB91427F40 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>] 00007FFB914268C600007FFB91426800 c10.dll!c10::detail::torchCheckFail [<unknown file> @ <unknown line number>] 00007FFB913BD62600007FFB913BD5E0 c10.dll!c10::detail::torchInternalAssertFail [<unknown file> @ <unknown line number>] 00007FFB1992999C00007FFB19929900 torch_cpu.dll!torch::jit::matchSchemas [<unknown file> @ <unknown line number>] 00007FFB1992A63100007FFB1992A170 torch_cpu.dll!torch::jit::emitBuiltinCall [<unknown file> @ <unknown line number>] 00007FFB19944B7900007FFB19944A70 torch_cpu.dll!torch::jit::BuiltinFunction::call [<unknown file> @ <unknown line number>] 00007FFB198DE81A00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB1990147500007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198FD19A00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198E560C00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198FE77000007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198EB36E00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198B2FC100007FFB198B1E30 torch_cpu.dll!torch::jit::ScriptTypeParser::ScriptTypeParser [<unknown file> @ <unknown line number>] 00007FFB198C17EE00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198A4EB200007FFB19883DD0 torch_cpu.dll!torch::jit::meaningfulName [<unknown file> @ <unknown line number>] 00007FFB1988E11200007FFB19883DD0 torch_cpu.dll!torch::jit::meaningfulName [<unknown file> @ <unknown line number>] 00007FFB198C94EE00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB197B6FAF00007FFB197B5B50 torch_cpu.dll!torch::jit::GraphFunction::getSchema [<unknown file> @ <unknown line number>] 00007FFB197B5ADE00007FFB197B5A40 torch_cpu.dll!torch::jit::GraphFunction::ensure_defined [<unknown file> @ <unknown line number>] 00007FFB198813FF00007FFB19880DC0 torch_cpu.dll!torch::jit::CompilationUnit::define [<unknown file> @ <unknown line number>] 00007FFB19F0D35800007FFB19F08190 torch_cpu.dll!torch::jit::readArchiveAndTensors [<unknown file> @ <unknown line number>] 00007FFB19F0BD4700007FFB19F08190 torch_cpu.dll!torch::jit::readArchiveAndTensors [<unknown file> @ <unknown line number>] 00007FFB19F12CEF00007FFB19F12950 torch_cpu.dll!torch::jit::ClassNamespaceValue::attr [<unknown file> @ <unknown line number>] 00007FFB199013C900007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198DE56900007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB1990147500007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198FD19A00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198E560C00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198FE77000007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198FE4BB00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198FD9B200007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198EF6B800007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198EF3F500007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198FE60E00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198EB36E00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198B2FC100007FFB198B1E30 torch_cpu.dll!torch::jit::ScriptTypeParser::ScriptTypeParser [<unknown file> @ <unknown line number>] 00007FFB198C17EE00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198A4EB200007FFB19883DD0 torch_cpu.dll!torch::jit::meaningfulName [<unknown file> @ <unknown line number>] 00007FFB1988E11200007FFB19883DD0 torch_cpu.dll!torch::jit::meaningfulName [<unknown file> @ <unknown line number>] 00007FFB198C94EE00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB197B6FAF00007FFB197B5B50 torch_cpu.dll!torch::jit::GraphFunction::getSchema [<unknown file> @ <unknown line number>] 00007FFB197B5ADE00007FFB197B5A40 torch_cpu.dll!torch::jit::GraphFunction::ensure_defined [<unknown file> @ <unknown line number>] 00007FFB198813FF00007FFB19880DC0 torch_cpu.dll!torch::jit::CompilationUnit::define [<unknown file> @ <unknown line number>] 00007FFB19F0D35800007FFB19F08190 torch_cpu.dll!torch::jit::readArchiveAndTensors [<unknown file> @ <unknown line number>] 00007FFB19F0BD4700007FFB19F08190 torch_cpu.dll!torch::jit::readArchiveAndTensors [<unknown file> @ <unknown line number>] 00007FFB19F12CEF00007FFB19F12950 torch_cpu.dll!torch::jit::ClassNamespaceValue::attr [<unknown file> @ <unknown line number>] 00007FFB199013C900007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198FD19A00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198E560C00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198FE77000007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198EB36E00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198B2FC100007FFB198B1E30 torch_cpu.dll!torch::jit::ScriptTypeParser::ScriptTypeParser [<unknown file> @ <unknown line number>] 00007FFB198C17EE00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198A4EB200007FFB19883DD0 torch_cpu.dll!torch::jit::meaningfulName [<unknown file> @ <unknown line number>] 00007FFB1988E11200007FFB19883DD0 torch_cpu.dll!torch::jit::meaningfulName [<unknown file> @ <unknown line number>] 00007FFB198C94EE00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB197B6FAF00007FFB197B5B50 torch_cpu.dll!torch::jit::GraphFunction::getSchema [<unknown file> @ <unknown line number>] 00007FFB197B5ADE00007FFB197B5A40 torch_cpu.dll!torch::jit::GraphFunction::ensure_defined [<unknown file> @ <unknown line number>] 00007FFB198D9D1100007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>] 00007FFB198DE81A00007FFB198B8660 torch_cpu.dll!torch::jit::ScriptTypeParser::operator= [<unknown file> @ <unknown line number>]

@AshvantSelvam
Copy link

AshvantSelvam commented Sep 12, 2022

@suo
An update on this would be appreciated. All of the object detection models give this error when loaded into C++. (traced in python). Is there workaround that would let me load these models into LibTorch C++?

schemas.size() > 0 INTERNAL ASSERT FAILED at "C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\jit\frontend\schema_matching.cpp":575, please report a bug to PyTorch.
error loading the model
Exception raised from matchSchemas at C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\jit\frontend\schema_matching.cpp:575 (most recent call first):

@mantaionut
Copy link
Collaborator

I will remove the module windows since I reproduced this issue on Linux as well.
Could not find any similar ops to torchvision::nms. This op may not exist or may not be currently supported in TorchScript. : File "code/__torch__/torchvision/ops/boxes.py", line 138 _59 = __torch__.torchvision.extension._assert_has_ops _60 = _59() _61 = ops.torchvision.nms(boxes, scores, iou_threshold) ~~~~~~~~~~~~~~~~~~~ <--- HERE return _61 'nms' is being compiled since it was called from '_batched_nms_vanilla'

@mantaionut mantaionut removed the module: windows Windows support for PyTorch label Nov 23, 2022
@mantaionut mantaionut removed their assignment Nov 23, 2022
@Blackhex Blackhex moved this from Todo to Done in PyTorch On Windows Dec 2, 2022
@withkun
Copy link

withkun commented Dec 11, 2022

Same problem, waiting for solution

@Monibsediqi
Copy link

Any update on this issue? I am getting the same error when loading a traced swinUNetR model in C++ using torch::jit::load("path to the traced model").

@soad89
Copy link

soad89 commented Jun 24, 2023

As a workaround to this issue import torchvision:
#include <torchvision/csrc/vision.h>
Torchvision has to be build from source as far as I know (I did not find any builds).
You can use official libtroch build and matching CUDA toolkit installed.
torchvision source

@iremyux
Copy link
Collaborator

iremyux commented Aug 9, 2023

Since the Windows label was removed, I am removing it from 'Pytorch on Windows' project as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests