Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ TRTEngine::TRTEngine(
const Platform& target_platform,
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata)
const std::string& serialized_metadata,
const ResourceAllocationStrategy& resource_allocation_strategy)
: TRTEngine(
"deserialized_trt",
serialized_engine,
Expand All @@ -71,7 +72,8 @@ TRTEngine::TRTEngine(
target_platform,
hardware_compatible,
requires_output_allocator,
serialized_metadata) {}
serialized_metadata,
resource_allocation_strategy) {}

TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
: TRTEngine(
Expand All @@ -83,7 +85,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
Platform(serialized_info[TARGET_PLATFORM_IDX]),
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
serialized_info[SERIALIZED_METADATA_IDX]) {}
serialized_info[SERIALIZED_METADATA_IDX],
resource_allocation_strategy_from_string(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) {}

TRTEngine::TRTEngine(
const std::string& mod_name,
Expand All @@ -94,7 +97,8 @@ TRTEngine::TRTEngine(
const Platform& target_platform,
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata) {
const std::string& serialized_metadata,
const ResourceAllocationStrategy& resource_allocation_strategy) {
TORCHTRT_CHECK(
is_supported_on_current_platform(target_platform),
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
Expand Down Expand Up @@ -124,7 +128,12 @@ TRTEngine::TRTEngine(
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
}

exec_ctx = make_trt(cuda_engine->createExecutionContext());
if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
this->exec_ctx =
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE));
} else {
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
}
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");

runtime_states.old_cudagraphs = CUDAGRAPHS_MODE;
Expand Down Expand Up @@ -436,7 +445,8 @@ FlattenedState TRTEngine::__obj_flatten__() {
std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]),
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]));
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]),
std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]));
}

std::vector<std::string> TRTEngine::serialize() {
Expand All @@ -459,6 +469,8 @@ std::vector<std::string> TRTEngine::serialize() {
serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
resource_allocation_strategy_to_string(this->resource_allocation_strategy);

return serialized_info;
}
Expand All @@ -467,6 +479,19 @@ void TRTEngine::reset_captured_graph() {
cudagraph.reset();
}

void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) {
if (new_strategy != this->resource_allocation_strategy) {
this->resource_allocation_strategy = new_strategy;
if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
std::cout << "Setting resource allocation strategy to dynamic" << std::endl;
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
} else {
this->exec_ctx = make_trt(
cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE));
}
}
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
16 changes: 13 additions & 3 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // HW compatibility
std::tuple<std::string, std::string>, // requires_output_allocator
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>>; // Platform
std::tuple<std::string, std::string>, // Platform
std::tuple<std::string, std::string>>; // Resource Allocation Strategy

struct TorchTRTRuntimeStates {
// Indicates whether CUDAGraphs were enabled in the previous execute_engine
Expand Down Expand Up @@ -98,6 +99,8 @@ class DynamicOutputAllocator : public nvinfer1::IOutputAllocator {
};

struct TRTEngine : torch::CustomClassHolder {
// Resource Allocation Strategy
enum ResourceAllocationStrategy { kStatic, kDynamic };
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
Expand Down Expand Up @@ -128,7 +131,9 @@ struct TRTEngine : torch::CustomClassHolder {
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
bool requires_output_allocator = false,
const std::string& serialized_metadata = "");
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);

TRTEngine(std::vector<std::string> serialized_info);

Expand All @@ -141,7 +146,9 @@ struct TRTEngine : torch::CustomClassHolder {
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
bool requires_output_allocator = false,
const std::string& serialized_metadata = "");
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);

TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
Expand Down Expand Up @@ -200,6 +207,9 @@ struct TRTEngine : torch::CustomClassHolder {
std::string cuda_graph_debug_path;
std::mutex mu;
std::unique_ptr<TRTEngineProfiler> trt_engine_profiler;
ResourceAllocationStrategy resource_allocation_strategy = kStatic;
void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy);
ResourceAllocationStrategy get_resource_allocation_strategy();
};

} // namespace runtime
Expand Down
6 changes: 6 additions & 0 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ void create_output_allocator(c10::intrusive_ptr<TRTEngine> compiled_engine) {
}

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
torch::Tensor dynamic_workspace;
if (compiled_engine->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
dynamic_workspace = torch::empty(compiled_engine->cuda_engine->getDeviceMemorySizeV2(), {torch::kCUDA});
compiled_engine->exec_ctx->setDeviceMemory(dynamic_workspace.data_ptr());
}

auto run_standard_execution = [&]() {
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
bool shape_changed = _validate_shapes(inputs, compiled_engine);
Expand Down
23 changes: 23 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ std::string serialize_bindings(const std::vector<std::string>& bindings) {
return serialized_binding_info;
}

std::string resource_allocation_strategy_to_string(TRTEngine::ResourceAllocationStrategy strategy) {
if (strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
return std::string("kDynamic");
} else {
return std::string("kStatic");
}
}

TRTEngine::ResourceAllocationStrategy resource_allocation_strategy_from_string(const std::string& str) {
if (str == "kDynamic")
return TRTEngine::ResourceAllocationStrategy::kDynamic;
else
return TRTEngine::ResourceAllocationStrategy::kStatic;
}

static const std::string sym_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; //=
std::string base64_encode(const std::string& in) {
std::string out;
Expand Down Expand Up @@ -90,6 +105,13 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def(
"_use_dynamically_allocated_resources",
[](const c10::intrusive_ptr<TRTEngine>& self, bool dynamic) -> void {
self->set_resource_allocation_strategy(
dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic
: TRTEngine::ResourceAllocationStrategy::kStatic);
})
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
.def_property(
Expand Down Expand Up @@ -135,6 +157,7 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; });
m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; });
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; });
m.def("_platform_linux_x86_64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64);
return it->second;
Expand Down
4 changes: 4 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@ typedef enum {
SERIALIZED_METADATA_IDX,
TARGET_PLATFORM_IDX,
REQUIRES_OUTPUT_ALLOCATOR_IDX,
RESOURCE_ALLOCATION_STRATEGY_IDX,
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

std::string base64_encode(const std::string& in);
std::string base64_decode(const std::string& in);
std::string serialize_bindings(const std::vector<std::string>& bindings);

std::string resource_allocation_strategy_to_string(TRTEngine::ResourceAllocationStrategy strategy);
TRTEngine::ResourceAllocationStrategy resource_allocation_strategy_from_string(const std::string& str);

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
const RTDevice& curr_device = RTDevice(),
Expand Down
2 changes: 1 addition & 1 deletion examples/apps/flux_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def forward_loop(mod):
"enabled_precisions": enabled_precisions,
"truncate_double": True,
"min_block_size": 1,
"use_python_runtime": False,
"use_python_runtime": True,
"immutable_weights": False,
"offload_module_to_cpu": args.low_vram_mode,
"use_explicit_typing": use_explicit_typing,
Expand Down
36 changes: 36 additions & 0 deletions examples/dynamo/dynamic_memory_allocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# %%
import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from diffusers import DiffusionPipeline

np.random.seed(5)
torch.manual_seed(5)
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]

settings = {
"ir": "dynamo",
"use_python_runtime": False,
"enabled_precisions": {torch.float32},
"immutable_weights": False,
}

model = models.resnet152(pretrained=True).eval().to("cuda")
compiled_module = torch_trt.compile(model, inputs=inputs, **settings)
print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3)
compiled_module(*inputs)

breakpoint()
with torch_trt.dynamo.runtime.ResourceAllocatorContext(compiled_module):
print(
"Memory used (GB):",
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
)
breakpoint()
compiled_module(*inputs)
print(
"Memory used (GB):",
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
)
breakpoint()
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def aten_ops_gelu(
)


@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.matmul.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)
Expand Down
10 changes: 8 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import operator
import warnings
from typing import Any, Callable, Optional, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
Expand All @@ -20,6 +20,8 @@
)
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor

logger = logging.getLogger(__name__)


def get_python_op_from_trt_elementwise_op(
trt_op: TRTElementWiseOp,
Expand Down Expand Up @@ -148,7 +150,11 @@ def convert_binary_elementwise(
ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir
)

if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
if len(lhs_val.shape) == len(rhs_val.shape) and all(
a == b or a == 1 or b == 1 for a, b in zip(lhs_val.shape, rhs_val.shape)
):
logger.info(f"skip broadcast for {name}")
elif has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
lhs_val, rhs_val = broadcast(
ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
aten.upsample_trilinear3d.vec,
aten.upsample_bicubic2d.vec,
aten.linear.default,
aten.matmul.default,
}


Expand Down
15 changes: 10 additions & 5 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,11 +463,13 @@ def scaled_dot_product_attention_decomposition(
) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
device = query.device
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)

if is_causal or attn_mask is not None:
attn_bias = torch.zeros((L, S), dtype=query.dtype, device=device)

if is_causal:
assert attn_mask is None, "attn_mask must be None when is_causal=True"
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
temp_mask = torch.ones((L, S), dtype=torch.bool, device=device).tril(diagonal=0)
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))

if attn_mask is not None:
Expand All @@ -480,17 +482,20 @@ def scaled_dot_product_attention_decomposition(
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1)
attn_weight = torch.matmul(query, key.transpose(-2, -1))

if scale is None:
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
attn_weight = attn_weight / scale
else:
attn_weight = attn_weight * scale

attn_weight = attn_weight + attn_bias
if is_causal or attn_mask is not None:
# We only add attn_bias when we have to, otherwise this will have a negative impact on the performance even it's 0.
attn_weight = attn_weight + attn_bias

attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value
return torch.matmul(attn_weight, value)


@register_torch_trt_decomposition(
Expand Down
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@


def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Splits all `torch.ops.aten.addmm.default` nodes in the FX graph into separate
`add` and `mm` nodes. This is useful for passes that want to insert additional
logic (such as FP32 accumulation) specifically around the matrix multiplication
operation, rather than the fused addmm.

Args:
gm (torch.fx.GraphModule): The FX graph module to transform.

Returns:
torch.fx.GraphModule: The modified FX graph module with addmm nodes split.
"""
target = torch.ops.aten.addmm.default
addmm_nodes = [node for node in gm.graph.nodes if node.target == target]
for addmm_node in addmm_nodes:
Expand Down Expand Up @@ -52,6 +64,7 @@ def accumulate_fp32_matmul(
matmul_targets = [
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.matmul.default,
]

# Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes
Expand Down
Loading
Loading