Skip to content

🐛 [Bug] Compiling BERT model from transformers fails #1401

@vishwanath-gowda

Description

@vishwanath-gowda

Bug Description

While compiling a Pytorch model following this tutorial. Compilation fails with below error

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select

To Reproduce

Here is the code used to compile

import torch_tensorrt
import torch
import sys
from transformers import BertConfig, BertModel
model = BertModel.from_pretrained("bert-base-uncased")
inp = torch.rand(1, 256).long()
model = torch.jit.trace(model, inp)

print(next(model.parameters()).is_cuda)

inputs = [
    torch_tensorrt.Input(
        min_shape=[1, 256],
        opt_shape=[1, 256],
        max_shape=[1, 256],
        dtype=torch.int32,
    )
]
enabled_precisions = {torch.float32}  # Run with fp16
trt_ts_module = torch_tensorrt.ts.compile(
    model,inputs=inputs, enabled_precisions=enabled_precisions)

#input_data = input_data.to("cuda").half()
#result = trt_ts_module(input_data)
#print(result)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")

Steps to reproduce the behavior:

  1. Copy above code to /home/compile.py
  2. sudo docker run --gpus all -v /home:/home -it --rm nvcr.io/nvidia/pytorch:22.09-py3
  3. pip install transformers==2.3.0
  4. python /home/compile.py

Environment

AWS g4dn.xlarge with DLAMI

Build information about Torch-TensorRT can be found by turning on debug messages

root@abc:/home/ec2-user# pip list | grep tensor
jupyter-tensorboard           0.2.0
tensorboard                   2.10.0
tensorboard-data-server       0.6.1
tensorboard-plugin-wit        1.8.1
tensorrt                      8.5.0.12
torch-tensorrt                1.2.0a0
functorch                     0.3.0a0
pytorch-quantization          2.1.2
torch                         1.13.0a0+d0d6b1f
torch-tensorrt                1.2.0a0
torchtext                     0.11.0a0
torchvision                   0.14.0a0
root@abc:/home/ec2-user# python --version
Python 3.8.13
[ec2-user@abc]$ uname -ra
Linux ip-172-31-36-232.us-west-2.compute.internal 4.14.291-218.527.amzn2.x86_64 #1 SMP Fri Aug 26 09:54:31 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
  • GPU models and configuration: T4

Stack trace


False
WARNING: [Torch-TensorRT] - For input input_ids, found user specified input dtype as Int32, however when inspecting the graph, the input type expected was inferred to be Float
The compiler is going to use the user setting Int32
This conflict may cause an error at runtime due to partial compilation being enabled and therefore
compatibility with PyTorch's data type convention is required.
If you do indeed see errors at runtime either:
- Remove the dtype spec for input_ids
- Disable partial compilation by setting require_full_compilation to True
Traceback (most recent call last):
  File "tensorrt_compile_1.py", line 20, in <module>
    trt_ts_module = torch_tensorrt.ts.compile(
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 134, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py(2206): embedding
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/sparse.py(160): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1173): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1185): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/modeling_bert.py(186): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1173): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1185): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/modeling_bert.py(735): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1173): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1185): _call_impl
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(967): trace_module
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(750): trace
tensorrt_compile_1.py(7): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions