Skip to content

🐛 [Bug] Cannot compile with F.interpolate,encounter with "TypeError: torch.int64 is not supported by tensorrt"  #2616

@DK223

Description

@DK223

Bug Description

I am compiling a model with the torch_tensorrt backend that includes the interpolate operator. However, I encountered the following error: TypeError: torch.int64 is not supported by tensorrt. What confuses me is that the dtype of the input I provided is torch.float32, and there is no occurrence of torch.int64 in the parameters. How can I resolve this issue and continue using the interpolate operator in my model?

The following are the library versions I am using.
torch==2.0.1+cu118
torch-tensorrt==1.4.0

To Reproduce

import torch
import os
import torch_tensorrt

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: torch.Tensor):
        x_out = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="bilinear", align_corners=False)
        return torch.mean(x_out)

inputs = [torch.rand((1, 3, 112, 112)).cuda()]
model = Model().eval().cuda()
optimized_model_custom = torch.compile(model, backend="torch_tensorrt")

with torch.no_grad():
    for _ in range(2):
        output = optimized_model_custom(*inputs)

And the error is:


[2024-01-23 15:19:09,061] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2024-01-23 15:19:09,100] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
[2024-01-23 15:19:09,101] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function torch_tensorrt_backend
ERROR:torch_tensorrt.dynamo.backend.backends:FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead.
Traceback (most recent call last):
  File "/home/users/dengke.wang/.envs/torch20/lib/python3.8/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 74, in _pretraced_backend
    trt_compiled = _compile_module(
  File "/home/users/dengke.wang/.envs/torch20/lib/python3.8/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 134, in _compile_module
    trt_mod = convert_module(
  File "/home/users/dengke.wang/.envs/torch20/lib/python3.8/site-packages/torch_tensorrt/dynamo/backend/conversion.py", line 34, in convert_module
    r = interp.run(
  File "/home/users/dengke.wang/.envs/torch20/lib/python3.8/site-packages/torch_tensorrt/fx/fx2trt.py", line 204, in run
    super().run()
  File "/home/users/dengke.wang/.envs/torch20/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/home/users/dengke.wang/.envs/torch20/lib/python3.8/site-packages/torch_tensorrt/fx/fx2trt.py", line 275, in run_node
    trt_node = super().run_node(n)
  File "/home/users/dengke.wang/.envs/torch20/lib/python3.8/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/users/dengke.wang/.envs/torch20/lib/python3.8/site-packages/torch_tensorrt/fx/fx2trt.py", line 303, in placeholder
    name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
  File "/home/users/dengke.wang/.envs/torch20/lib/python3.8/site-packages/torch_tensorrt/fx/utils.py", line 45, in torch_dtype_to_trt
    raise TypeError("%s is not supported by tensorrt" % dtype)
TypeError: torch.int64 is not supported by tensorrt

While executing %unsqueeze_1 : [#users=1] = placeholder[target=unsqueeze_1]
Original traceback:
  File "/home/users/dengke.wang/WorkSpace/HAT/HAT/Tools_test/interpolate_test.py", line 25, in forward
    x_out = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="bilinear", align_corners=False)

[2024-01-23 15:19:12,975] torch._dynamo.output_graph: [INFO] Step 2: done compiler function torch_tensorrt_backend

Expected behavior

So, what is causing this error, and how can I resolve it? Please.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version 1.4.0:
  • PyTorch Version 2.0.1+cu118:
  • CPU Architecture:
  • OS (e.g., Linux): Linux
  • How you installed PyTorch : pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.8.3
  • CUDA version: cu118
  • GPU models and configuration: 3090
  • Any other relevant information:

Additional context

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions