-
Notifications
You must be signed in to change notification settings - Fork 376
Closed
Labels
bug: triaged [verified]We can replicate the bugWe can replicate the bugquestionFurther information is requestedFurther information is requested
Description
❓ Question
I am having trouble using Torch-TensorRT with multi-headed networks. torch_tensorrt.compile(...) works fine and I can successfully use the resulting ScriptModule for execution. However, when I try to save and re-load the module I receive a RuntimeError on torch.jit.load(...):
Traceback (most recent call last):
File "/home/burak/test.py", line 33, in <module>
net_trt = torch.jit.load('net_trt.ts')
File "/home/burak/miniconda3/envs/convert/lib/python3.9/site-packages/torch/jit/_serialization.py", line 162, in load
cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: [Error thrown at core/runtime/TRTEngine.cpp:132] Expected (binding_name == engine_binded_name) to be true but got false
Could not find a TensorRT engine binding for output named output_0
What you have already tried
I have tried this behavior with a very simple multi-headed network:
import torch
import torch_tensorrt
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv0 = torch.nn.Conv2d(3, 8, kernel_size=3)
self.relu = torch.nn.ReLU(inplace=True)
self.conv1b1 = torch.nn.Conv2d(8, 16, kernel_size=3)
self.conv1b2 = torch.nn.Conv2d(8, 32, kernel_size=3)
def forward(self, x):
x = self.conv0(x)
x = self.relu(x)
output1 = self.conv1b1(x)
output2 = self.conv1b2(x)
return output1, output2
net = Net().eval().cuda()Then, I have compiled this network for TensorRT as usual:
net_specs = {
'inputs': [torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)],
'enabled_precisions': {torch.float32, torch.half},
}
net_trt = torch_tensorrt.compile(net, **net_specs)No problem so far. net_trt works just fine. However, when I try to save and re-load it:
torch.jit.save(net_trt, 'net_trt.ts')
net_trt = torch.jit.load('net_trt.ts')I receive the following RuntimeError:
Traceback (most recent call last):
File "/home/burak/test.py", line 33, in <module>
net_trt = torch.jit.load('net_trt.ts')
File "/home/burak/miniconda3/envs/convert/lib/python3.9/site-packages/torch/jit/_serialization.py", line 162, in load
cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: [Error thrown at core/runtime/TRTEngine.cpp:132] Expected (binding_name == engine_binded_name) to be true but got false
Could not find a TensorRT engine binding for output named output_0
I have only encountered this with multi-headed networks. Everything seems to work fine with other type of networks.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
- PyTorch Version (e.g., 1.0): 1.13.1
- CPU Architecture: x86_64
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.9.15
- CUDA version: 11.7
- GPU models and configuration: NVIDIA GeForce RTX 3070 (Laptop)
- Any other relevant information:
Additional context
Metadata
Metadata
Assignees
Labels
bug: triaged [verified]We can replicate the bugWe can replicate the bugquestionFurther information is requestedFurther information is requested