-
Notifications
You must be signed in to change notification settings - Fork 371
Closed
Labels
No ActivitybugSomething isn't workingSomething isn't workingcomponent: convertersIssues re: Specific op convertersIssues re: Specific op converterscomponent: partitioning
Description
Bug Description
Compiling model fails if nn.PReLU is used. It works fine with nn.ReLU.
To Reproduce
Just run this code
from torch import nn as nn
from torch.nn import functional as F
# https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/archs/srvgg_arch.py
class test_class(nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4):
super(test_class, self).__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
self.num_conv = num_conv
self.upscale = upscale
self.body = nn.ModuleList()
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
# prelu does not work, relu does
activation = nn.PReLU(num_parameters=num_feat)
#activation = nn.ReLU(inplace=True)
self.body.append(activation)
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
return out
import torch
import torch_tensorrt
model = test_class()
example_data = torch.rand(1,3,256,256)
model = torch.jit.trace(model, [example_data])
print(model(example_data).shape)
model = torch_tensorrt.compile(model, inputs=[example_data], \
enabled_precisions={torch.float}, truncate_long_and_double=True)
print(model(example_data.cuda()).shape)python compact_tensorrt.py
torch.Size([1, 48, 256, 256])
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [pointWiseNode.cpp::computeOutputExtents::17] Error Code 2: Internal Error (Assertion nbDims == inputs[i]->extent.nbDims failed.)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [pointWiseNode.cpp::computeOutputExtents::17] Error Code 2: Internal Error (Assertion nbDims == inputs[i]->extent.nbDims failed.)
Traceback (most recent call last):
File "compact_tensorrt.py", line 34, in <module>
model = torch_tensorrt.compile(model, inputs=[example_data], \
File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py", line 97, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/conversion/converters/impl/conv_deconv.cpp:115] Expected orig_dims.nbDims > 2 to be true but got false
Unable to create convolution layer from node: %18 : Tensor = aten::_convolution(%input, %22, %23, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.body.2 # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:442:0
Expected behavior
Compile should not crash.
Environment
- Torch-TensorRT Version: 1.0
- PyTorch Version: 1.10.0+cu113
- OS: Manjaro + Docker (Nvidia TensorRT container 21.11-py3)
Metadata
Metadata
Assignees
Labels
No ActivitybugSomething isn't workingSomething isn't workingcomponent: convertersIssues re: Specific op convertersIssues re: Specific op converterscomponent: partitioning