From d2babd179562de4cfd87481d973c527b05390263 Mon Sep 17 00:00:00 2001 From: Chen Fu Date: Mon, 21 Jul 2025 21:27:42 +0000 Subject: [PATCH 1/6] skipped unnecessary broadcast --- .../dynamo/conversion/impl/elementwise/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index ab9629b0db..097a81b8d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -1,8 +1,8 @@ +import logging import operator import warnings from typing import Any, Callable, Optional, Union -import numpy as np import tensorrt as trt import torch from torch.fx.node import Target @@ -20,6 +20,8 @@ ) from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor +logger = logging.getLogger(__name__) + def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, @@ -148,7 +150,11 @@ def convert_binary_elementwise( ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir ) - if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape): + if len(lhs_val.shape) == len(rhs_val.shape) and all( + a == b or a == 1 or b == 1 for a, b in zip(lhs_val.shape, rhs_val.shape) + ): + logger.info(f"skip broadcast for {name}") + elif has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape): lhs_val, rhs_val = broadcast( ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs" ) From 88c1a774caaeffec5db5b1ee476c8751fb193cb6 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 22 Jul 2025 22:26:25 +0000 Subject: [PATCH 2/6] Fixed SDPA perf gap --- .../dynamo/conversion/aten_ops_converters.py | 2 +- .../dynamo/lowering/_decomposition_groups.py | 1 + .../dynamo/lowering/_decompositions.py | 14 +++++++++----- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f1a7f9a8fc..9ac81d2981 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -531,7 +531,7 @@ def aten_ops_gelu( ) -@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.matmul.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 91ac0476d2..9d28ae70a5 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -170,6 +170,7 @@ aten.upsample_trilinear3d.vec, aten.upsample_bicubic2d.vec, aten.linear.default, + aten.matmul.default, } diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index b0cfdee4f0..1cba827d48 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -463,11 +463,13 @@ def scaled_dot_product_attention_decomposition( ) -> torch.Tensor: L, S = query.size(-2), key.size(-2) device = query.device - attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device) + + if is_causal or attn_mask is not None: + attn_bias = torch.zeros((L, S), dtype=query.dtype, device=device) if is_causal: assert attn_mask is None, "attn_mask must be None when is_causal=True" - temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0) + temp_mask = torch.ones((L, S), dtype=torch.bool, device=device).tril(diagonal=0) attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf")) if attn_mask is not None: @@ -480,7 +482,7 @@ def scaled_dot_product_attention_decomposition( key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) - attn_weight = query @ key.transpose(-2, -1) + attn_weight = torch.matmul(query, key.transpose(-2, -1)) if scale is None: scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)) @@ -488,9 +490,11 @@ def scaled_dot_product_attention_decomposition( else: attn_weight = attn_weight * scale - attn_weight = attn_weight + attn_bias + if is_causal or attn_mask is not None: + attn_weight = attn_weight + attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) - return attn_weight @ value + return torch.matmul(attn_weight, value) @register_torch_trt_decomposition( From 799a5d5c05f16ee7ee1ebe5b3fd034a3141963af Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 23 Jul 2025 22:57:29 +0000 Subject: [PATCH 3/6] Added comments --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 1cba827d48..baeb92d22e 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -491,6 +491,7 @@ def scaled_dot_product_attention_decomposition( attn_weight = attn_weight * scale if is_causal or attn_mask is not None: + # We only add attn_bias when we have to, otherwise this will have a negative impact on the performance even it's 0. attn_weight = attn_weight + attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) From b00ddecb6cf4bc8aea11c04b1f5915f1177c9d9c Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 23 Jul 2025 23:22:57 +0000 Subject: [PATCH 4/6] Added fp32 matmul around matmul node --- .../lowering/passes/accumulate_fp32_matmul.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py index e569c45cfa..282693d299 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py @@ -10,6 +10,18 @@ def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Splits all `torch.ops.aten.addmm.default` nodes in the FX graph into separate + `add` and `mm` nodes. This is useful for passes that want to insert additional + logic (such as FP32 accumulation) specifically around the matrix multiplication + operation, rather than the fused addmm. + + Args: + gm (torch.fx.GraphModule): The FX graph module to transform. + + Returns: + torch.fx.GraphModule: The modified FX graph module with addmm nodes split. + """ target = torch.ops.aten.addmm.default addmm_nodes = [node for node in gm.graph.nodes if node.target == target] for addmm_node in addmm_nodes: @@ -52,6 +64,7 @@ def accumulate_fp32_matmul( matmul_targets = [ torch.ops.aten.mm.default, torch.ops.aten.bmm.default, + torch.ops.aten.matmul.default, ] # Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes From 8855fc85465f8c19a792075c7388e67108b5a2ae Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 24 Jul 2025 21:57:55 +0000 Subject: [PATCH 5/6] Added updated benchmark script --- examples/apps/flux_demo.py | 2 +- tools/perf/Flux/benchmark.sh | 13 ++++++++++++- tools/perf/Flux/flux_perf.py | 20 +++++++++++++++++++- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index c0028d1459..5220f38ec6 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -120,7 +120,7 @@ def forward_loop(mod): "enabled_precisions": enabled_precisions, "truncate_double": True, "min_block_size": 1, - "use_python_runtime": False, + "use_python_runtime": True, "immutable_weights": False, "offload_module_to_cpu": args.low_vram_mode, "use_explicit_typing": use_explicit_typing, diff --git a/tools/perf/Flux/benchmark.sh b/tools/perf/Flux/benchmark.sh index 79f5e4b66c..3b29ac0989 100644 --- a/tools/perf/Flux/benchmark.sh +++ b/tools/perf/Flux/benchmark.sh @@ -1,9 +1,20 @@ #TODO: Enter the HF Token huggingface-cli login --token HF_TOKEN +nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> pytorch_fp16_gpu_utilization.txt & +NVIDIA_SMI_PID=$! +python flux_perf.py --pytorch --max_batch_size 3 > pytorch_fp16_benchmark.txt +kill $NVIDIA_SMI_PID + nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp8_gpu_utilization.txt & NVIDIA_SMI_PID=$! -python flux_perf.py --dtype fp8 --low_vram_mode> fp8_benchmark.txt +python flux_perf.py --dtype fp8 --max_batch_size 3 > fp8_benchmark.txt +kill $NVIDIA_SMI_PID + + +nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp16_gpu_utilization.txt & +NVIDIA_SMI_PID=$! +python flux_perf.py --dtype fp16 --max_batch_size 3 > fp16_benchmark.txt kill $NVIDIA_SMI_PID diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index 1d3b2acbbc..969f4c93d8 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -44,9 +44,22 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): return +from diffusers import FluxPipeline + + def main(args): print(f"Running flux_perfwith args: {args}") - pipe, backbone, trt_gm = compile_model(args) + if not args.pytorch: + pipe, backbone, trt_gm = compile_model(args) + else: + pipe = ( + FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, + ) + .to(torch.float16) + .to("cuda:0") + ) benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3) @@ -83,6 +96,11 @@ def main(args): action="store_true", help="Use dynamic shapes", ) + parser.add_argument( + "--pytorch", + action="store_true", + help="Use pytorch runtime and no tensorrt", + ) parser.add_argument("--max_batch_size", type=int, default=1) args = parser.parse_args() main(args) From f8bc2d56e93e755ce8a0d5c374717a2e8d325a50 Mon Sep 17 00:00:00 2001 From: Chen Fu Date: Tue, 29 Jul 2025 20:25:12 +0000 Subject: [PATCH 6/6] Added initial implementation --- core/runtime/TRTEngine.cpp | 37 ++++++++++++++++--- core/runtime/TRTEngine.h | 16 ++++++-- core/runtime/execute_engine.cpp | 6 +++ core/runtime/register_jit_hooks.cpp | 23 ++++++++++++ core/runtime/runtime.h | 4 ++ examples/dynamo/dynamic_memory_allocation.py | 36 ++++++++++++++++++ .../dynamo/runtime/_ResourceAllocator.py | 30 +++++++++++++++ .../dynamo/runtime/_TorchTensorRTModule.py | 12 +++++- py/torch_tensorrt/dynamo/runtime/__init__.py | 3 ++ 9 files changed, 157 insertions(+), 10 deletions(-) create mode 100644 examples/dynamo/dynamic_memory_allocation.py create mode 100644 py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 52a9b47c12..aabb40c6dd 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -61,7 +61,8 @@ TRTEngine::TRTEngine( const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata) + const std::string& serialized_metadata, + const ResourceAllocationStrategy& resource_allocation_strategy) : TRTEngine( "deserialized_trt", serialized_engine, @@ -71,7 +72,8 @@ TRTEngine::TRTEngine( target_platform, hardware_compatible, requires_output_allocator, - serialized_metadata) {} + serialized_metadata, + resource_allocation_strategy) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -83,7 +85,8 @@ TRTEngine::TRTEngine(std::vector serialized_info) Platform(serialized_info[TARGET_PLATFORM_IDX]), static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX])), static_cast(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])), - serialized_info[SERIALIZED_METADATA_IDX]) {} + serialized_info[SERIALIZED_METADATA_IDX], + resource_allocation_strategy_from_string(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) {} TRTEngine::TRTEngine( const std::string& mod_name, @@ -94,7 +97,8 @@ TRTEngine::TRTEngine( const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata) { + const std::string& serialized_metadata, + const ResourceAllocationStrategy& resource_allocation_strategy) { TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -124,7 +128,12 @@ TRTEngine::TRTEngine( cuda_engine->setWeightStreamingBudgetV2(budget_bytes); } - exec_ctx = make_trt(cuda_engine->createExecutionContext()); + if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { + this->exec_ctx = + make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE)); + } else { + this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); + } TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); runtime_states.old_cudagraphs = CUDAGRAPHS_MODE; @@ -436,7 +445,8 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]), std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]), std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), - std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX])); + std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), + std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])); } std::vector TRTEngine::serialize() { @@ -459,6 +469,8 @@ std::vector TRTEngine::serialize() { serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0"; serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata; serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); + serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = + resource_allocation_strategy_to_string(this->resource_allocation_strategy); return serialized_info; } @@ -467,6 +479,19 @@ void TRTEngine::reset_captured_graph() { cudagraph.reset(); } +void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { + if (new_strategy != this->resource_allocation_strategy) { + this->resource_allocation_strategy = new_strategy; + if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { + std::cout << "Setting resource allocation strategy to dynamic" << std::endl; + this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); + } else { + this->exec_ctx = make_trt( + cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE)); + } + } +} + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 15d723ce4e..9c77ab325a 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -29,7 +29,8 @@ using FlattenedState = std::tuple< std::tuple, // HW compatibility std::tuple, // requires_output_allocator std::tuple, // serialized metadata - std::tuple>; // Platform + std::tuple, // Platform + std::tuple>; // Resource Allocation Strategy struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -98,6 +99,8 @@ class DynamicOutputAllocator : public nvinfer1::IOutputAllocator { }; struct TRTEngine : torch::CustomClassHolder { + // Resource Allocation Strategy + enum ResourceAllocationStrategy { kStatic, kDynamic }; // Each engine needs it's own runtime object std::shared_ptr rt; std::shared_ptr cuda_engine; @@ -128,7 +131,9 @@ struct TRTEngine : torch::CustomClassHolder { const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = ""); + const std::string& serialized_metadata = "", + const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy = + TRTEngine::ResourceAllocationStrategy::kStatic); TRTEngine(std::vector serialized_info); @@ -141,7 +146,9 @@ struct TRTEngine : torch::CustomClassHolder { const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = ""); + const std::string& serialized_metadata = "", + const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy = + TRTEngine::ResourceAllocationStrategy::kStatic); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; @@ -200,6 +207,9 @@ struct TRTEngine : torch::CustomClassHolder { std::string cuda_graph_debug_path; std::mutex mu; std::unique_ptr trt_engine_profiler; + ResourceAllocationStrategy resource_allocation_strategy = kStatic; + void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); + ResourceAllocationStrategy get_resource_allocation_strategy(); }; } // namespace runtime diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..d36cc98c80 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -201,6 +201,12 @@ void create_output_allocator(c10::intrusive_ptr compiled_engine) { } std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { + torch::Tensor dynamic_workspace; + if (compiled_engine->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { + dynamic_workspace = torch::empty(compiled_engine->cuda_engine->getDeviceMemorySizeV2(), {torch::kCUDA}); + compiled_engine->exec_ctx->setDeviceMemory(dynamic_workspace.data_ptr()); + } + auto run_standard_execution = [&]() { bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); bool shape_changed = _validate_shapes(inputs, compiled_engine); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 173ff8c35f..99633a4e47 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -22,6 +22,21 @@ std::string serialize_bindings(const std::vector& bindings) { return serialized_binding_info; } +std::string resource_allocation_strategy_to_string(TRTEngine::ResourceAllocationStrategy strategy) { + if (strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { + return std::string("kDynamic"); + } else { + return std::string("kStatic"); + } +} + +TRTEngine::ResourceAllocationStrategy resource_allocation_strategy_from_string(const std::string& str) { + if (str == "kDynamic") + return TRTEngine::ResourceAllocationStrategy::kDynamic; + else + return TRTEngine::ResourceAllocationStrategy::kStatic; +} + static const std::string sym_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; //= std::string base64_encode(const std::string& in) { std::string out; @@ -90,6 +105,13 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) .def("reset_captured_graph", &TRTEngine::reset_captured_graph) + .def( + "_use_dynamically_allocated_resources", + [](const c10::intrusive_ptr& self, bool dynamic) -> void { + self->set_resource_allocation_strategy( + dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic + : TRTEngine::ResourceAllocationStrategy::kStatic); + }) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( @@ -135,6 +157,7 @@ TORCH_LIBRARY(tensorrt, m) { m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; }); m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); + m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 894df55bfe..233b4bb274 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -38,6 +38,7 @@ typedef enum { SERIALIZED_METADATA_IDX, TARGET_PLATFORM_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, + RESOURCE_ALLOCATION_STRATEGY_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; @@ -45,6 +46,9 @@ std::string base64_encode(const std::string& in); std::string base64_decode(const std::string& in); std::string serialize_bindings(const std::vector& bindings); +std::string resource_allocation_strategy_to_string(TRTEngine::ResourceAllocationStrategy strategy); +TRTEngine::ResourceAllocationStrategy resource_allocation_strategy_from_string(const std::string& str); + c10::optional get_most_compatible_device( const RTDevice& target_device, const RTDevice& curr_device = RTDevice(), diff --git a/examples/dynamo/dynamic_memory_allocation.py b/examples/dynamo/dynamic_memory_allocation.py new file mode 100644 index 0000000000..be7bc7e1bd --- /dev/null +++ b/examples/dynamo/dynamic_memory_allocation.py @@ -0,0 +1,36 @@ +# %% +import numpy as np +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models +from diffusers import DiffusionPipeline + +np.random.seed(5) +torch.manual_seed(5) +inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] + +settings = { + "ir": "dynamo", + "use_python_runtime": False, + "enabled_precisions": {torch.float32}, + "immutable_weights": False, +} + +model = models.resnet152(pretrained=True).eval().to("cuda") +compiled_module = torch_trt.compile(model, inputs=inputs, **settings) +print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3) +compiled_module(*inputs) + +breakpoint() +with torch_trt.dynamo.runtime.ResourceAllocatorContext(compiled_module): + print( + "Memory used (GB):", + (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3, + ) + breakpoint() + compiled_module(*inputs) + print( + "Memory used (GB):", + (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3, + ) + breakpoint() diff --git a/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py b/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py new file mode 100644 index 0000000000..5c72d4e180 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py @@ -0,0 +1,30 @@ +from typing import Any + +import torch + + +class ResourceAllocatorContext(torch.nn.Module): # type: ignore[misc] + """ + ResourceAllocatorContext is a context manager module that temporarily enables dynamic resource allocation + for all TRT submodules of the given compiled_module. When entering the context, + it sets these submodules to use dynamically allocated resources. Upon exiting, it restores them to their + original (static) resource allocation mode. + """ + + def __init__( + self, + compiled_module: torch.nn.Module, + ) -> None: + super(ResourceAllocatorContext, self).__init__() + self.compiled_module = compiled_module + + def __enter__(self) -> None: + print("Entering resource allocator context") + for name, submodule in self.compiled_module.named_modules(): + if "_run_on_acc" in name: + submodule.use_dynamically_allocated_resources(dynamic=True) + + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: + for name, submodule in self.compiled_module.named_modules(): + if "_run_on_acc" in name: + submodule.use_dynamically_allocated_resources(dynamic=False) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..c5929c16a7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -50,7 +50,10 @@ REQUIRES_OUTPUT_ALLOCATOR_IDX = ( torch.ops.tensorrt.REQUIRES_OUTPUT_ALLOCATOR_IDX() ) # 9 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 10 + RESOURCE_ALLOCATION_STRATEGY_IDX = ( + torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX() + ) # 10 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 11 @for_all_methods(needs_torch_tensorrt_runtime) @@ -139,6 +142,7 @@ def __init__( self.serialized_engine = serialized_engine self.engine = None self.requires_output_allocator = requires_output_allocator + self.resource_allocation_strategy = 0 # Default to static allocation TODO: Make this configurable with the context manager if ( serialized_engine @@ -184,6 +188,9 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = str( int(self.requires_output_allocator) ) + engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( + int(self.resource_allocation_strategy) + ) return engine_info @@ -212,6 +219,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: def _reset_captured_graph(self) -> None: self.engine.reset_captured_graph() + def use_dynamically_allocated_resources(self, dynamic: bool = False) -> None: + self.engine._use_dynamically_allocated_resources(dynamic) + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. diff --git a/py/torch_tensorrt/dynamo/runtime/__init__.py b/py/torch_tensorrt/dynamo/runtime/__init__.py index de47d942e9..19843a0a54 100644 --- a/py/torch_tensorrt/dynamo/runtime/__init__.py +++ b/py/torch_tensorrt/dynamo/runtime/__init__.py @@ -2,6 +2,9 @@ from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( # noqa: F401 PythonTorchTensorRTModule, ) +from torch_tensorrt.dynamo.runtime._ResourceAllocator import ( # noqa: F401 + ResourceAllocatorContext, +) from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401 TorchTensorRTModule, )