-
Notifications
You must be signed in to change notification settings - Fork 371
Description
❓ Question
As far as I understand Torch-TensorRT performs a partitioning step when unsupported operations are encountered. Then, graph uses generated TensorRT engine(s) for supported partition(s) and falls back to TorchScript JIT anywhere else. I can observe this behavior from generated graphs in general. However, I receive errors with specific blocks in which I couldn't understand why such blocks are problematic.
For instance, for the following (example) block:
"""block(for+cond)"""
retval=[]
for slice in x: # x: torch.Tensor
if slice.sum() > 0: # any cond. dep. on tensor/slice
retval.append(slice + 100)
else:
retval.append(slice + 50)
"""block(for+cond)"""I receive a RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:167] Expected ivalues_maps.count(input) to be true but got false on torch_tensorrt.compile(...):
Traceback (most recent call last):
File "/home/burak/test.py", line 36, in <module>
net_trt = torch_tensorrt.compile(net, **net_specs)
File "/home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 125, in compile
return torch_tensorrt.ts.compile(
File "/home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch_tensorrt/ts/_compiler.py", line 136, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:167] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* slice.1 produced from %slice.1 : Tensor = aten::select(%158, %6, %19) # /home/burak/test.py:20:8 in lowering graph for mini graph input.
What you have already tried
I have tried this behavior with the following example script:
import torch
import torch_tensorrt
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Info)
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv0 = torch.nn.Conv2d(3, 8, kernel_size=3)
self.relu = torch.nn.ReLU(inplace=True)
self.conv1 = torch.nn.Conv2d(8, 16, kernel_size=3)
def forward(self, x):
x = self.conv0(x)
x = self.relu(x)
x = self.conv1(x)
"""block(for+cond)"""
retval=[]
for slice in x:
if slice.sum() > 0: # any cond. dep. on tensor/slice
retval.append(slice + 100)
else:
retval.append(slice + 50)
"""block(for+cond)"""
return retval
net = Net().eval().cuda()
net_specs = {
'inputs': [torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)],
'enabled_precisions': {torch.float32, torch.half},
}
net_trt = torch_tensorrt.compile(net, **net_specs)
print(net_trt.graph)I receive the following RuntimeError (full output, info log-level):
INFO: [Torch-TensorRT] - ir was set to default, using TorchScript as ir
INFO: [Torch-TensorRT] - Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript
INFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor):
%self.conv0.weight.1 : Float(8, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.conv0.bias.1 : Float(8, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value= 0.1437 0.0745 0.1127 0.1185 0.1406 0.1445 -0.0802 0.0562 [ CUDAFloatType{8} ]]()
%self.conv1.weight.1 : Float(16, 8, 3, 3, strides=[72, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.conv1.bias.1 : Float(16, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%9 : int = prim::Constant[value=1]()
%8 : NoneType = prim::Constant()
%7 : bool = prim::Constant[value=1]() # /home/burak/test.py:20:8
%6 : int = prim::Constant[value=0]() # /home/burak/test.py:21:29
%5 : int = prim::Constant[value=100]() # /home/burak/test.py:22:38
%4 : int = prim::Constant[value=50]() # /home/burak/test.py:24:38
%3 : int[] = prim::Constant[value=[1, 1]]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%153 : bool = prim::Constant[value=0]()
%154 : int[] = prim::Constant[value=[0, 0]]()
%155 : Tensor = aten::_convolution(%x.1, %self.conv0.weight.1, %self.conv0.bias.1, %3, %2, %3, %153, %154, %9, %153, %153, %153, %153)
%17 : Tensor[] = prim::ListConstruct()
%137 : Tensor = aten::relu(%155) # /home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch/nn/functional.py:1455:17
%156 : bool = prim::Constant[value=0]()
%157 : int[] = prim::Constant[value=[0, 0]]()
%158 : Tensor = aten::_convolution(%137, %self.conv1.weight.1, %self.conv1.bias.1, %3, %2, %3, %156, %157, %9, %156, %156, %156, %156)
%144 : int = aten::len(%158) # /home/burak/test.py:20:8
= prim::Loop(%144, %7) # /home/burak/test.py:20:8
block0(%19 : int):
%slice.1 : Tensor = aten::select(%158, %6, %19) # /home/burak/test.py:20:8
%123 : Tensor = aten::sum(%slice.1, %8) # /home/burak/test.py:21:15
%125 : Tensor = aten::gt(%123, %6) # /home/burak/test.py:21:15
%126 : bool = aten::Bool(%125) # /home/burak/test.py:21:15
= prim::If(%126) # /home/burak/test.py:21:12
block0():
%24 : Tensor = aten::add(%slice.1, %5, %9) # /home/burak/test.py:22:30
%25 : Tensor[] = aten::append(%17, %24) # /home/burak/test.py:22:16
-> ()
block1():
%26 : Tensor = aten::add(%slice.1, %4, %9) # /home/burak/test.py:24:30
%27 : Tensor[] = aten::append(%17, %26) # /home/burak/test.py:24:16
-> ()
-> (%7)
return (%17)
INFO: [Torch-TensorRT] - Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.
Unsupported operators listed below:
- aten::Bool.Tensor(Tensor a) -> bool
- aten::len.Tensor(Tensor t) -> int
Traceback (most recent call last):
File "/home/burak/test.py", line 36, in <module>
net_trt = torch_tensorrt.compile(net, **net_specs)
File "/home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 125, in compile
return torch_tensorrt.ts.compile(
File "/home/burak/miniconda3/envs/convert/lib/python3.10/site-packages/torch_tensorrt/ts/_compiler.py", line 136, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:167] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* slice.1 produced from %slice.1 : Tensor = aten::select(%158, %6, %19) # /home/burak/test.py:20:8 in lowering graph for mini graph input.
I have also experimented with the following three blocks:
"""block(none)"""
retval=x # no partitioning
"""block(none)"""
"""block(for)"""
retval=[]
for slice in x:
retval.append(slice)
"""block(for)"""
"""block(cond)"""
if x.sum() > 0: # any cond. dep. on tensor
retval = x + 100
else:
retval = x + 50
"""block(cond)"""I haven't received any errors with these. It also seems that torch.jit.script(...) generates a valid graph with all tested blocks (including block(for+cond)) but the problem is at the partitioning step.
I couldn't understand why block(for+cond) is problematic. Any comments or suggestions? Thanks in advance!
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
- PyTorch Version (e.g., 1.0): 1.13.1
- CPU Architecture: x86_64
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.9.15
- CUDA version: 11.7
- GPU models and configuration: NVIDIA GeForce RTX 3070 (Laptop)
- Any other relevant information: