Skip to content

🐛 [Bug] Cannot export models to TensorRT with int8 quantization #1222

@domef

Description

@domef

🐛 Describe the bug

I'm trying to convert a resnet18 to TensorRT. It works fine when setting enabled_precisions to torch.float and to torch.float16. It doesn't work with torch.int8.

import torch
import torchvision
import torch_tensorrt


model = torchvision.models.resnet18().eval().cuda()
model_jit = torch.jit.script(model)
# model_jit = torch.jit.trace(model, torch.rand((1, 3, 256, 256), device="cuda"))

trt_model = torch_tensorrt.ts.compile(
    model_jit,
    inputs=[torch_tensorrt.Input((1, 3, 256, 256))],
    device={
        "device_type": torch_tensorrt.DeviceType.GPU,
        "gpu_id": 0,
        "dla_core": 0,
        "allow_gpu_fallback": True,
    },
    enabled_precisions={torch.int8},
)

When using the model exported with torch.jit.script, the error is the following:

  File "test_tensorrt.py", line 10, in <module>
    trt_model = torch_tensorrt.ts.compile(
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: Unknown type bool encountered in graph lowering. This type is not supported in ONNX export.

When using the model exported with torch.jit.trace, the program exit with:

Segmentation fault (core dumped)

I'm using the nvidia container nvcr.io/nvidia/pytorch:22.06-py3.

Edit:

If I convert the nn.Module I get the same error as using the scripted model:

model = torchvision.models.resnet18().eval().cuda()

trt_model = torch_tensorrt.compile(
    model,
    inputs=[torch_tensorrt.Input((1, 3, 256, 256))],
    device={
        "device_type": torch_tensorrt.DeviceType.GPU,
        "gpu_id": 0,
        "dla_core": 0,
        "allow_gpu_fallback": True,
    },
    enabled_precisions={torch.int8},
)

Versions

PyTorch version: 1.13.0a0+340c412
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.23.2
Libc version: glibc-2.31

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10) [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.13.0-52-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.7.99
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 Ti
Nvidia driver version: 510.73.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] pytorch-quantization==2.1.2
[pip3] torch==1.13.0a0+340c412
[pip3] torch-tensorrt==1.1.0a0
[pip3] torchtext==0.13.0a0
[pip3] torchvision==0.13.0a0
[conda] mkl 2020.4 h726a3e6_304 conda-forge
[conda] mkl-include 2020.4 h726a3e6_304 conda-forge
[conda] numpy 1.22.4 py38h99721a1_0 conda-forge
[conda] pytorch-quantization 2.1.2 pypi_0 pypi
[conda] torch 1.13.0a0+340c412 pypi_0 pypi
[conda] torch-tensorrt 1.1.0a0 pypi_0 pypi
[conda] torchtext 0.13.0a0 pypi_0 pypi
[conda] torchvision 0.13.0a0 pypi_0 pypi

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions