diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index c7f1ba1d6b..66a81c1fd4 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -72,7 +72,7 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/ popd @@ -98,7 +98,7 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py popd @@ -125,7 +125,7 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/ ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py popd @@ -152,7 +152,7 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/ ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/ ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/ diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 4e0153dcc3..99b0aa2674 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -78,16 +78,15 @@ jobs: script: | export USE_HOST_DEPS=1 export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH - export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH pushd . cd tests/modules # Don't use requirements.txt here as it contains tensorrt and torch which should have been installed by now. - ${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers timm pybind11==2.6.2 + ${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers==4.39.3 timm==0.9.16 pybind11==2.6.2 ${CONDA_RUN} python hub.py popd pushd . cd tests/py/ts - ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/ ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/ ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/ @@ -115,10 +114,9 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 - export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo - ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/ popd @@ -144,10 +142,9 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 - export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo - ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py popd @@ -174,10 +171,9 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 - export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo - ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py popd @@ -203,10 +199,9 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 - export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo - ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/ ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py @@ -234,10 +229,9 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 - export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo - ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/ ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/ ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/ @@ -264,9 +258,8 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 - export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/core - ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . popd diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index a1ee30e994..0267f9ce04 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -124,6 +124,8 @@ std::vector execute_engine(std::vector inputs, c10::intr } } + // this is a buffer to store shape tensor input addresses throughout the runtime scope + std::list> inputShapeTensorValues; { std::unique_ptr input_profiler_guard; if (compiled_engine->profile_execution) { @@ -142,12 +144,30 @@ std::vector execute_engine(std::vector inputs, c10::intr auto dims = core::util::toDims(inputs[i].sizes()); auto shape = core::util::toVec(dims); LOG_DEBUG("Input Name: " << name << " Shape: " << dims); - compiled_engine->exec_ctx->setInputShape(name.c_str(), dims); - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr()); + if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) { + // Shape tensor inputs are casted to int32 explicitly. + // Refer to + // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 + auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt32); + std::vector inputs_cpu_vec( + input_cpu.data_ptr(), input_cpu.data_ptr() + input_cpu.numel()); + inputShapeTensorValues.emplace_back(inputs_cpu_vec); + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()); + } else { + compiled_engine->exec_ctx->setInputShape(name.c_str(), dims); + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr()); + } } + // Check if input shapes can be inferred. + int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; + std::vector names(io_size); + int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); TORCHTRT_CHECK( - compiled_engine->exec_ctx->allInputShapesSpecified(), "Not enough inputs provided (runtime.RunCudaEngine)"); + nbNames == 0, + "The shapes of the inputs: " + << names + << " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly"); } std::vector outputs(compiled_engine->num_io.second); diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 32f19ce1f0..18636f8114 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -47,6 +47,7 @@ class _ShapeMode(Enum): high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET torch_tensor: torch.Tensor = None name: str = "" + is_shape_tensor: bool = False def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input @@ -161,6 +162,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: else: self._explicit_set_dtype = False + if "is_shape_tensor" in kwargs: + self.is_shape_tensor = kwargs["is_shape_tensor"] + if "format" in kwargs: self.format = memory_format._from(kwargs["format"]) @@ -174,7 +178,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: if "torch_tensor" in kwargs: self.torch_tensor = kwargs["torch_tensor"] else: - if self.shape_mode == Input._ShapeMode.DYNAMIC: + if self.is_shape_tensor: + self.torch_tensor = torch.tensor( + kwargs["opt_shape"], dtype=kwargs["dtype"] + ) + elif self.shape_mode == Input._ShapeMode.DYNAMIC: self.torch_tensor = self.example_tensor("opt_shape") else: self.torch_tensor = self.example_tensor() diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 11c0f6b3ac..6bc334f427 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -58,13 +58,9 @@ def trace( device = to_torch_device(kwargs.get("device", default_device())) torch_inputs = get_torch_inputs(inputs, device) - dynamic_shapes = {} + dynamic_shapes = [] for input in inputs: if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC: - if not input.name: - raise AssertionError( - f"Expected a name for a dynamic input with shape {input.shape} but found none" - ) min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] max_shape = input.shape["max_shape"] @@ -80,8 +76,8 @@ def trace( max=max_shape[dim], ) - dynamic_shapes[input.name] = dynamic_dims + dynamic_shapes.append(dynamic_dims) - exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes) + exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes)) return exp_program diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 66a9729cc0..dbb900009a 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -96,6 +96,8 @@ def _pretraced_backend( gm = apply_lowering_passes(gm, torch_inputs) + logger.debug("Lowered Input graph:\n " + str(gm.graph)) + torchtrt_inputs = prepare_inputs( torch_inputs, disable_memory_format_check=True ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 59d2c5d6c0..4de6aeb98f 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -22,10 +23,10 @@ get_node_name, get_trt_tensor, ) +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -365,18 +366,29 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: max_shape = current_input.shape["max_shape"] # TODO: Does not support disjoint optimization profiles? assert self.optimization_profiles is not None - self.optimization_profiles[0].set_shape( - target, min_shape, opt_shape, max_shape - ) - assert len(min_shape) == len(opt_shape) == len(max_shape) - for i in range(len(min_shape)): - if min_shape[i] == opt_shape[i] == max_shape[i]: - shape.append(min_shape[i]) - else: - # -1 to represent the dynamic dimension - shape.append(-1) - elif current_input.shape_mode == Input._ShapeMode.STATIC: + if current_input.is_shape_tensor: + # For shape_tensors, min/opt/max_shapes correspond to actual values + # of the shapes provided during runtime + self.optimization_profiles[0].set_shape_input( + target, min_shape, opt_shape, max_shape + ) + shape.append(len(opt_shape)) + else: + self.optimization_profiles[0].set_shape( + target, min_shape, opt_shape, max_shape + ) + + for i in range(len(min_shape)): + if min_shape[i] == opt_shape[i] == max_shape[i]: + shape.append(min_shape[i]) + else: + # -1 to represent the dynamic dimension + shape.append(DYNAMIC_DIM) + elif ( + not current_input.is_shape_tensor + and current_input.shape_mode == Input._ShapeMode.STATIC + ): assert isinstance(current_input.shape, tuple) shape = list(current_input.shape) else: @@ -388,6 +400,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: _LOGGER.debug( f"Adding input to in-progress INetwork: {target} [shape={shape}, dtype={trt_input_dtype}]" ) + return self.ctx.net.add_input( name=target, shape=tuple(shape), diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 64bb14ad21..1de955f680 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -4,7 +4,9 @@ import logging from typing import List, Sequence +import tensorrt as trt import torch +from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -17,8 +19,6 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -28,12 +28,12 @@ def infer_module_output_dtypes( device: Device, truncate_double: bool = False, ) -> List[dtype]: - torch_inputs = get_torch_inputs(inputs, device) - module = module.to(device.to(torch.device)) - module_outputs = module(*torch_inputs) - - if not isinstance(module_outputs, (list, tuple)): - module_outputs = [module_outputs] + with maybe_disable_fake_tensor_mode(): + torch_inputs = get_torch_inputs(inputs, device) + module = module.to(device.to(torch.device)) + module_outputs = module(*torch_inputs) + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] # Int64 outputs can sometimes be generated from within other operators # such as aten.sum - such outputs can be truncated diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 23483986a5..69b29cf400 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -128,10 +128,14 @@ def aten_ops_batch_norm_legit_no_training( @dynamo_tensorrt_converter( - torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator + torch.ops.aten.native_layer_norm.default, + capability_validator=one_user_validator, + supports_dynamic_shapes=True, ) -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm) +@dynamo_tensorrt_converter( + torch.ops.aten.layer_norm.default, supports_dynamic_shapes=True +) +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -236,7 +240,10 @@ def aten_ops_cat( ) -@dynamo_tensorrt_converter(torch.ops.aten.embedding.default) +@dynamo_tensorrt_converter( + torch.ops.aten.embedding.default, + supports_dynamic_shapes=True, +) def aten_ops_embedding( ctx: ConversionContext, target: Target, @@ -426,7 +433,7 @@ def aten_ops_index( ) -@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) +@dynamo_tensorrt_converter(torch.ops.aten.tanh.default, supports_dynamic_shapes=True) def aten_ops_tanh( ctx: ConversionContext, target: Target, @@ -517,10 +524,10 @@ def aten_ops_hard_sigmoid( ) -@dynamo_tensorrt_converter(torch.ops.aten.matmul) -@dynamo_tensorrt_converter(torch.ops.aten.mm.default) -@dynamo_tensorrt_converter(torch.ops.aten.mv.default) -@dynamo_tensorrt_converter(torch.ops.aten.bmm.default) +@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.bmm.default, supports_dynamic_shapes=True) def aten_ops_matmul( ctx: ConversionContext, target: Target, @@ -601,7 +608,9 @@ def aten_ops_erf( ) -@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default) +@dynamo_tensorrt_converter( + torch.ops.aten.unsqueeze.default, supports_dynamic_shapes=True +) def aten_ops_unsqueeze( ctx: ConversionContext, target: Target, @@ -614,7 +623,9 @@ def aten_ops_unsqueeze( ) -@dynamo_tensorrt_converter(torch.ops.aten._softmax.default) +@dynamo_tensorrt_converter( + torch.ops.aten._softmax.default, supports_dynamic_shapes=True +) def aten_ops_softmax( ctx: ConversionContext, target: Target, @@ -709,7 +720,7 @@ def aten_ops_select( ) -@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -805,7 +816,7 @@ def aten_ops_tile( ) -@dynamo_tensorrt_converter(torch.ops.aten.permute.default) +@dynamo_tensorrt_converter(torch.ops.aten.permute.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -876,10 +887,12 @@ def validator(to_copy_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.clone.default, capability_validator=lambda node: not is_only_operator_on_placeholder(node), + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator(placeholder_only=False), + supports_dynamic_shapes=True, ) def aten_ops_clone_copy_dtype( ctx: ConversionContext, @@ -928,7 +941,7 @@ def aten_ops_clone_copy_placeholder( ) -@dynamo_tensorrt_converter(torch.ops.aten.expand.default) +@dynamo_tensorrt_converter(torch.ops.aten.expand.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -1551,6 +1564,7 @@ def aten_ops_isnan( ) +@dynamo_tensorrt_converter(operator.add, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.add.Tensor, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.add.Scalar, supports_dynamic_shapes=True) def aten_ops_add( @@ -1583,8 +1597,8 @@ def aten_ops_add( ) -@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar, supports_dynamic_shapes=True) def aten_ops_mul( ctx: ConversionContext, target: Target, @@ -1670,11 +1684,11 @@ def aten_ops_sub( ) -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) -@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) -@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) -@dynamo_tensorrt_converter(torch.ops.prims.div.default) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.prims.div.default, supports_dynamic_shapes=True) def aten_ops_div( ctx: ConversionContext, target: Target, @@ -1717,9 +1731,13 @@ def aten_ops_div( ) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) +@dynamo_tensorrt_converter( + torch.ops.aten.pow.Tensor_Tensor, supports_dynamic_shapes=True +) +@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter( + torch.ops.aten.pow.Tensor_Scalar, supports_dynamic_shapes=True +) def aten_ops_pow( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 135309443e..d1bdf72d21 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -194,6 +194,7 @@ def create_constant( value: Union[int, float, bool, np.ndarray, torch.Tensor], name: str, dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]], + min_rank: Optional[int] = 1, ) -> TRTTensor: """ Add a TensorRT constant layer whose value is `value` to `ctx.net`. @@ -205,14 +206,19 @@ def create_constant( name (str): Name of the added TensorRT Constant layer. dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): If a dtype is given, we will convert the type of the given `value` to this dtype. + min_rank (int): minimum rank of the constant tensor. Returns: A TensorRT ITensor that represents the given value. """ + shape = (1,) + # Rank 0 constant is required in IFillLayer inputs. + if min_rank == 0: + shape = trt.Dims() numpy_value = to_numpy( value, _enums.dtype._from(dtype).to(np.dtype) if dtype is not None else None ) constant = ctx.net.add_constant( - (1,) if isinstance(value, (int, float, bool)) else value.shape, + shape if isinstance(value, (int, float, bool)) else value.shape, numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value, ) constant.name = name @@ -224,6 +230,7 @@ def get_trt_tensor( input_val: Any, name: str, dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, + min_rank: int = 1, ) -> TRTTensor: """ Given a value of random type, we try to convert it to a TensorRT ITensor. @@ -236,6 +243,7 @@ def get_trt_tensor( one. dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): If dtype is provided, the given value will be converted to this dtype. + min_rank (int): minimum rank of the constant tensor. Returns: A TensorRT ITensor that represents the given value. """ @@ -248,7 +256,7 @@ def get_trt_tensor( input_val = input_val.astype(np.float32) if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)): - return create_constant(ctx, input_val, name, dtype) + return create_constant(ctx, input_val, name, dtype, min_rank) elif isinstance(input_val, TRTTensor): return input_val else: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index d6ffc77377..d7d32b5bb0 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,11 +1,14 @@ from typing import Optional, Sequence, Union import numpy as np +import tensorrt as trt import torch from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, get_positive_dim, get_trt_tensor, ) @@ -20,14 +23,19 @@ def cat( name: str, input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], dim: int, + cast_dtype: Union[_enums.dtype, trt.DataType, np.dtype] = None, ) -> Union[TRTTensor, Sequence[TRTTensor]]: trt_inputs = [] for i, each_input in enumerate(input): if not isinstance(each_input, TRTTensor): each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}") + if cast_dtype: + each_input = cast_trt_tensor( + ctx, each_input, cast_dtype, f"{name}_tensor_int32_cast_{i}" + ) trt_inputs.append(each_input) concat_layer = ctx.net.add_concatenation(trt_inputs) - dim = get_positive_dim(dim, len(input[0].shape)) + dim = get_positive_dim(dim, len(trt_inputs[0].shape)) concat_layer.axis = dim set_layer_name(concat_layer, target, f"{name}_gather", source_ir) return concat_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index b2a79af5cb..b2d005b175 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -2,8 +2,13 @@ import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -25,7 +30,13 @@ def reshape( for i, s in enumerate(shape): if isinstance(s, TRTTensor): - trt_shape.append(s) + dim_int32 = cast_trt_tensor( + ctx, + s, + _enums.dtype.int32, + name + f"_int32_casted_{i}", + ) + trt_shape.append(dim_int32) else: a = get_trt_tensor(ctx, s, f"{name}_{i}") trt_shape.append(a) diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index ea2f1c4d89..43009c306a 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -3,13 +3,20 @@ from typing import Dict, Sequence, Tuple, Union import numpy as np +import tensorrt as trt import torch from torch.fx.node import Argument, Node, Target +from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( ConverterRegistry, dynamo_tensorrt_converter, ) +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.dynamo.conversion.impl.elementwise import sub, trunc_div from torch_tensorrt.fx.types import TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -28,7 +35,7 @@ def getitem_validator(getitem_node: Node) -> bool: capability_validator=getitem_validator, supports_dynamic_shapes=True, ) -@dynamo_tensorrt_converter(torch.ops.aten.detach.default) +@dynamo_tensorrt_converter(torch.ops.aten.detach.default, supports_dynamic_shapes=True) def generic_evaluator( ctx: ConversionContext, target: Target, @@ -42,7 +49,9 @@ def generic_evaluator( return target(*args) -@dynamo_tensorrt_converter(torch.ops.aten.arange.start_step) +@dynamo_tensorrt_converter( + torch.ops.aten.arange.start_step, supports_dynamic_shapes=True +) def aten_ops_arange_start_step( ctx: ConversionContext, target: Target, @@ -50,4 +59,38 @@ def aten_ops_arange_start_step( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + # Case where inputs to arange are dynamic + if any(isinstance(tensor, TRTTensor) for tensor in args): + start_rank_0 = get_trt_tensor(ctx, args[0], name + "_start_rank_0", min_rank=0) + start_rank_1 = get_trt_tensor(ctx, args[0], name + "_start_rank_1", min_rank=1) + end = get_trt_tensor(ctx, args[1], name + "_end", min_rank=1) + step = args[2] if len(args) > 2 else 1 + step = get_trt_tensor(ctx, step, name + "_step", min_rank=1) + # Calculate shape = (end-start) / step + shape = sub( + ctx, + target, + SourceIR.ATEN, + name + "_sub", + end, + start_rank_1, + ) + shape = trunc_div( + ctx, + target, + SourceIR.ATEN, + name + "_shape", + shape, + step, + ) + shape = cast_trt_tensor(ctx, shape, end.dtype, name + "_shape_casted") + fill_layer = ctx.net.add_fill( + shape.shape, trt.FillOperation.LINSPACE, shape.dtype + ) + fill_layer.set_input(0, shape) + # Set start index + fill_layer.set_input(1, start_rank_0) + # Set delta/step + fill_layer.set_input(2, step) + return fill_layer.get_output(0) return np.arange(*args) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py index 312926e870..aa7403f94e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py @@ -2,7 +2,6 @@ from typing import Sequence import torch -from torch.fx.passes.shape_prop import ShapeProp from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) @@ -17,16 +16,6 @@ def fuse_prims_broadcast( """Fuses prim nodes which are effectively the ATen equivalents with keep_dim=True""" modified_graph = False - # Propagate shapes through the graph to determine if broadcast can be resolved - try: - ShapeProp(gm).propagate(*sample_inputs) - except (RuntimeError, AssertionError): - logger.warning( - "Shape Propagation Failed on Graph, skipping fuse_prims_broadcast lowering pass", - exc_info=True, - ) - return gm - for node in gm.graph.nodes: # If the node is a sum prims operator, with broadcast_in_dim being the only consumer # it is a candidate for fusing diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 270973c8c3..9ac677484f 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -2,6 +2,8 @@ from typing import Any, Dict, Optional, Sequence, Set, Tuple import torch +from torch._subclasses.fake_tensor import FakeTensor +from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import DEBUG @@ -12,13 +14,12 @@ def contains_sym_int(tensor: torch.Tensor) -> bool: """ Returns true if the given tensor has symbolic shape. """ - for dim in tensor: - if isinstance(dim, torch.SymInt): - return True - return False + return any(isinstance(dim, torch.SymInt) for dim in tensor) -def construct_dynamic_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input: +def construct_dynamic_input( + input_shape: torch.Size, input_dtype: torch.dtype, is_shape_tensor: bool = False +) -> Input: """ Constructs a torch_tensorrt.Input based on a symbolic input Args: @@ -50,18 +51,26 @@ def construct_dynamic_input(input_shape: torch.Size, input_dtype: torch.dtype) - max_shape.append(dim) return Input( - min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, dtype=input_dtype + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=input_dtype, + is_shape_tensor=is_shape_tensor, ) -def get_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input: +def get_input( + input_shape: torch.Size, dtype: torch.dtype, is_shape_tensor: bool = False +) -> Input: """ Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs """ if contains_sym_int(input_shape): - return construct_dynamic_input(input_shape, input_dtype) + return construct_dynamic_input( + input_shape, dtype, is_shape_tensor=is_shape_tensor + ) else: - return Input(shape=input_shape, dtype=input_dtype) + return Input(shape=input_shape, dtype=dtype, is_shape_tensor=is_shape_tensor) def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: @@ -73,28 +82,42 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: Returns: Sequence of torch_tensorrt.Input's representing inputs to given module """ - torchtrt_inputs = [] - module_inputs = [node for node in module.graph.nodes if node.op == "placeholder"] - for input in module_inputs: - if input.meta: - if "val" in input.meta: - input_meta = input.meta["val"] - input_shape = input_meta.size() - torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) - elif "tensor_meta" in input.meta: - input_meta = input.meta["tensor_meta"] - input_shape = input_meta.shape - torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) + with maybe_disable_fake_tensor_mode(): + torchtrt_inputs = [] + module_inputs = [ + node for node in module.graph.nodes if node.op == "placeholder" + ] + for input in module_inputs: + if input.meta: + if "val" in input.meta: + input_meta = input.meta["val"] + if isinstance(input_meta, (FakeTensor, torch.Tensor)): + input_shape = input_meta.size() + torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) + elif isinstance(input_meta, torch.SymInt): + # Assuming sym_integers | shape inputs always have torch.int64 dtype + torchtrt_inputs.append( + get_input([input_meta], torch.int64, is_shape_tensor=True) + ) + else: + raise ValueError( + f"The meta val for input node {input.target} is of type : {type(input_meta)}. Supported types: torch.Tensor|FakeTensor|torch.SymInt" + ) + + elif "tensor_meta" in input.meta: + input_meta = input.meta["tensor_meta"] + input_shape = input_meta.shape + torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) + else: + raise AssertionError( + f"Input {input.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly" + ) else: raise AssertionError( - f"Input {input.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly" + f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly" ) - else: - raise AssertionError( - f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly" - ) - return torchtrt_inputs + return torchtrt_inputs def run_shape_analysis( diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 0c152e15f1..4aa520542d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -15,6 +15,7 @@ _select_rt_device, multi_gpu_device_check, ) +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER logger = logging.getLogger(__name__) @@ -174,7 +175,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] - bindings = [] for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: logger.warning( @@ -193,9 +193,27 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . contiguous_inputs[i].dtype == self.input_dtypes[i] ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - bindings.append(contiguous_inputs[i].data_ptr()) - self.context.set_input_shape( - input_name, tuple(contiguous_inputs[i].shape) + if self.engine.is_shape_inference_io(input_name): + # Shape tensor inputs are casted to int32 explicitly. + # Refer to https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 + inputs_cpu = contiguous_inputs[i].cpu().to(torch.int32) + self.context.set_tensor_address( + input_name, inputs_cpu.data_ptr() + ) + else: + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) + ) + self.context.set_tensor_address( + input_name, contiguous_inputs[i].data_ptr() + ) + + # Check if input shapes can be inferred. + uninferred_input_names = self.context.infer_shapes() + if uninferred_input_names: + logger.warning( + f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ + This could happen if the input tensor addresses/shapes haven't been configured correctly" ) with ( @@ -211,20 +229,19 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . for i, output_name in enumerate(self.output_names): shape = tuple(self.context.get_tensor_shape(output_name)) + if DYNAMIC_DIM in shape: + raise ValueError( + "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." + ) + output = torch.empty( size=shape, dtype=self.output_dtypes[i].to(torch.dtype), device=torch.cuda.current_device(), ) - bindings.append(output.data_ptr()) + self.context.set_tensor_address(output_name, output.data_ptr()) outputs.append(output) - # Assign tensor address appropriately - for idx in range(self.engine.num_io_tensors): - self.context.set_tensor_address( - self.engine.get_tensor_name(idx), bindings[idx] - ) - with ( torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:TensorRTRuntime" diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 6ea9503b84..4ea8687016 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) COSINE_THRESHOLD = 0.99 +DYNAMIC_DIM = -1 def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool: @@ -158,6 +159,10 @@ def parse_complex_tensor_structs( """ if isinstance(inputs, (torch.Tensor, Input)): return apply_fn(getattr(inputs, attribute_to_extract, None)) + elif isinstance(inputs, (int, float, bool)): + # inputs is a python scalar value + inputs_torch = torch.tensor(inputs) + return apply_fn(getattr(inputs_torch, attribute_to_extract, None)) elif isinstance(inputs, (list, tuple)): torchtrt_input_list = [] diff --git a/tests/modules/requirements.txt b/tests/modules/requirements.txt index da63a6dad1..03e51d0b61 100644 --- a/tests/modules/requirements.txt +++ b/tests/modules/requirements.txt @@ -1,3 +1,2 @@ -timm -transformers -torchvision +timm==0.9.12 +transformers==4.40.2 diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 7ce3939371..6bb9b0c500 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -6,16 +6,20 @@ from typing import Callable, List, Optional, Set, Tuple import torch +import torch_tensorrt +from torch.fx.passes.shape_prop import ShapeProp from torch.testing._internal.common_utils import TestCase from torch_tensorrt import Input from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._settings import CompilationSettings # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes -from torch_tensorrt.dynamo.lowering import apply_lowering_passes +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule +from torch_tensorrt.dynamo.utils import get_torch_inputs _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -55,13 +59,13 @@ def run_test( rtol, atol, check_dtype=True, + pyt_inputs=None, ): with torch.no_grad(): cuda_inputs = [] for i in inputs: cuda_inputs.append(i.cuda()) - mod.eval() start = time.perf_counter() interpreter_result = interpreter.run() sec = time.perf_counter() - start @@ -71,9 +75,11 @@ def run_test( interpreter_result.input_names, interpreter_result.output_names, ) - mod = mod.cuda() - ref_outputs = mod(*cuda_inputs) + if pyt_inputs is not None: + ref_outputs = mod(*pyt_inputs) + else: + ref_outputs = mod(*cuda_inputs) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) @@ -198,19 +204,30 @@ def generate_graph( original_inputs: List[torch.Tensor], use_dynamo_tracer: bool, enable_passes: bool, + propagate_shapes: bool = False, ): + mod = mod.eval() + torch_inputs = get_torch_inputs(original_inputs, _defaults.DEVICE) if use_dynamo_tracer: - fx_module = torch._dynamo.export( - mod, - *original_inputs, - aten_graph=True, - assume_static_by_default=True, - tracing_mode="real", - ).graph_module + exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs)) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + fx_module = exported_program.module() else: fx_module = torch.fx.symbolic_trace(mod) if enable_passes: - fx_module = apply_lowering_passes(fx_module, original_inputs) + fx_module = apply_lowering_passes(fx_module, torch_inputs) + + if propagate_shapes: + # TODO: This is currently being used to test embedding_bag_aten due to https://github.com/pytorch/TensorRT/issues/2843 + try: + ShapeProp(fx_module).propagate(*torch_inputs) + except (RuntimeError, AssertionError): + logger.warning( + "Shape Propagation failed on Graph, skipping it", + exc_info=False, + ) return fx_module def run_test( @@ -223,6 +240,7 @@ def run_test( check_dtype=True, use_dynamo_tracer=False, enable_passes=False, + propagate_shapes=False, ): mod.eval() mod = self.generate_graph( @@ -230,6 +248,7 @@ def run_test( inputs, use_dynamo_tracer=use_dynamo_tracer, enable_passes=enable_passes, + propagate_shapes=propagate_shapes, ) # Previous instance of the interpreter auto-casted 64-bit inputs @@ -279,14 +298,17 @@ def run_test_with_dynamic_shape( output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, + use_example_tensors=True, + pyt_inputs=None, + propagate_shapes=False, ): mod.eval() - inputs = [spec.example_tensor("opt_shape") for spec in input_specs] mod = self.generate_graph( mod, - inputs, + input_specs, use_dynamo_tracer=use_dynamo_tracer, enable_passes=enable_passes, + propagate_shapes=propagate_shapes, ) # Previous instance of the interpreter auto-casted 64-bit inputs @@ -302,4 +324,6 @@ def run_test_with_dynamic_shape( # Since the lowering is based on optimal shape. We need to test with # different shape(for ex. max shape) for testing dynamic shape inputs_max = [spec.example_tensor("max_shape") for spec in input_specs] - super().run_test(mod, inputs_max, interp, rtol, atol) + if not use_example_tensors: + inputs_max = [spec.torch_tensor for spec in input_specs] + super().run_test(mod, inputs_max, interp, rtol, atol, pyt_inputs=pyt_inputs) diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py index e06239eb4e..32a243330f 100644 --- a/tests/py/dynamo/conversion/test_arange_aten.py +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch_tensorrt from parameterized import parameterized from torch.testing._internal.common_utils import run_tests @@ -33,6 +34,29 @@ def forward(self, x): use_dynamo_tracer=True, ) + def test_arange_dynamic(self): + class Arange(nn.Module): + def forward(self, end_tensor): + return torch.ops.aten.arange.start_step(0, end_tensor, 1) + + pyt_input = 7 + inputs = [ + torch_tensorrt.Input( + min_shape=(5,), + opt_shape=(7,), + max_shape=(10,), + dtype=torch.int64, + torch_tensor=torch.tensor(pyt_input, dtype=torch.int64).cuda(), + is_shape_tensor=True, + ) + ] + self.run_test_with_dynamic_shape( + Arange(), + inputs, + use_example_tensors=False, + pyt_inputs=[pyt_input], + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py index 5c2a78a18a..8e7d8cef73 100644 --- a/tests/py/dynamo/conversion/test_bitwise_and_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -26,6 +26,7 @@ def forward(self, lhs_val, rhs_val): bitwise_and(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -46,6 +47,7 @@ def forward(self, tensor): bitwise_and(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -66,6 +68,7 @@ def forward(self, tensor): bitwise_and(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) diff --git a/tests/py/dynamo/conversion/test_bitwise_not_aten.py b/tests/py/dynamo/conversion/test_bitwise_not_aten.py index b811f1e51a..33d8629aff 100644 --- a/tests/py/dynamo/conversion/test_bitwise_not_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_not_aten.py @@ -25,6 +25,7 @@ def forward(self, val): bitwise_not(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) diff --git a/tests/py/dynamo/conversion/test_bitwise_or_aten.py b/tests/py/dynamo/conversion/test_bitwise_or_aten.py index b5e0200734..e912a9c473 100644 --- a/tests/py/dynamo/conversion/test_bitwise_or_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_or_aten.py @@ -26,6 +26,7 @@ def forward(self, lhs_val, rhs_val): bitwise_or(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -46,6 +47,7 @@ def forward(self, tensor): bitwise_or(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -66,6 +68,7 @@ def forward(self, tensor): bitwise_or(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) diff --git a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py index 8c1a8136ef..4bd2790bf9 100644 --- a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py @@ -26,6 +26,7 @@ def forward(self, lhs_val, rhs_val): bitwise_xor(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -46,6 +47,7 @@ def forward(self, tensor): bitwise_xor(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) @parameterized.expand( @@ -66,6 +68,7 @@ def forward(self, tensor): bitwise_xor(), inputs, enable_passes=True, + use_dynamo_tracer=True, ) diff --git a/tests/py/dynamo/conversion/test_convolution_aten.py b/tests/py/dynamo/conversion/test_convolution_aten.py index 7d69c871a9..95f4de92b5 100644 --- a/tests/py/dynamo/conversion/test_convolution_aten.py +++ b/tests/py/dynamo/conversion/test_convolution_aten.py @@ -1,7 +1,6 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests - from torch_tensorrt import Input from .harness import DispatchTestCase @@ -138,7 +137,7 @@ def forward(self, x): Input( shape=(-1, 3, -1, -1), dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + shape_ranges=[((1, 3, 1, 1), (2, 3, 4, 4), (32, 3, 128, 128))], ), ] self.run_test_with_dynamic_shape( @@ -201,7 +200,7 @@ def forward(self, x): Input( shape=(-1, 3, -1, -1, -1), dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + shape_ranges=[((1, 3, 1, 1, 1), (2, 3, 4, 4, 4), (8, 3, 32, 32, 32))], ), ] self.run_test_with_dynamic_shape( diff --git a/tests/py/dynamo/conversion/test_deconvolution_aten.py b/tests/py/dynamo/conversion/test_deconvolution_aten.py index 6024b6946e..307275dba1 100644 --- a/tests/py/dynamo/conversion/test_deconvolution_aten.py +++ b/tests/py/dynamo/conversion/test_deconvolution_aten.py @@ -1,7 +1,6 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests - from torch_tensorrt import Input from .harness import DispatchTestCase @@ -152,7 +151,7 @@ def forward(self, x): Input( shape=(-1, 3, -1, -1), dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + shape_ranges=[((1, 3, 1, 1), (2, 3, 4, 4), (32, 3, 128, 128))], ), ] self.run_test_with_dynamic_shape( @@ -221,7 +220,7 @@ def forward(self, x): Input( shape=(-1, 3, -1, -1, -1), dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + shape_ranges=[((1, 3, 1, 1, 1), (2, 3, 4, 4, 4), (8, 3, 32, 32, 32))], ), ] self.run_test_with_dynamic_shape( diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 2154937b43..9664e1be58 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -144,6 +144,7 @@ def forward(self, weight, indices): inputs=[weight, indices], precision=weight.dtype, enable_passes=True, + propagate_shapes=True, ) @parameterized.expand( @@ -340,6 +341,7 @@ def forward(self, weight, indices, offsets): inputs=[weight, indices, offsets], precision=weight.dtype, enable_passes=True, + propagate_shapes=True, ) @parameterized.expand( @@ -403,6 +405,7 @@ def forward(self, weight, indices, offsets): inputs=[weight, indices, offsets], precision=weight.dtype, enable_passes=True, + propagate_shapes=True, ) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index b6519815a4..de216de916 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -1,7 +1,6 @@ import unittest import pytest -import timm import torch import torch_tensorrt as torchtrt import torchvision.models as models diff --git a/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py b/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py index c23684646a..975f0b7ffa 100644 --- a/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py +++ b/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py @@ -27,7 +27,7 @@ def forward(self, a, b): # Inference on TRT Engine py_trt_module = PythonTorchTensorRTModule( - trt_engine_str, ["a", "b"], ["output0"] + trt_engine_str, ["arg0_1", "arg1_1"], ["output0"] ) trt_output = py_trt_module(input_data_0, input_data_1).cpu()