Skip to content

Commit 65c9c03

Browse files
committed
Changed the debug setting (#3551)
1 parent e6d316f commit 65c9c03

File tree

11 files changed

+176
-34
lines changed

11 files changed

+176
-34
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
load_cross_compiled_exported_program,
1515
save_cross_compiled_exported_program,
1616
)
17-
from ._Debugger import Debugger
1817
from ._exporter import export
1918
from ._refit import refit_module_weights
2019
from ._settings import CompilationSettings
2120
from ._SourceIR import SourceIR
2221
from ._tracer import trace
22+
from .debug._Debugger import Debugger

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import os
56
import platform
67
import warnings
78
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
@@ -32,6 +33,8 @@
3233
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3334
DYNAMO_CONVERTERS as CONVERTERS,
3435
)
36+
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
37+
from torch_tensorrt.dynamo.debug._supports_debugger import fn_supports_debugger
3538
from torch_tensorrt.dynamo.lowering import (
3639
get_decompositions,
3740
post_lowering,
@@ -43,7 +46,6 @@
4346
get_output_metadata,
4447
parse_graph_io,
4548
prepare_inputs,
46-
set_log_level,
4749
to_torch_device,
4850
to_torch_tensorrt_device,
4951
)
@@ -66,7 +68,7 @@ def cross_compile_for_windows(
6668
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
6769
] = _defaults.ENABLED_PRECISIONS,
6870
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
69-
debug: bool = _defaults.DEBUG,
71+
debug: bool = False,
7072
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
7173
workspace_size: int = _defaults.WORKSPACE_SIZE,
7274
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -188,7 +190,11 @@ def cross_compile_for_windows(
188190
)
189191

190192
if debug:
191-
set_log_level(logger.parent, logging.DEBUG)
193+
warnings.warn(
194+
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.",
195+
DeprecationWarning,
196+
stacklevel=2,
197+
)
192198

193199
if "truncate_long_and_double" in kwargs.keys():
194200
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
@@ -299,7 +305,6 @@ def cross_compile_for_windows(
299305
"enabled_precisions": (
300306
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
301307
),
302-
"debug": debug,
303308
"device": device,
304309
"assume_dynamic_shape_support": assume_dynamic_shape_support,
305310
"workspace_size": workspace_size,
@@ -401,7 +406,7 @@ def compile(
401406
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
402407
] = _defaults.ENABLED_PRECISIONS,
403408
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
404-
debug: bool = _defaults.DEBUG,
409+
debug: bool = False,
405410
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
406411
workspace_size: int = _defaults.WORKSPACE_SIZE,
407412
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -520,6 +525,13 @@ def compile(
520525
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
521526
"""
522527

528+
if debug:
529+
warnings.warn(
530+
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality",
531+
DeprecationWarning,
532+
stacklevel=2,
533+
)
534+
523535
if "truncate_long_and_double" in kwargs.keys():
524536
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
525537
raise ValueError(
@@ -641,7 +653,6 @@ def compile(
641653
"enabled_precisions": (
642654
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
643655
),
644-
"debug": debug,
645656
"device": device,
646657
"assume_dynamic_shape_support": assume_dynamic_shape_support,
647658
"workspace_size": workspace_size,
@@ -716,12 +727,15 @@ def compile(
716727
return trt_gm
717728

718729

730+
@fn_supports_debugger
719731
def compile_module(
720732
gm: torch.fx.GraphModule,
721733
sample_arg_inputs: Sequence[Input],
722734
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
723735
settings: CompilationSettings = CompilationSettings(),
724736
engine_cache: Optional[BaseEngineCache] = None,
737+
*,
738+
_debugger_settings: Optional[DebuggerConfig] = None,
725739
) -> torch.fx.GraphModule:
726740
"""Compile a traced FX module
727741
@@ -924,6 +938,34 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
924938

925939
trt_modules[name] = trt_module
926940

941+
if _debugger_settings:
942+
943+
if _debugger_settings.save_engine_profile:
944+
if settings.use_python_runtime:
945+
if _debugger_settings.profile_format == "trex":
946+
logger.warning(
947+
"Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
948+
)
949+
trt_module.enable_profiling()
950+
else:
951+
path = os.path.join(
952+
_debugger_settings.logging_dir, "engine_visualization"
953+
)
954+
os.makedirs(path, exist_ok=True)
955+
trt_module.enable_profiling(
956+
profiling_results_dir=path,
957+
profile_format=_debugger_settings.profile_format,
958+
)
959+
960+
if _debugger_settings.save_layer_info:
961+
with open(
962+
os.path.join(
963+
_debugger_settings.logging_dir, "engine_layer_info.json"
964+
),
965+
"w",
966+
) as f:
967+
f.write(trt_module.get_layer_info())
968+
927969
# Parse the graph I/O and store it in dryrun tracker
928970
parse_graph_io(gm, dryrun_tracker)
929971

@@ -951,7 +993,7 @@ def convert_exported_program_to_serialized_trt_engine(
951993
enabled_precisions: (
952994
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
953995
) = _defaults.ENABLED_PRECISIONS,
954-
debug: bool = _defaults.DEBUG,
996+
debug: bool = False,
955997
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
956998
workspace_size: int = _defaults.WORKSPACE_SIZE,
957999
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
@@ -1054,7 +1096,11 @@ def convert_exported_program_to_serialized_trt_engine(
10541096
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10551097
"""
10561098
if debug:
1057-
set_log_level(logger.parent, logging.DEBUG)
1099+
warnings.warn(
1100+
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.",
1101+
DeprecationWarning,
1102+
stacklevel=2,
1103+
)
10581104

10591105
if "truncate_long_and_double" in kwargs.keys():
10601106
if truncate_double is not _defaults.TRUNCATE_DOUBLE:

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch_tensorrt._enums import EngineCapability, dtype
77

88
ENABLED_PRECISIONS = {dtype.f32}
9-
DEBUG = False
109
DEVICE = None
1110
DISABLE_TF32 = False
1211
ASSUME_DYNAMIC_SHAPE_SUPPORT = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from torch_tensorrt.dynamo._defaults import (
88
ASSUME_DYNAMIC_SHAPE_SUPPORT,
99
CACHE_BUILT_ENGINES,
10-
DEBUG,
1110
DISABLE_TF32,
1211
DLA_GLOBAL_DRAM_SIZE,
1312
DLA_LOCAL_DRAM_SIZE,
@@ -101,7 +100,6 @@ class CompilationSettings:
101100
"""
102101

103102
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
104-
debug: bool = DEBUG
105103
workspace_size: int = WORKSPACE_SIZE
106104
min_block_size: int = MIN_BLOCK_SIZE
107105
torch_executed_ops: Collection[Target] = field(default_factory=set)

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch
88
from torch.export import Dim, export
99
from torch_tensorrt._Input import Input
10-
from torch_tensorrt.dynamo._defaults import DEBUG, default_device
11-
from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device
10+
from torch_tensorrt.dynamo._defaults import default_device
11+
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -70,10 +70,6 @@ def trace(
7070
if kwarg_inputs is None:
7171
kwarg_inputs = {}
7272

73-
debug = kwargs.get("debug", DEBUG)
74-
if debug:
75-
set_log_level(logger.parent, logging.DEBUG)
76-
7773
device = to_torch_device(kwargs.get("device", default_device()))
7874
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
7975
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
to_torch,
4747
)
4848
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
49+
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
50+
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
4951
from torch_tensorrt.fx.observer import Observer
5052
from torch_tensorrt.logging import TRT_LOGGER
5153

@@ -70,6 +72,7 @@ class TRTInterpreterResult(NamedTuple):
7072
requires_output_allocator: bool
7173

7274

75+
@cls_supports_debugger
7376
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
7477
def __init__(
7578
self,
@@ -78,12 +81,14 @@ def __init__(
7881
output_dtypes: Optional[Sequence[dtype]] = None,
7982
compilation_settings: CompilationSettings = CompilationSettings(),
8083
engine_cache: Optional[BaseEngineCache] = None,
84+
*,
85+
_debugger_settings: Optional[DebuggerConfig] = None,
8186
):
8287
super().__init__(module)
8388

8489
self.logger = TRT_LOGGER
8590
self.builder = trt.Builder(self.logger)
86-
91+
self._debugger_settings = _debugger_settings
8792
flag = 0
8893
if compilation_settings.use_explicit_typing:
8994
STRONGLY_TYPED = 1 << (int)(
@@ -204,7 +209,7 @@ def _populate_trt_builder_config(
204209
) -> trt.IBuilderConfig:
205210
builder_config = self.builder.create_builder_config()
206211

207-
if self.compilation_settings.debug:
212+
if self._debugger_settings and self._debugger_settings.engine_builder_monitor:
208213
builder_config.progress_monitor = TRTBulderMonitor()
209214

210215
if self.compilation_settings.workspace_size != 0:
@@ -215,7 +220,8 @@ def _populate_trt_builder_config(
215220
if version.parse(trt.__version__) >= version.parse("8.2"):
216221
builder_config.profiling_verbosity = (
217222
trt.ProfilingVerbosity.DETAILED
218-
if self.compilation_settings.debug
223+
if self._debugger_settings
224+
and self._debugger_settings.save_engine_profile
219225
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
220226
)
221227

0 commit comments

Comments
 (0)