-
Notifications
You must be signed in to change notification settings - Fork 369
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
To Reproduce
Minimal reproducible code:
import torch
import torch_tensorrt
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model = model.eval().cuda()
inputs = [
torch_tensorrt.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[4, 3, 224, 224],
max_shape=[16, 3, 224, 224],
dtype=torch.float32
)
]
# inputs = torch_tensorrt.Input(shape=[2, 3, 224, 224], dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, "dynamo", inputs)
Expected behavior
Model should compile with Dynamic shapes.
But I got error:
WARNING:torch_tensorrt.dynamo._compiler:Node scaled_dot_product_attention of op type call_function does not have metadata. This could sometimes lead to undefined behavior.
WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +489, GPU +0, now: CPU 6268, GPU 2121 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1906, GPU +354, now: CPU 8327, GPU 2475 (MiB)
WARNING:torch_tensorrt [TensorRT Conversion Context]:CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.058611
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: [SLICE]-[aten_ops.expand.default]-[/vit_embeddings/expand]: ISliceLayer has out of bounds access on axis 0 Condition '<' violated: 3 >= 1.)
Traceback (most recent call last):
File "/mnt/bn/hukongtao-infer-speed/mlx/users/kongtao.hu/codebase/EasyGuard_0617/speed_vit_test.py", line 27, in <module>
trt_gm = torch_tensorrt.compile(model, "dynamo", inputs)
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/_compile.py", line 250, in compile
trt_graph_module = dynamo_compile(
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 243, in compile
trt_gm = compile_module(gm, inputs, settings)
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 431, in compile_module
trt_module = convert_module(
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 107, in convert_module
interpreter_result = interpret_module_to_result(module, inputs, settings)
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 88, in interpret_module_to_result
interpreter_result = interpreter.run()
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 350, in run
assert serialized_engine
AssertionError
Environment
Additional context
Reference official documentation:
https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working