Skip to content

Added CPU offloading #3452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@
pre_export_lowering,
)
from torch_tensorrt.dynamo.utils import (
deallocate_module,
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
@@ -98,6 +99,7 @@ def cross_compile_for_windows(
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -362,7 +364,18 @@ def cross_compile_for_windows(
# Apply lowering on the graph module
gm = post_lowering(gm, settings)
logger.debug("Lowered Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(exported_program.module(), delete_module=False)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
else:
remaining_memory, total_memory = torch.cuda.mem_get_info()
if remaining_memory < total_memory // 2:
logger.warning(
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
)
trt_gm = compile_module(
gm,
trt_arg_inputs,
@@ -421,6 +434,7 @@ def compile(
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -498,6 +512,7 @@ def compile(
enable_weight_streaming (bool): Enable weight streaming.
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -550,15 +565,6 @@ def compile(
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
)

if (
not immutable_weights
and not refit_identical_engine_weights
and enable_weight_streaming
):
raise ValueError(
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
)

if (
"enable_cross_compile_for_windows" in kwargs.keys()
and kwargs["enable_cross_compile_for_windows"]
@@ -674,6 +680,7 @@ def compile(
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
}

settings = CompilationSettings(**compilation_options)
@@ -690,6 +697,18 @@ def compile(
gm = post_lowering(gm, settings)
logger.debug("Lowered Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(exported_program.module(), delete_module=False)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
else:
remaining_memory, total_memory = torch.cuda.mem_get_info()
if remaining_memory < total_memory // 2:
logger.warning(
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
@@ -820,6 +839,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
trt_modules = {}
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
@@ -833,6 +853,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
str(name),
str(submodule.graph),
)
submodule.to(to_torch_device(settings.device))
continue

if name not in submodule_node_dict:
@@ -964,6 +985,7 @@ def convert_exported_program_to_serialized_trt_engine(
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1147,6 +1169,7 @@ def convert_exported_program_to_serialized_trt_engine(
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
}

settings = CompilationSettings(**compilation_options)
@@ -1166,7 +1189,17 @@ def convert_exported_program_to_serialized_trt_engine(

# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

if offload_module_to_cpu:
deallocate_module(exported_program.module(), delete_module=False)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
else:
remaining_memory, total_memory = torch.cuda.mem_get_info()
if remaining_memory < total_memory // 2:
logger.warning(
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
)
try:
interpreter_result = interpret_module_to_result(
gm,
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@
TILING_OPTIMIZATION_LEVEL = "none"
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False


def default_device() -> Device:
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OFFLOAD_MODULE_TO_CPU,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -140,6 +141,7 @@ class CompilationSettings:
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@
get_trt_tensor,
to_torch,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

@@ -736,7 +736,8 @@ def run(
self._create_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)

if self.compilation_settings.offload_module_to_cpu:
deallocate_module(self.module)
serialized_engine = self.builder.build_serialized_network(
self.ctx.net, builder_config
)
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -84,13 +84,14 @@ class Frameworks(Enum):
}


def delete_module(module: torch.fx.GraphModule) -> None:
def deallocate_module(module: torch.fx.GraphModule, delete_module: bool = True) -> None:
"""
This is a helper function to delete the instance of module. We first move it to CPU and then
delete the object. This function ensures the GPU memory occupied by the module is released effectively after this call
"""
module.to(CPU_DEVICE)
del module
if delete_module:
del module
torch.cuda.empty_cache()
gc.collect()

114 changes: 113 additions & 1 deletion tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,11 @@
import torch
import torch_tensorrt as torchtrt
import torchvision.models as models
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
from torch_tensorrt.dynamo.utils import (
COSINE_THRESHOLD,
cosine_similarity,
get_model_device,
)

assertions = unittest.TestCase()

@@ -283,6 +287,53 @@ def test_resnet18(ir):
)


@pytest.mark.unit
def test_resnet18_cpu_offload(ir):
"""
This tests export save and load functionality on Resnet18 model
"""
model = models.resnet18().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"ir": ir,
"min_block_size": 1,
"cache_built_engines": False,
"reuse_cached_engines": False,
"offload_module_to_cpu": True,
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
assertions.assertTrue(
get_model_device(model).type == "cpu",
msg="Model should be offloaded to CPU",
)
model.cuda()
torchtrt.save(trt_module, trt_ep_path)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

outputs_trt_deser = deser_trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_resnet18_dynamic(ir):
"""
@@ -381,6 +432,67 @@ def forward(self, x):
)


@pytest.mark.unit
def test_hybrid_conv_fallback_cpu_offload(ir):
"""
This tests export save and load functionality on a hybrid
model where a conv (a weighted layer) has been forced to fallback to Pytorch.
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
conv = self.conv(x)
relu = self.relu(conv)
mul = relu * 0.5
return mul

model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"ir": ir,
"min_block_size": 1,
"torch_executed_ops": {"torch.ops.aten.convolution.default"},
"cache_built_engines": False,
"reuse_cached_engines": False,
"offload_module_to_cpu": True,
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
model.cuda()
torchtrt.save(trt_module, trt_ep_path)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)

for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

outputs_trt_deser = deser_trt_module(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_arange_export(ir):
"""
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.