diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d213cca638..dfc2755252 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -35,6 +35,7 @@ pre_export_lowering, ) from torch_tensorrt.dynamo.utils import ( + colorize_log, get_flat_args_with_check, parse_graph_io, prepare_inputs, @@ -88,6 +89,7 @@ def compile( engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR, engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE, custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, + color_log: bool = _defaults.COLOR_LOG, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -158,6 +160,7 @@ def compile( engine_cache_dir (Optional[str]): Directory to store the cached TRT engines engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. + color_log (bool): Colorize logging output if rich module is available, otherwise do nothing. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -165,6 +168,10 @@ def compile( if debug: set_log_level(logger.parent, logging.DEBUG) + + if color_log: + colorize_log() + if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: raise ValueError( @@ -281,6 +288,7 @@ def compile( "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, + "color_log": color_log, } settings = CompilationSettings(**compilation_options) @@ -522,6 +530,7 @@ def convert_exported_program_to_serialized_trt_engine( calibrator: object = None, allow_shape_tensors: bool = False, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + color_log: bool = _defaults.COLOR_LOG, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -580,12 +589,16 @@ def convert_exported_program_to_serialized_trt_engine( calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + color_log (bool): Colorize logging output if rich module is available, otherwise do nothing. Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ if debug: set_log_level(logger.parent, logging.DEBUG) + if color_log: + colorize_log() + if "truncate_long_and_double" in kwargs.keys(): if truncate_double is not _defaults.TRUNCATE_DOUBLE: raise ValueError( @@ -653,6 +666,7 @@ def convert_exported_program_to_serialized_trt_engine( "dla_local_dram_size": dla_local_dram_size, "dla_global_dram_size": dla_global_dram_size, "timing_cache_path": timing_cache_path, + "color_log": color_log, } exported_program = pre_export_lowering(exported_program) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 68e446dab5..5953091742 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -40,6 +40,7 @@ ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache") ENGINE_CACHE_SIZE = 1073741824 CUSTOM_ENGINE_CACHE = None +COLOR_LOG = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 359dc0b3ff..3581058dd1 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -35,6 +35,7 @@ ) from torch_tensorrt.dynamo.utils import ( check_module_output, + colorize_log, get_model_device, get_torch_inputs, set_log_level, @@ -277,6 +278,9 @@ def refit_module_weights( if settings.debug: set_log_level(logger.parent, logging.DEBUG) + if settings.color_log: + colorize_log() + device = to_torch_tensorrt_device(settings.device) if arg_inputs: if not isinstance(arg_inputs, collections.abc.Sequence): diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index f8886fbd67..138d8b0ade 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,6 +7,7 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, + COLOR_LOG, DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -78,6 +79,7 @@ class CompilationSettings: timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage + color_log (bool): Colorize logging output if rich module is available, otherwise do nothing. """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -112,6 +114,7 @@ class CompilationSettings: lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES + color_log: bool = COLOR_LOG _SETTINGS_TO_BE_ENGINE_INVARIANT = ( diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 78f7989777..da9df9f7ff 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -7,8 +7,13 @@ import torch from torch.export import Dim, export from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import DEBUG, default_device -from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device +from torch_tensorrt.dynamo._defaults import COLOR_LOG, DEBUG, default_device +from torch_tensorrt.dynamo.utils import ( + colorize_log, + get_torch_inputs, + set_log_level, + to_torch_device, +) logger = logging.getLogger(__name__) @@ -74,6 +79,10 @@ def trace( if debug: set_log_level(logger.parent, logging.DEBUG) + color_log = kwargs.get("color_log", COLOR_LOG) + if color_log: + colorize_log() + device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 605d963a50..7e4331ab8a 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -18,6 +18,7 @@ repair_input_aliasing, ) from torch_tensorrt.dynamo.utils import ( + colorize_log, parse_dynamo_kwargs, prepare_inputs, set_log_level, @@ -39,6 +40,13 @@ def torch_tensorrt_backend( ) or ("debug" in kwargs and kwargs["debug"]): set_log_level(logger.parent, logging.DEBUG) + if ( + "options" in kwargs + and "color_log" in kwargs["options"] + and kwargs["options"]["color_log"] + ) or ("color_log" in kwargs and kwargs["color_log"]): + colorize_log() + DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend return DEFAULT_BACKEND(gm, sample_inputs, **kwargs) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index a85494239e..1c91b0492d 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -227,6 +227,26 @@ def set_log_level(parent_logger: Any, level: Any) -> None: torch.ops.tensorrt.set_logging_level(int(log_level)) +def colorize_log() -> None: + try: + from rich.console import Console + from rich.logging import RichHandler + + logging.basicConfig( + format="%(name)s:%(message)s", + handlers=[ + RichHandler( + console=Console(stderr=True), + show_time=False, + show_path=False, + rich_tracebacks=True, + ) + ], + ) + except ImportError: + pass + + def prepare_inputs( inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], disable_memory_format_check: bool = False,