Skip to content

❓ [Question] Partitioning for unsupported operations #1653

@kunkcu

Description

@kunkcu

❓ Question

As far as I understand Torch-TensorRT performs a partitioning step when unsupported operations are encountered. Then, graph uses generated TensorRT engine(s) for supported partition(s) and falls back to TorchScript JIT anywhere else. I can observe this behavior from generated graphs in general. However, I receive errors with specific blocks in which I couldn't understand why such blocks are problematic.

For instance, for the following (example) block:

"""block(for+cond)"""
retval=[]
for slice in x: # x: torch.Tensor
    if slice.sum() > 0: # any cond. dep. on tensor/slice
        retval.append(slice + 100)
    else:
        retval.append(slice + 50)
"""block(for+cond)"""

I receive a RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:167] Expected ivalues_maps.count(input) to be true but got false on torch_tensorrt.compile(...):

Traceback (most recent call last):
  File "/home/burak/test.py", line 36, in <module>
    net_trt = torch_tensorrt.compile(net, **net_specs)
  File "/home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 125, in compile
    return torch_tensorrt.ts.compile(
  File "/home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch_tensorrt/ts/_compiler.py", line 136, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:167] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* slice.1 produced from %slice.1 : Tensor = aten::select(%158, %6, %19) # /home/burak/test.py:20:8 in lowering graph for mini graph input.

What you have already tried

I have tried this behavior with the following example script:

import torch
import torch_tensorrt
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Info)

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv0 = torch.nn.Conv2d(3, 8, kernel_size=3)
        self.relu = torch.nn.ReLU(inplace=True)
        self.conv1 = torch.nn.Conv2d(8, 16, kernel_size=3)

    def forward(self, x):
        x = self.conv0(x)
        x = self.relu(x)
        x = self.conv1(x)

        """block(for+cond)"""
        retval=[]
        for slice in x:
            if slice.sum() > 0: # any cond. dep. on tensor/slice
                retval.append(slice + 100)
            else:
                retval.append(slice + 50)
        """block(for+cond)"""

        return retval

net = Net().eval().cuda()

net_specs = {
    'inputs': [torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)],
    'enabled_precisions': {torch.float32, torch.half},
}

net_trt = torch_tensorrt.compile(net, **net_specs)
print(net_trt.graph)

I receive the following RuntimeError (full output, info log-level):

INFO: [Torch-TensorRT] - ir was set to default, using TorchScript as ir
INFO: [Torch-TensorRT] - Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript
INFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor):
  %self.conv0.weight.1 : Float(8, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %self.conv0.bias.1 : Float(8, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value= 0.1437  0.0745  0.1127  0.1185  0.1406  0.1445 -0.0802  0.0562 [ CUDAFloatType{8} ]]()
  %self.conv1.weight.1 : Float(16, 8, 3, 3, strides=[72, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %self.conv1.bias.1 : Float(16, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %9 : int = prim::Constant[value=1]()
  %8 : NoneType = prim::Constant()
  %7 : bool = prim::Constant[value=1]() # /home/burak/test.py:20:8
  %6 : int = prim::Constant[value=0]() # /home/burak/test.py:21:29
  %5 : int = prim::Constant[value=100]() # /home/burak/test.py:22:38
  %4 : int = prim::Constant[value=50]() # /home/burak/test.py:24:38
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %153 : bool = prim::Constant[value=0]()
  %154 : int[] = prim::Constant[value=[0, 0]]()
  %155 : Tensor = aten::_convolution(%x.1, %self.conv0.weight.1, %self.conv0.bias.1, %3, %2, %3, %153, %154, %9, %153, %153, %153, %153)
  %17 : Tensor[] = prim::ListConstruct()
  %137 : Tensor = aten::relu(%155) # /home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch/nn/functional.py:1455:17
  %156 : bool = prim::Constant[value=0]()
  %157 : int[] = prim::Constant[value=[0, 0]]()
  %158 : Tensor = aten::_convolution(%137, %self.conv1.weight.1, %self.conv1.bias.1, %3, %2, %3, %156, %157, %9, %156, %156, %156, %156)
  %144 : int = aten::len(%158) # /home/burak/test.py:20:8
   = prim::Loop(%144, %7) # /home/burak/test.py:20:8
    block0(%19 : int):
      %slice.1 : Tensor = aten::select(%158, %6, %19) # /home/burak/test.py:20:8
      %123 : Tensor = aten::sum(%slice.1, %8) # /home/burak/test.py:21:15
      %125 : Tensor = aten::gt(%123, %6) # /home/burak/test.py:21:15
      %126 : bool = aten::Bool(%125) # /home/burak/test.py:21:15
       = prim::If(%126) # /home/burak/test.py:21:12
        block0():
          %24 : Tensor = aten::add(%slice.1, %5, %9) # /home/burak/test.py:22:30
          %25 : Tensor[] = aten::append(%17, %24) # /home/burak/test.py:22:16
          -> ()
        block1():
          %26 : Tensor = aten::add(%slice.1, %4, %9) # /home/burak/test.py:24:30
          %27 : Tensor[] = aten::append(%17, %26) # /home/burak/test.py:24:16
          -> ()
      -> (%7)
  return (%17)

INFO: [Torch-TensorRT] - Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.
Unsupported operators listed below:
  - aten::Bool.Tensor(Tensor a) -> bool
  - aten::len.Tensor(Tensor t) -> int

Traceback (most recent call last):
  File "/home/burak/test.py", line 36, in <module>
    net_trt = torch_tensorrt.compile(net, **net_specs)
  File "/home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 125, in compile
    return torch_tensorrt.ts.compile(
  File "/home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch_tensorrt/ts/_compiler.py", line 136, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:167] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* slice.1 produced from %slice.1 : Tensor = aten::select(%158, %6, %19) # /home/burak/test.py:20:8 in lowering graph for mini graph input.

I have also experimented with the following three blocks:

"""block(none)"""
retval=x # no partitioning
"""block(none)"""

"""block(for)"""
retval=[]
for slice in x:
    retval.append(slice)
"""block(for)"""

"""block(cond)"""
if x.sum() > 0: # any cond. dep. on tensor
    retval = x + 100
else:
    retval = x + 50
"""block(cond)"""

I haven't received any errors with these. It also seems that torch.jit.script(...) generates a valid graph with all tested blocks (including block(for+cond)) but the problem is at the partitioning step.

I couldn't understand why block(for+cond) is problematic. Any comments or suggestions? Thanks in advance!

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
  • PyTorch Version (e.g., 1.0): 1.13.1
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.9.15
  • CUDA version: 11.7
  • GPU models and configuration: NVIDIA GeForce RTX 3070 (Laptop)
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions