Skip to content

🐛 [Bug] nn.PReLU crashes torch_tensorrt.compile #789

@styler00dollar

Description

@styler00dollar

Bug Description

Compiling model fails if nn.PReLU is used. It works fine with nn.ReLU.

To Reproduce

Just run this code

from torch import nn as nn
from torch.nn import functional as F
# https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/archs/srvgg_arch.py
class test_class(nn.Module):
    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4):
        super(test_class, self).__init__()
        self.num_in_ch = num_in_ch
        self.num_out_ch = num_out_ch
        self.num_feat = num_feat
        self.num_conv = num_conv
        self.upscale = upscale

        self.body = nn.ModuleList()
        self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
        
        # prelu does not work, relu does
        activation = nn.PReLU(num_parameters=num_feat)
        #activation = nn.ReLU(inplace=True)
        self.body.append(activation)
        
        self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))

    def forward(self, x):
        out = x
        for i in range(0, len(self.body)):
            out = self.body[i](out)
        return out

import torch
import torch_tensorrt
model = test_class()
example_data = torch.rand(1,3,256,256)
model = torch.jit.trace(model, [example_data])
print(model(example_data).shape)
model = torch_tensorrt.compile(model, inputs=[example_data], \
                enabled_precisions={torch.float}, truncate_long_and_double=True)
print(model(example_data.cuda()).shape)
python compact_tensorrt.py 
torch.Size([1, 48, 256, 256])
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [pointWiseNode.cpp::computeOutputExtents::17] Error Code 2: Internal Error (Assertion nbDims == inputs[i]->extent.nbDims failed.)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 2: [pointWiseNode.cpp::computeOutputExtents::17] Error Code 2: Internal Error (Assertion nbDims == inputs[i]->extent.nbDims failed.)
Traceback (most recent call last):
  File "compact_tensorrt.py", line 34, in <module>
    model = torch_tensorrt.compile(model, inputs=[example_data], \
  File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py", line 97, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/conversion/converters/impl/conv_deconv.cpp:115] Expected orig_dims.nbDims > 2 to be true but got false
Unable to create convolution layer from node: %18 : Tensor = aten::_convolution(%input, %22, %23, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.body.2 # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:442:0

Expected behavior

Compile should not crash.

Environment

  • Torch-TensorRT Version: 1.0
  • PyTorch Version: 1.10.0+cu113
  • OS: Manjaro + Docker (Nvidia TensorRT container 21.11-py3)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions