From bf817ea66baf54f63992be813470e084ea18d551 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 4 Jun 2025 06:25:17 +0000 Subject: [PATCH 1/2] Changed the debug setting --- py/torch_tensorrt/dynamo/__init__.py | 2 +- py/torch_tensorrt/dynamo/_compiler.py | 29 +++++++++- py/torch_tensorrt/dynamo/_defaults.py | 1 - py/torch_tensorrt/dynamo/_settings.py | 2 - .../dynamo/conversion/_TRTInterpreter.py | 14 +++-- .../dynamo/{ => debug}/_Debugger.py | 55 +++++++++++++++---- .../dynamo/debug/_DebuggerConfig.py | 10 ++++ .../dynamo/debug/_supports_debugger.py | 19 +++++++ .../runtime/_PythonTorchTensorRTModule.py | 11 +++- 9 files changed, 122 insertions(+), 21 deletions(-) rename py/torch_tensorrt/dynamo/{ => debug}/_Debugger.py (78%) create mode 100644 py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py create mode 100644 py/torch_tensorrt/dynamo/debug/_supports_debugger.py diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 15a17a4f02..607dca76bf 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -14,9 +14,9 @@ load_cross_compiled_exported_program, save_cross_compiled_exported_program, ) - from ._Debugger import Debugger from ._exporter import export from ._refit import refit_module_weights from ._settings import CompilationSettings from ._SourceIR import SourceIR from ._tracer import trace + from .debug._Debugger import Debugger diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index b7914efd37..65de06550e 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -2,6 +2,7 @@ import collections.abc import logging +import os import platform import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union @@ -31,6 +32,8 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import fn_supports_debugger from torch_tensorrt.dynamo.lowering import ( get_decompositions, post_lowering, @@ -503,6 +506,13 @@ def compile( torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT """ + if debug: + warnings.warn( + "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality", + DeprecationWarning, + stacklevel=2, + ) + if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: raise ValueError( @@ -633,7 +643,6 @@ def compile( "enabled_precisions": ( enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS ), - "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, @@ -694,12 +703,15 @@ def compile( return trt_gm +@fn_supports_debugger def compile_module( gm: torch.fx.GraphModule, sample_arg_inputs: Sequence[Input], sample_kwarg_inputs: Optional[dict[Any, Any]] = None, settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, + *, + _debugger_settings: Optional[DebuggerConfig] = None, ) -> torch.fx.GraphModule: """Compile a traced FX module @@ -900,6 +912,21 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules[name] = trt_module + if _debugger_settings and _debugger_settings.save_engine_profile: + if settings.use_python_runtime: + logger.warning( + "Profiling can only be enabled when using the C++ runtime" + ) + else: + path = os.path.join( + _debugger_settings.logging_dir, "engine_visualization" + ) + os.makedirs(path, exist_ok=True) + trt_module.enable_profiling( + profiling_results_dir=path, + profile_format="trex", + ) + # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 379a196e2e..297c8db52c 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -6,7 +6,6 @@ from torch_tensorrt._enums import EngineCapability, dtype ENABLED_PRECISIONS = {dtype.f32} -DEBUG = False DEVICE = None DISABLE_TF32 = False ASSUME_DYNAMIC_SHAPE_SUPPORT = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d9b0e05e4d..8cff9dd63a 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,7 +7,6 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, - DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -100,7 +99,6 @@ class CompilationSettings: """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) - debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Collection[Target] = field(default_factory=set) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index b54dc6d461..ecd7aa45b2 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -45,6 +45,8 @@ get_trt_tensor, to_torch, ) +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -70,6 +72,7 @@ class TRTInterpreterResult(NamedTuple): requires_output_allocator: bool +@cls_supports_debugger class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] def __init__( self, @@ -78,12 +81,14 @@ def __init__( output_dtypes: Optional[Sequence[dtype]] = None, compilation_settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, + *, + _debugger_settings: Optional[DebuggerConfig] = None, ): super().__init__(module) self.logger = TRT_LOGGER self.builder = trt.Builder(self.logger) - + self._debugger_settings = _debugger_settings flag = 0 if compilation_settings.use_explicit_typing: STRONGLY_TYPED = 1 << (int)( @@ -204,7 +209,7 @@ def _populate_trt_builder_config( ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() - if self.compilation_settings.debug: + if self._debugger_settings and self._debugger_settings.engine_builder_monitor: builder_config.progress_monitor = TRTBulderMonitor() if self.compilation_settings.workspace_size != 0: @@ -215,7 +220,8 @@ def _populate_trt_builder_config( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( trt.ProfilingVerbosity.DETAILED - if self.compilation_settings.debug + if self._debugger_settings + and self._debugger_settings.save_engine_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) @@ -713,7 +719,7 @@ def run( if self.compilation_settings.reuse_cached_engines: interpreter_result = self._pull_cached_engine(hash_val) if interpreter_result is not None: # hit the cache - return interpreter_result # type: ignore[no-any-return] + return interpreter_result self._construct_trt_network_def() diff --git a/py/torch_tensorrt/dynamo/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py similarity index 78% rename from py/torch_tensorrt/dynamo/_Debugger.py rename to py/torch_tensorrt/dynamo/debug/_Debugger.py index 2b92e1fa51..14a1f37917 100644 --- a/py/torch_tensorrt/dynamo/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -1,10 +1,18 @@ +import contextlib +import functools import logging import os import tempfile from logging.config import dictConfig from typing import Any, List, Optional +from unittest import mock import torch +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import ( + _DEBUG_ENABLED_CLS, + _DEBUG_ENABLED_FUNCS, +) from torch_tensorrt.dynamo.lowering import ( ATEN_POST_LOWERING_PASSES, ATEN_PRE_LOWERING_PASSES, @@ -18,13 +26,21 @@ class Debugger: def __init__( self, - log_level: str, + log_level: str = "debug", capture_fx_graph_before: Optional[List[str]] = None, capture_fx_graph_after: Optional[List[str]] = None, save_engine_profile: bool = False, - logging_dir: Optional[str] = None, + engine_builder_monitor: bool = True, + logging_dir: str = tempfile.gettempdir(), ): - self.debug_file_dir = tempfile.TemporaryDirectory().name + + os.makedirs(logging_dir, exist_ok=True) + self.cfg = DebuggerConfig( + log_level=log_level, + save_engine_profile=save_engine_profile, + engine_builder_monitor=engine_builder_monitor, + logging_dir=logging_dir, + ) if log_level == "debug": self.log_level = logging.DEBUG @@ -47,14 +63,10 @@ def __init__( self.capture_fx_graph_before = capture_fx_graph_before self.capture_fx_graph_after = capture_fx_graph_after - if logging_dir is not None: - self.debug_file_dir = logging_dir - os.makedirs(self.debug_file_dir, exist_ok=True) - def __enter__(self) -> None: self.original_lvl = _LOGGER.getEffectiveLevel() self.rt_level = torch.ops.tensorrt.get_logging_level() - dictConfig(self.get_config()) + dictConfig(self.get_customized_logging_config()) if self.capture_fx_graph_before or self.capture_fx_graph_after: self.old_pre_passes, self.old_post_passes = ( @@ -63,7 +75,7 @@ def __enter__(self) -> None: ) pre_pass_names = [p.__name__ for p in self.old_pre_passes] post_pass_names = [p.__name__ for p in self.old_post_passes] - path = os.path.join(self.debug_file_dir, "lowering_passes_visualization") + path = os.path.join(self.cfg.logging_dir, "lowering_passes_visualization") if self.capture_fx_graph_before is not None: pre_vis_passes = [ p for p in self.capture_fx_graph_before if p in pre_pass_names @@ -85,9 +97,25 @@ def __enter__(self) -> None: ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(pre_vis_passes, path) ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(post_vis_passes, path) + self._context_stack = contextlib.ExitStack() + + for f in _DEBUG_ENABLED_FUNCS: + f.__kwdefaults__["_debugger_settings"] = self.cfg + + [ + self._context_stack.enter_context( + mock.patch.object( + c, + "__init__", + functools.partialmethod(c.__init__, _debugger_settings=self.cfg), + ) + ) + for c in _DEBUG_ENABLED_CLS + ] + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - dictConfig(self.get_default_config()) + dictConfig(self.get_default_logging_config()) torch.ops.tensorrt.set_logging_level(self.rt_level) if self.capture_fx_graph_before or self.capture_fx_graph_after: ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = ( @@ -96,6 +124,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: ) self.debug_file_dir = tempfile.TemporaryDirectory().name + for f in _DEBUG_ENABLED_FUNCS: + f.__kwdefaults__["_debugger_settings"] = None + + self._context_stack.close() + def get_customized_logging_config(self) -> dict[str, Any]: config = { "version": 1, @@ -114,7 +147,7 @@ def get_customized_logging_config(self) -> dict[str, Any]: "file": { "level": self.log_level, "class": "logging.FileHandler", - "filename": f"{self.debug_file_dir}/torch_tensorrt_logging.log", + "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", "formatter": "standard", }, "console": { diff --git a/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py new file mode 100644 index 0000000000..5dc51e286e --- /dev/null +++ b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py @@ -0,0 +1,10 @@ +import tempfile +from dataclasses import dataclass + + +@dataclass +class DebuggerConfig: + log_level: str = "debug" + save_engine_profile: bool = False + engine_builder_monitor: bool = True + logging_dir: str = tempfile.gettempdir() diff --git a/py/torch_tensorrt/dynamo/debug/_supports_debugger.py b/py/torch_tensorrt/dynamo/debug/_supports_debugger.py new file mode 100644 index 0000000000..627be743df --- /dev/null +++ b/py/torch_tensorrt/dynamo/debug/_supports_debugger.py @@ -0,0 +1,19 @@ +from typing import Any, Callable, Type, TypeVar + +T = TypeVar("T") + + +_DEBUG_ENABLED_FUNCS = [] +_DEBUG_ENABLED_CLS = [] + +F = TypeVar("F", bound=Callable[..., Any]) + + +def fn_supports_debugger(func: F) -> F: + _DEBUG_ENABLED_FUNCS.append(func) + return func + + +def cls_supports_debugger(cls: Type[T]) -> Type[T]: + _DEBUG_ENABLED_CLS.append(cls) + return cls diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 6415ce11c3..8d1a31564d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -12,6 +12,8 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig +from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( @@ -111,6 +113,7 @@ def set_runtime_states( ) +@cls_supports_debugger class PythonTorchTensorRTModule(Module): # type: ignore[misc] """PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. @@ -128,6 +131,7 @@ def __init__( settings: CompilationSettings = CompilationSettings(), weight_name_map: Optional[dict[Any, Any]] = None, requires_output_allocator: bool = False, + _debugger_settings: Optional[DebuggerConfig] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -157,6 +161,7 @@ def __init__( """ self.context: Any + self._debugger_settings: Optional[DebuggerConfig] = _debugger_settings super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) @@ -193,7 +198,11 @@ def __init__( self.target_device_properties = torch.cuda.get_device_properties( self.target_device_id ) - self.profiling_enabled = settings.debug if settings.debug is not None else False + self.profiling_enabled = ( + _debugger_settings.save_engine_profile + if _debugger_settings is not None + else False + ) self.settings = settings self.engine = None self.weight_name_map = weight_name_map From 197069e134ef4bbc3ae884e5151eb319050a0f22 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 4 Jun 2025 17:57:05 +0000 Subject: [PATCH 2/2] Added profile_format and added docstring to the debugger --- py/torch_tensorrt/dynamo/_compiler.py | 61 ++++++++++++------- py/torch_tensorrt/dynamo/_tracer.py | 8 +-- .../dynamo/conversion/_TRTInterpreter.py | 2 +- py/torch_tensorrt/dynamo/debug/_Debugger.py | 26 ++++++++ .../dynamo/debug/_DebuggerConfig.py | 2 + .../dynamo/debug/_supports_debugger.py | 4 +- .../runtime/_MutableTorchTensorRTModule.py | 11 +++- 7 files changed, 80 insertions(+), 34 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 65de06550e..cfde9acce4 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -44,7 +44,6 @@ get_output_metadata, parse_graph_io, prepare_inputs, - set_log_level, to_torch_device, to_torch_tensorrt_device, ) @@ -66,7 +65,7 @@ def cross_compile_for_windows( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - debug: bool = _defaults.DEBUG, + debug: bool = False, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -187,7 +186,11 @@ def cross_compile_for_windows( ) if debug: - set_log_level(logger.parent, logging.DEBUG) + warnings.warn( + "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.", + DeprecationWarning, + stacklevel=2, + ) if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: @@ -298,7 +301,6 @@ def cross_compile_for_windows( "enabled_precisions": ( enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS ), - "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, @@ -389,7 +391,7 @@ def compile( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - debug: bool = _defaults.DEBUG, + debug: bool = False, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -912,20 +914,33 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules[name] = trt_module - if _debugger_settings and _debugger_settings.save_engine_profile: - if settings.use_python_runtime: - logger.warning( - "Profiling can only be enabled when using the C++ runtime" - ) - else: - path = os.path.join( - _debugger_settings.logging_dir, "engine_visualization" - ) - os.makedirs(path, exist_ok=True) - trt_module.enable_profiling( - profiling_results_dir=path, - profile_format="trex", - ) + if _debugger_settings: + + if _debugger_settings.save_engine_profile: + if settings.use_python_runtime: + if _debugger_settings.profile_format == "trex": + logger.warning( + "Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization." + ) + trt_module.enable_profiling() + else: + path = os.path.join( + _debugger_settings.logging_dir, "engine_visualization" + ) + os.makedirs(path, exist_ok=True) + trt_module.enable_profiling( + profiling_results_dir=path, + profile_format=_debugger_settings.profile_format, + ) + + if _debugger_settings.save_layer_info: + with open( + os.path.join( + _debugger_settings.logging_dir, "engine_layer_info.json" + ), + "w", + ) as f: + f.write(trt_module.get_layer_info()) # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) @@ -954,7 +969,7 @@ def convert_exported_program_to_serialized_trt_engine( enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, - debug: bool = _defaults.DEBUG, + debug: bool = False, assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, workspace_size: int = _defaults.WORKSPACE_SIZE, min_block_size: int = _defaults.MIN_BLOCK_SIZE, @@ -1056,7 +1071,11 @@ def convert_exported_program_to_serialized_trt_engine( bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ if debug: - set_log_level(logger.parent, logging.DEBUG) + warnings.warn( + "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.", + DeprecationWarning, + stacklevel=2, + ) if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 78f7989777..5f4bdd0a8d 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -7,8 +7,8 @@ import torch from torch.export import Dim, export from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import DEBUG, default_device -from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device +from torch_tensorrt.dynamo._defaults import default_device +from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device logger = logging.getLogger(__name__) @@ -70,10 +70,6 @@ def trace( if kwarg_inputs is None: kwarg_inputs = {} - debug = kwargs.get("debug", DEBUG) - if debug: - set_log_level(logger.parent, logging.DEBUG) - device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index ecd7aa45b2..054bd4215b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -719,7 +719,7 @@ def run( if self.compilation_settings.reuse_cached_engines: interpreter_result = self._pull_cached_engine(hash_val) if interpreter_result is not None: # hit the cache - return interpreter_result + return interpreter_result # type: ignore[no-any-return] self._construct_trt_network_def() diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index 14a1f37917..bb9dffbfc1 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -30,9 +30,33 @@ def __init__( capture_fx_graph_before: Optional[List[str]] = None, capture_fx_graph_after: Optional[List[str]] = None, save_engine_profile: bool = False, + profile_format: str = "perfetto", engine_builder_monitor: bool = True, logging_dir: str = tempfile.gettempdir(), + save_layer_info: bool = False, ): + """Initialize a debugger for TensorRT conversion. + + Args: + log_level (str): Logging level to use. Valid options are: + 'debug', 'info', 'warning', 'error', 'internal_errors', 'graphs'. + Defaults to 'debug'. + capture_fx_graph_before (List[str], optional): List of pass names to visualize FX graph + before execution of a lowering pass. Defaults to None. + capture_fx_graph_after (List[str], optional): List of pass names to visualize FX graph + after execution of a lowering pass. Defaults to None. + save_engine_profile (bool): Whether to save TensorRT engine profiling information. + Defaults to False. + profile_format (str): Format for profiling data. Can be either 'perfetto' or 'trex'. + If you need to generate engine graph using the profiling files, set it to 'trex' . + Defaults to 'perfetto'. + engine_builder_monitor (bool): Whether to monitor TensorRT engine building process. + Defaults to True. + logging_dir (str): Directory to save debug logs and profiles. + Defaults to system temp directory. + save_layer_info (bool): Whether to save layer info. + Defaults to False. + """ os.makedirs(logging_dir, exist_ok=True) self.cfg = DebuggerConfig( @@ -40,6 +64,8 @@ def __init__( save_engine_profile=save_engine_profile, engine_builder_monitor=engine_builder_monitor, logging_dir=logging_dir, + profile_format=profile_format, + save_layer_info=save_layer_info, ) if log_level == "debug": diff --git a/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py index 5dc51e286e..3c409b0aa8 100644 --- a/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py +++ b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py @@ -8,3 +8,5 @@ class DebuggerConfig: save_engine_profile: bool = False engine_builder_monitor: bool = True logging_dir: str = tempfile.gettempdir() + profile_format: str = "perfetto" + save_layer_info: bool = False diff --git a/py/torch_tensorrt/dynamo/debug/_supports_debugger.py b/py/torch_tensorrt/dynamo/debug/_supports_debugger.py index 627be743df..2d9fd2a149 100644 --- a/py/torch_tensorrt/dynamo/debug/_supports_debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_supports_debugger.py @@ -1,13 +1,11 @@ from typing import Any, Callable, Type, TypeVar T = TypeVar("T") - +F = TypeVar("F", bound=Callable[..., Any]) _DEBUG_ENABLED_FUNCS = [] _DEBUG_ENABLED_CLS = [] -F = TypeVar("F", bound=Callable[..., Any]) - def fn_supports_debugger(func: F) -> F: _DEBUG_ENABLED_FUNCS.append(func) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index eaeb6a8c28..c6bd22f938 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -1,5 +1,6 @@ import inspect import logging +import warnings from copy import deepcopy from enum import Enum, auto from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union @@ -71,7 +72,7 @@ def __init__( ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, immutable_weights: bool = False, - debug: bool = _defaults.DEBUG, + debug: bool = False, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -109,7 +110,6 @@ def __init__( sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. - debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -156,6 +156,12 @@ def __init__( self.kwarg_inputs: dict[str, Any] = {} device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} + if debug: + warnings.warn( + "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.", + DeprecationWarning, + stacklevel=2, + ) assert ( not immutable_weights ), "`immutable_weights` has to be False for a MutableTorchTensorRTModule." @@ -165,7 +171,6 @@ def __init__( if enabled_precisions else _defaults.ENABLED_PRECISIONS ), - "debug": debug, "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size,