Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] RuntimeError: [Error thrown at core/conversion/conversion.cpp:220] List type. Only a single tensor or a TensorList type is supported. #1170

Closed
wanduoz opened this issue Jul 9, 2022 · 3 comments
Assignees
Labels
bug Something isn't working component: core Issues re: The core compiler

Comments

@wanduoz
Copy link

wanduoz commented Jul 9, 2022

Bug Description

I use torch_tensorrt to convert a multi-head classification model. I find a similar issue here https://github.com/pytorch/TensorRT/issues/899, but can't find a clear solution or workaround. Any suggestions or conclusions?

Traceback (most recent call last):
  File "pt2torch_tensorrt.py", line 45, in <module>
    trt_model_fp32 = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((2, 3, 512, 512), dtype=torch.float32)],enabled_precisions = torch.float32)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 115, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 116, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/conversion/conversion.cpp:220] List type. Only a single tensor or a TensorList type is supported

To Reproduce

Code snippets are as followed

def forward(self, x):
    x = self.features(x)
    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = [classifier(x) for classifier in self.classifiers]
    return x

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

use prebuild container, v22.03

  • CUDA version: 11.6
  • GPU models and configuration: 1080ti

Additional context

I try return the first element of list output and gets a trt model successfully.

def forward(self, x):
    x = self.features(x)
    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = [classifier(x) for classifier in self.classifiers]
    return x[0]
root@b41c3016cf4c:/models/1-inference-investigation/mycode/aimachine-master# python pt2torch_tensorrt.py
results/pth2onnx-20220621
training in device = cuda
Creating model with parameters: {'heads_type': 'mult', 'input_dim': 3, 'model_dir': 'results', 'net_name': 'efficientnet_multiheads_b5', 'num_classes_list': [1, 1, 11, 53]}
loading model: model_36_20220621_v2.pth
graph(%self_1 : __torch__.aimachine.models.classification.classificationnet.ClassificationNet_trt,
      %input_0 : Tensor):
  %__torch___aimachine_models_classification_classificationnet_ClassificationNet_trt_engine_ : __torch__.torch.classes.tensorrt.Engine = prim::GetAttr[name="__torch___aimachine_models_classification_classificationnet_ClassificationNet_trt_engine_"](%self_1)
  %3 : Tensor[] = prim::ListConstruct(%input_0)
  %4 : Tensor[] = tensorrt::execute_engine(%3, %__torch___aimachine_models_classification_classificationnet_ClassificationNet_trt_engine_)
  %5 : Tensor = prim::ListUnpack(%4)
  return (%5)
root@b41c3016cf4c:/models/1-inference-investigation/mycode/aimachine-master#
@wanduoz wanduoz added the bug Something isn't working label Jul 9, 2022
@ncomly-nvidia ncomly-nvidia added the component: core Issues re: The core compiler label Jul 11, 2022
@wanduoz
Copy link
Author

wanduoz commented Jul 12, 2022

I guess this problem cause by the following code

def forward(self, images):
    ypred = self.net(images)
    if isinstance(ypred, list) | isinstance(ypred, tuple):
        if self.softmax:
            ypred = [F.softmax(pred, dim=1) for pred in ypred]
        elif self.softmax == False:
            ypred = [F.sigmoid(pred) for pred in ypred]
    else:
        if self.softmax:
            ypred = F.softmax(ypred, dim=1)
        elif self.softmax == False:
            ypred = F.sigmoid(ypred)
    return ypred

Since I use sigmoid for all outputs of model, I change the code into

def forward(self, images):
    ypred = self.net(images)
    output0 = torch.sigmoid(ypred[0])
    output1 = torch.sigmoid(ypred[1])
    output2 = torch.sigmoid(ypred[2])
    output3 = torch.sigmoid(ypred[3])

    return  (output0, output1, output2, output3)

And it works

@narendasan
Copy link
Collaborator

narendasan commented Aug 12, 2022

I can't verify this exact use case myself because the full module was not shared but this should be solved in the upcoming release, you can try now on master (#1201), reopen if there are still issues.

@chriskeraly
Copy link

I am still having this issue when trying to compile just a method (which I need to do for a few reasons)

`class test_module(nn.Module):
def init(self):
super().init()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)

def forward(self, x:torch.Tensor, y: torch.Tensor): #torch.Tensor:
    x = x+y
    x1 = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    x = F.max_pool2d(F.relu(self.conv2(x1)), 2)
    return (x,x1)

mymodule = test_module()

serialized_engine = torch_tensorrt.convert_method_to_trt_engine(module, 'forward', inputs=torch_tensorrt_inputs)`

RuntimeError: [Error thrown at external/torch-tensorrt/core/conversion/conversion.cpp:254] Tuple type. Only a single tensor or a TensorList type is supported.

any reason this should not work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: core Issues re: The core compiler
Projects
None yet
Development

No branches or pull requests

6 participants