Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
pre_export_lowering,
)
from torch_tensorrt.dynamo.utils import (
colorize_log,
get_flat_args_with_check,
parse_graph_io,
prepare_inputs,
Expand Down Expand Up @@ -88,6 +89,7 @@ def compile(
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
color_log: bool = _defaults.COLOR_LOG,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -158,13 +160,18 @@ def compile(
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
color_log (bool): Colorize logging output if rich module is available, otherwise do nothing.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
"""

if debug:
set_log_level(logger.parent, logging.DEBUG)

if color_log:
colorize_log()

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
Expand Down Expand Up @@ -281,6 +288,7 @@ def compile(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"color_log": color_log,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -522,6 +530,7 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator: object = None,
allow_shape_tensors: bool = False,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
color_log: bool = _defaults.COLOR_LOG,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -580,12 +589,16 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
color_log (bool): Colorize logging output if rich module is available, otherwise do nothing.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
if debug:
set_log_level(logger.parent, logging.DEBUG)

if color_log:
colorize_log()

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
Expand Down Expand Up @@ -653,6 +666,7 @@ def convert_exported_program_to_serialized_trt_engine(
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"timing_cache_path": timing_cache_path,
"color_log": color_log,
}

exported_program = pre_export_lowering(exported_program)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
CUSTOM_ENGINE_CACHE = None
COLOR_LOG = False


def default_device() -> Device:
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from torch_tensorrt.dynamo.utils import (
check_module_output,
colorize_log,
get_model_device,
get_torch_inputs,
set_log_level,
Expand Down Expand Up @@ -277,6 +278,9 @@ def refit_module_weights(
if settings.debug:
set_log_level(logger.parent, logging.DEBUG)

if settings.color_log:
colorize_log()

device = to_torch_tensorrt_device(settings.device)
if arg_inputs:
if not isinstance(arg_inputs, collections.abc.Sequence):
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
COLOR_LOG,
DEBUG,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
Expand Down Expand Up @@ -78,6 +79,7 @@ class CompilationSettings:
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
color_log (bool): Colorize logging output if rich module is available, otherwise do nothing.
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -112,6 +114,7 @@ class CompilationSettings:
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
color_log: bool = COLOR_LOG


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
13 changes: 11 additions & 2 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
import torch
from torch.export import Dim, export
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._defaults import DEBUG, default_device
from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device
from torch_tensorrt.dynamo._defaults import COLOR_LOG, DEBUG, default_device
from torch_tensorrt.dynamo.utils import (
colorize_log,
get_torch_inputs,
set_log_level,
to_torch_device,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,6 +79,10 @@ def trace(
if debug:
set_log_level(logger.parent, logging.DEBUG)

color_log = kwargs.get("color_log", COLOR_LOG)
if color_log:
colorize_log()

device = to_torch_device(kwargs.get("device", default_device()))
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
repair_input_aliasing,
)
from torch_tensorrt.dynamo.utils import (
colorize_log,
parse_dynamo_kwargs,
prepare_inputs,
set_log_level,
Expand All @@ -39,6 +40,13 @@ def torch_tensorrt_backend(
) or ("debug" in kwargs and kwargs["debug"]):
set_log_level(logger.parent, logging.DEBUG)

if (
"options" in kwargs
and "color_log" in kwargs["options"]
and kwargs["options"]["color_log"]
) or ("color_log" in kwargs and kwargs["color_log"]):
colorize_log()

DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend

return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
Expand Down
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,26 @@ def set_log_level(parent_logger: Any, level: Any) -> None:
torch.ops.tensorrt.set_logging_level(int(log_level))


def colorize_log() -> None:
try:
from rich.console import Console
from rich.logging import RichHandler

logging.basicConfig(
format="%(name)s:%(message)s",
handlers=[
RichHandler(
console=Console(stderr=True),
show_time=False,
show_path=False,
rich_tracebacks=True,
)
],
)
except ImportError:
pass


def prepare_inputs(
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
disable_memory_format_check: bool = False,
Expand Down