-
Notifications
You must be signed in to change notification settings - Fork 370
Closed
Labels
Description
Bug Description
While compiling a Pytorch model following this tutorial. Compilation fails with below error
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select
To Reproduce
Here is the code used to compile
import torch_tensorrt
import torch
import sys
from transformers import BertConfig, BertModel
model = BertModel.from_pretrained("bert-base-uncased")
inp = torch.rand(1, 256).long()
model = torch.jit.trace(model, inp)
print(next(model.parameters()).is_cuda)
inputs = [
torch_tensorrt.Input(
min_shape=[1, 256],
opt_shape=[1, 256],
max_shape=[1, 256],
dtype=torch.int32,
)
]
enabled_precisions = {torch.float32} # Run with fp16
trt_ts_module = torch_tensorrt.ts.compile(
model,inputs=inputs, enabled_precisions=enabled_precisions)
#input_data = input_data.to("cuda").half()
#result = trt_ts_module(input_data)
#print(result)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
Steps to reproduce the behavior:
- Copy above code to /home/compile.py
- sudo docker run --gpus all -v /home:/home -it --rm nvcr.io/nvidia/pytorch:22.09-py3
- pip install transformers==2.3.0
- python /home/compile.py
Environment
AWS g4dn.xlarge with DLAMI
Build information about Torch-TensorRT can be found by turning on debug messages
root@abc:/home/ec2-user# pip list | grep tensor
jupyter-tensorboard 0.2.0
tensorboard 2.10.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorrt 8.5.0.12
torch-tensorrt 1.2.0a0
functorch 0.3.0a0
pytorch-quantization 2.1.2
torch 1.13.0a0+d0d6b1f
torch-tensorrt 1.2.0a0
torchtext 0.11.0a0
torchvision 0.14.0a0
root@abc:/home/ec2-user# python --version
Python 3.8.13
[ec2-user@abc]$ uname -ra
Linux ip-172-31-36-232.us-west-2.compute.internal 4.14.291-218.527.amzn2.x86_64 #1 SMP Fri Aug 26 09:54:31 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
- GPU models and configuration: T4
Stack trace
False
WARNING: [Torch-TensorRT] - For input input_ids, found user specified input dtype as Int32, however when inspecting the graph, the input type expected was inferred to be Float
The compiler is going to use the user setting Int32
This conflict may cause an error at runtime due to partial compilation being enabled and therefore
compatibility with PyTorch's data type convention is required.
If you do indeed see errors at runtime either:
- Remove the dtype spec for input_ids
- Disable partial compilation by setting require_full_compilation to True
Traceback (most recent call last):
File "tensorrt_compile_1.py", line 20, in <module>
trt_ts_module = torch_tensorrt.ts.compile(
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 134, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py(2206): embedding
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/sparse.py(160): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1173): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1185): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/modeling_bert.py(186): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1173): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1185): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/modeling_bert.py(735): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1173): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1185): _call_impl
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(967): trace_module
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(750): trace
tensorrt_compile_1.py(7): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)