2
2
3
3
import collections .abc
4
4
import logging
5
+ import os
5
6
import platform
6
7
import warnings
7
8
from typing import Any , Collection , List , Optional , Sequence , Set , Tuple , Union
32
33
from torch_tensorrt .dynamo .conversion ._ConverterRegistry import (
33
34
DYNAMO_CONVERTERS as CONVERTERS ,
34
35
)
36
+ from torch_tensorrt .dynamo .debug ._DebuggerConfig import DebuggerConfig
37
+ from torch_tensorrt .dynamo .debug ._supports_debugger import fn_supports_debugger
35
38
from torch_tensorrt .dynamo .lowering import (
36
39
get_decompositions ,
37
40
post_lowering ,
43
46
get_output_metadata ,
44
47
parse_graph_io ,
45
48
prepare_inputs ,
46
- set_log_level ,
47
49
to_torch_device ,
48
50
to_torch_tensorrt_device ,
49
51
)
@@ -66,7 +68,7 @@ def cross_compile_for_windows(
66
68
Set [Union [torch .dtype , dtype ]], Tuple [Union [torch .dtype , dtype ]]
67
69
] = _defaults .ENABLED_PRECISIONS ,
68
70
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
69
- debug : bool = _defaults . DEBUG ,
71
+ debug : bool = False ,
70
72
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
71
73
workspace_size : int = _defaults .WORKSPACE_SIZE ,
72
74
dla_sram_size : int = _defaults .DLA_SRAM_SIZE ,
@@ -188,7 +190,11 @@ def cross_compile_for_windows(
188
190
)
189
191
190
192
if debug :
191
- set_log_level (logger .parent , logging .DEBUG )
193
+ warnings .warn (
194
+ "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options." ,
195
+ DeprecationWarning ,
196
+ stacklevel = 2 ,
197
+ )
192
198
193
199
if "truncate_long_and_double" in kwargs .keys ():
194
200
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
@@ -299,7 +305,6 @@ def cross_compile_for_windows(
299
305
"enabled_precisions" : (
300
306
enabled_precisions if enabled_precisions else _defaults .ENABLED_PRECISIONS
301
307
),
302
- "debug" : debug ,
303
308
"device" : device ,
304
309
"assume_dynamic_shape_support" : assume_dynamic_shape_support ,
305
310
"workspace_size" : workspace_size ,
@@ -401,7 +406,7 @@ def compile(
401
406
Set [Union [torch .dtype , dtype ]], Tuple [Union [torch .dtype , dtype ]]
402
407
] = _defaults .ENABLED_PRECISIONS ,
403
408
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
404
- debug : bool = _defaults . DEBUG ,
409
+ debug : bool = False ,
405
410
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
406
411
workspace_size : int = _defaults .WORKSPACE_SIZE ,
407
412
dla_sram_size : int = _defaults .DLA_SRAM_SIZE ,
@@ -520,6 +525,13 @@ def compile(
520
525
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
521
526
"""
522
527
528
+ if debug :
529
+ warnings .warn (
530
+ "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality" ,
531
+ DeprecationWarning ,
532
+ stacklevel = 2 ,
533
+ )
534
+
523
535
if "truncate_long_and_double" in kwargs .keys ():
524
536
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
525
537
raise ValueError (
@@ -641,7 +653,6 @@ def compile(
641
653
"enabled_precisions" : (
642
654
enabled_precisions if enabled_precisions else _defaults .ENABLED_PRECISIONS
643
655
),
644
- "debug" : debug ,
645
656
"device" : device ,
646
657
"assume_dynamic_shape_support" : assume_dynamic_shape_support ,
647
658
"workspace_size" : workspace_size ,
@@ -716,12 +727,15 @@ def compile(
716
727
return trt_gm
717
728
718
729
730
+ @fn_supports_debugger
719
731
def compile_module (
720
732
gm : torch .fx .GraphModule ,
721
733
sample_arg_inputs : Sequence [Input ],
722
734
sample_kwarg_inputs : Optional [dict [Any , Any ]] = None ,
723
735
settings : CompilationSettings = CompilationSettings (),
724
736
engine_cache : Optional [BaseEngineCache ] = None ,
737
+ * ,
738
+ _debugger_settings : Optional [DebuggerConfig ] = None ,
725
739
) -> torch .fx .GraphModule :
726
740
"""Compile a traced FX module
727
741
@@ -924,6 +938,34 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
924
938
925
939
trt_modules [name ] = trt_module
926
940
941
+ if _debugger_settings :
942
+
943
+ if _debugger_settings .save_engine_profile :
944
+ if settings .use_python_runtime :
945
+ if _debugger_settings .profile_format == "trex" :
946
+ logger .warning (
947
+ "Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
948
+ )
949
+ trt_module .enable_profiling ()
950
+ else :
951
+ path = os .path .join (
952
+ _debugger_settings .logging_dir , "engine_visualization"
953
+ )
954
+ os .makedirs (path , exist_ok = True )
955
+ trt_module .enable_profiling (
956
+ profiling_results_dir = path ,
957
+ profile_format = _debugger_settings .profile_format ,
958
+ )
959
+
960
+ if _debugger_settings .save_layer_info :
961
+ with open (
962
+ os .path .join (
963
+ _debugger_settings .logging_dir , "engine_layer_info.json"
964
+ ),
965
+ "w" ,
966
+ ) as f :
967
+ f .write (trt_module .get_layer_info ())
968
+
927
969
# Parse the graph I/O and store it in dryrun tracker
928
970
parse_graph_io (gm , dryrun_tracker )
929
971
@@ -951,7 +993,7 @@ def convert_exported_program_to_serialized_trt_engine(
951
993
enabled_precisions : (
952
994
Set [torch .dtype | dtype ] | Tuple [torch .dtype | dtype ]
953
995
) = _defaults .ENABLED_PRECISIONS ,
954
- debug : bool = _defaults . DEBUG ,
996
+ debug : bool = False ,
955
997
assume_dynamic_shape_support : bool = _defaults .ASSUME_DYNAMIC_SHAPE_SUPPORT ,
956
998
workspace_size : int = _defaults .WORKSPACE_SIZE ,
957
999
min_block_size : int = _defaults .MIN_BLOCK_SIZE ,
@@ -1054,7 +1096,11 @@ def convert_exported_program_to_serialized_trt_engine(
1054
1096
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
1055
1097
"""
1056
1098
if debug :
1057
- set_log_level (logger .parent , logging .DEBUG )
1099
+ warnings .warn (
1100
+ "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options." ,
1101
+ DeprecationWarning ,
1102
+ stacklevel = 2 ,
1103
+ )
1058
1104
1059
1105
if "truncate_long_and_double" in kwargs .keys ():
1060
1106
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
0 commit comments