Skip to content

❓ [Question] How to use Torch-TensorRT with multi-headed (multiple output) networks #1645

@kunkcu

Description

@kunkcu

❓ Question

I am having trouble using Torch-TensorRT with multi-headed networks. torch_tensorrt.compile(...) works fine and I can successfully use the resulting ScriptModule for execution. However, when I try to save and re-load the module I receive a RuntimeError on torch.jit.load(...):

Traceback (most recent call last):
  File "/home/burak/test.py", line 33, in <module>
    net_trt = torch.jit.load('net_trt.ts')
  File "/home/burak/miniconda3/envs/convert/lib/python3.9/site-packages/torch/jit/_serialization.py", line 162, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: [Error thrown at core/runtime/TRTEngine.cpp:132] Expected (binding_name == engine_binded_name) to be true but got false
Could not find a TensorRT engine binding for output named output_0

What you have already tried

I have tried this behavior with a very simple multi-headed network:

import torch
import torch_tensorrt

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv0 = torch.nn.Conv2d(3, 8, kernel_size=3)
        self.relu = torch.nn.ReLU(inplace=True)
        self.conv1b1 = torch.nn.Conv2d(8, 16, kernel_size=3)
        self.conv1b2 = torch.nn.Conv2d(8, 32, kernel_size=3)

    def forward(self, x):
        x = self.conv0(x)
        x = self.relu(x)
        output1 = self.conv1b1(x)
        output2 = self.conv1b2(x)

        return output1, output2

net = Net().eval().cuda()

Then, I have compiled this network for TensorRT as usual:

net_specs = {
    'inputs': [torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)],
    'enabled_precisions': {torch.float32, torch.half},
}

net_trt = torch_tensorrt.compile(net, **net_specs)

No problem so far. net_trt works just fine. However, when I try to save and re-load it:

torch.jit.save(net_trt, 'net_trt.ts')
net_trt = torch.jit.load('net_trt.ts')

I receive the following RuntimeError:

Traceback (most recent call last):
  File "/home/burak/test.py", line 33, in <module>
    net_trt = torch.jit.load('net_trt.ts')
  File "/home/burak/miniconda3/envs/convert/lib/python3.9/site-packages/torch/jit/_serialization.py", line 162, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: [Error thrown at core/runtime/TRTEngine.cpp:132] Expected (binding_name == engine_binded_name) to be true but got false
Could not find a TensorRT engine binding for output named output_0

I have only encountered this with multi-headed networks. Everything seems to work fine with other type of networks.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
  • PyTorch Version (e.g., 1.0): 1.13.1
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.9.15
  • CUDA version: 11.7
  • GPU models and configuration: NVIDIA GeForce RTX 3070 (Laptop)
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions