Skip to content

Commit 48a7c94

Browse files
committedMar 18, 2025
Refined Flux demo, solved a bug of device mismatch, and prototyped CudaGraph and Weight streaming
1 parent 18918ff commit 48a7c94

File tree

3 files changed

+84
-16
lines changed

3 files changed

+84
-16
lines changed
 

‎examples/apps/flux-demo.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,20 @@
4343
"debug": False,
4444
"use_python_runtime": True,
4545
"immutable_weights": False,
46+
# "cache_built_engines": True,
47+
# "reuse_cached_engines": True,
48+
# "timing_cache_path": "/home/engine_cache/flux.bin",
49+
# "engine_cache_size": 40 * 1 << 30,
50+
# "enable_weight_streaming": False,
51+
# "enable_cuda_graph": True,
4652
}
4753

4854
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
4955
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
5056
pipe.transformer = trt_gm
5157

5258

53-
def generate_image(prompt, inference_step, batch_size=1):
59+
def generate_image(prompt, inference_step, batch_size=2):
5460
image = pipe(
5561
prompt,
5662
output_type="pil",
@@ -60,7 +66,8 @@ def generate_image(prompt, inference_step, batch_size=1):
6066
return image
6167

6268

63-
generate_image(["A golden retriever holding a sign to code"], 2)
69+
generate_image(["Test"], 2)
70+
torch.cuda.empty_cache()
6471

6572

6673
def model_change(model):
@@ -76,14 +83,20 @@ def model_change(model):
7683
def load_lora(path):
7784

7885
pipe.load_lora_weights(
79-
path,
86+
"/home/TensorRT/examples/apps/NGRVNG.safetensors",
8087
adapter_name="lora1",
8188
)
8289
pipe.set_adapters(["lora1"], adapter_weights=[1])
8390
pipe.fuse_lora()
8491
pipe.unload_lora_weights()
85-
print("LoRA loaded!")
92+
print("LoRA loaded! Begin refitting")
93+
generate_image(["Test"], 2)
94+
print("Refitting Finished!")
95+
8696

97+
generate_image(["Test"], 2)
98+
load_lora("")
99+
generate_image(["A golden retriever holding a sign to code"], 2)
87100

88101
# Create Gradio interface
89102
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
@@ -103,7 +116,8 @@ def load_lora(path):
103116

104117
lora_upload_path = gr.Textbox(
105118
label="LoRA Path",
106-
placeholder="/home/TensorRT/examples/apps/NGRVNG.safetensors",
119+
placeholder="Enter the LoRA checkpoint path here",
120+
value="/home/TensorRT/examples/apps/NGRVNG.safetensors",
107121
lines=2,
108122
)
109123
num_steps = gr.Slider(

‎py/torch_tensorrt/dynamo/_refit.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -317,37 +317,62 @@ def refit_module_weights(
317317

318318
new_gm = post_lowering(new_gm, settings)
319319

320-
logger.info("Compilation Settings: %s\n", settings)
320+
logger.debug("Lowered Input graph: " + str(new_gm.graph))
321321

322322
# Set torch-executed ops
323-
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)
323+
CONVERTERS.set_compilation_settings(settings)
324+
325+
# Check the number of supported operations in the graph
326+
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
327+
new_gm, settings.debug, settings.torch_executed_ops
328+
)
329+
330+
if num_supported_ops == 0 or (
331+
num_supported_ops < settings.min_block_size and not settings.dryrun
332+
):
333+
logger.warning(
334+
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
335+
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
336+
)
337+
return new_gm
338+
else:
339+
logger.debug(
340+
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
341+
)
324342

325343
# If specified, try using the fast partitioner and fall back to the global one on failure
326344
if settings.use_fast_partitioner:
327345
try:
346+
logger.info("Partitioning the graph via the fast partitioner")
328347
new_partitioned_module, supported_ops = partitioning.fast_partition(
329348
new_gm,
330349
verbose=settings.debug,
331350
min_block_size=settings.min_block_size,
332351
torch_executed_ops=settings.torch_executed_ops,
352+
require_full_compilation=settings.require_full_compilation,
353+
skip_fusion=(num_supported_ops == total_ops),
333354
)
355+
334356
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
335357
logger.error(
336358
"Partitioning failed on the subgraph with fast partition. See trace above. "
337-
+ "Retrying with global partition.",
359+
"Retrying with global partition.",
338360
exc_info=True,
339361
)
340362

341363
settings.use_fast_partitioner = False
342364

343365
if not settings.use_fast_partitioner:
366+
logger.info("Partitioning the graph via the global partitioner")
344367
new_partitioned_module, supported_ops = partitioning.global_partition(
345368
new_gm,
346369
verbose=settings.debug,
347370
min_block_size=settings.min_block_size,
348371
torch_executed_ops=settings.torch_executed_ops,
372+
require_full_compilation=settings.require_full_compilation,
349373
)
350374

375+
# Done Partition
351376
if inline_module:
352377
# Preprocess the partitioned module to be in the same format as the inline module
353378
inline_torch_modules(new_partitioned_module)
@@ -495,6 +520,12 @@ def refit_module_weights(
495520
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
496521
setattr(compiled_module, f"{name}_engine", refitted_engine)
497522

523+
# TODO: Memory control prototyping. Under discussion
524+
if settings.offload_module_to_cpu:
525+
del new_partitioned_module
526+
gc.collect()
527+
torch.cuda.empty_cache()
528+
498529
if verify_output and arg_inputs is not None:
499530
if check_module_output(
500531
new_module=new_gm,

‎py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import torch
9+
import torch_tensorrt
910
from torch_tensorrt._Device import Device
1011
from torch_tensorrt.dynamo import _defaults
1112
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
@@ -61,6 +62,7 @@ def __init__(
6162
*,
6263
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
6364
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
65+
enable_cuda_graph: bool = True,
6466
immutable_weights: bool = False,
6567
strict: bool = True,
6668
allow_complex_guards_as_runtime_asserts: bool = False,
@@ -127,6 +129,7 @@ def __init__(
127129
self.arg_inputs: tuple[Any, ...] = tuple()
128130
self.kwarg_inputs: dict[str, Any] = {}
129131
self.additional_settings = kwargs
132+
self.enable_cuda_graph = enable_cuda_graph
130133
self.strict = strict
131134
self.allow_complex_guards_as_runtime_asserts = (
132135
allow_complex_guards_as_runtime_asserts
@@ -142,7 +145,11 @@ def __init__(
142145
self.run_info: Optional[tuple[Any, ...]] = None
143146
self.state_dict_metadata: dict[str, torch.Size] = {}
144147
self._store_state_dict_metadata()
145-
148+
self.enable_weight_streaming = (
149+
kwargs["enable_weight_streaming"]
150+
if "enable_weight_streaming" in kwargs
151+
else False
152+
)
146153
cls = self.__class__
147154
self.__class__ = type(
148155
self.original_model.__class__.__name__,
@@ -193,7 +200,7 @@ def forward(a, b, c=0, d=0):
193200

194201
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
195202

196-
def _get_total_dynamic_shapes(self) -> Union[dict[str, Any], None]:
203+
def _get_total_dynamic_shapes(self) -> dict[str, Any] | None:
197204
if not self.arg_dynamic_shapes and not self.kwarg_dynamic_shapes:
198205
return None
199206
total_dynamic_shape = {}
@@ -266,15 +273,17 @@ def refit_gm(self) -> None:
266273
MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module.
267274
If it fails to catch the changes, please call this function manually to update the TRT graph module.
268275
"""
269-
self.original_model.to(to_torch_device(self.trt_device))
276+
270277
if self.exp_program is None:
278+
self.original_model.to(to_torch_device(self.trt_device))
271279
self.exp_program = self.get_exported_program()
272280
else:
273281
self.exp_program._state_dict = (
274282
MutableTorchTensorRTModule._transform_state_dict(
275283
self.original_model.state_dict()
276284
)
277285
)
286+
self.exp_program.module().to(to_torch_device(self.trt_device))
278287
self.gm = refit_module_weights(
279288
self.gm,
280289
self.exp_program,
@@ -284,7 +293,7 @@ def refit_gm(self) -> None:
284293
in_place=True,
285294
)
286295

287-
self.original_model.cpu()
296+
self.original_model.to("cpu")
288297
torch.cuda.empty_cache()
289298

290299
def get_exported_program(self) -> torch.export.ExportedProgram:
@@ -324,8 +333,15 @@ def compile(self) -> None:
324333
use_python_runtime=self.use_python_runtime,
325334
**self.additional_settings,
326335
)
327-
self.original_model.cpu()
336+
self.original_model.to("cpu")
328337
torch.cuda.empty_cache()
338+
# torch_tensorrt.runtime.set_cudagraphs_mode(self.enable_cuda_graph)
339+
# if self.enable_cuda_graph:
340+
# self.gm = torch_tensorrt.runtime.enable_cudagraphs(self.gm)
341+
if self.enable_weight_streaming:
342+
self.weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(self.gm)
343+
requested_budget = int(16 * 2 << 20)
344+
self.weight_streaming_ctx.device_budget = requested_budget
329345

330346
def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
331347

@@ -446,14 +462,21 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
446462
self._store_state_dict_metadata()
447463
self.refit_state.set_state(RefitFlag.LIVE)
448464

465+
# weight_streaming_ctx = self.weight_streaming_ctx if self.enable_weight_streaming else None
449466
result = self.gm(*args, **kwargs)
450467
# Storing inputs and outputs for verification when the state is unknown
451468
self.run_info = (args, kwargs, result)
452469
return result
453470

454-
def to(self, device: str) -> None:
455-
logger.warning("Original PyTorch model is moved. CPU offload may failed.")
456-
self.original_model.to(device)
471+
def to(self, *args: Any, **kwargs: Any) -> None:
472+
logger.warning(
473+
"Trying to move the original PyTorch model. This will cause CPU offloading failing and increase GPU memory usage."
474+
+ "If this is absolute necessary, please call module.pytorch_model.to(...)"
475+
)
476+
477+
@property
478+
def device(self) -> torch.device:
479+
return to_torch_device(self.trt_device)
457480

458481
def __deepcopy__(self, memo: Any) -> Any:
459482
cls = self.__class__

0 commit comments

Comments
 (0)
Failed to load comments.