diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py index 73be3b2400..7dbaf70571 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -1,4 +1,6 @@ -from typing import Any, Callable, List, Optional +import tempfile +from types import new_class +from typing import Any, Callable, List, Optional, Union import torch from torch.fx import passes @@ -6,30 +8,14 @@ from torch_tensorrt.dynamo._settings import CompilationSettings -def get_draw_fx_graph_pass_lowering( - idx: int, path_prefix: str, post: bool +def _generate_draw_fx_graph_pass( + output_path_prefix: str, name: str ) -> Callable[[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule]: - from torch_tensorrt.dynamo.lowering.passes import ( - post_lowering_pass_list, - pre_lowering_pass_list, - ) - - PRE_DEBUG_NAME = { - i + 1: f"after_{p.__name__}" for i, p in enumerate(pre_lowering_pass_list) - } - PRE_DEBUG_NAME[0] = "exported_program" - - POST_DEBUG_NAME = { - i + 1: f"after_{p.__name__}" for i, p in enumerate(post_lowering_pass_list) - } - POST_DEBUG_NAME[0] = "after_decomposition" - def draw_fx_graph_pass( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: - DEBUG_NAME = POST_DEBUG_NAME[idx] if post else PRE_DEBUG_NAME[idx] - path = f"{path_prefix}_{DEBUG_NAME}.svg" - g = passes.graph_drawer.FxGraphDrawer(gm, DEBUG_NAME) + path = f"{output_path_prefix}/{name}.svg" + g = passes.graph_drawer.FxGraphDrawer(gm, name) with open(path, "wb") as f: f.write(g.get_dot_graph().create_svg()) return gm @@ -47,8 +33,9 @@ def __init__( ] ] ] = None, + constraints: Optional[List[Callable]] = None ): - super().__init__(passes) + super().__init__(passes, constraints) @classmethod def build_from_passlist( @@ -80,16 +67,48 @@ def add_pass_with_index( def remove_pass_with_index(self, index: int) -> None: del self.passes[index] - def insert_debug_pass( - self, index: List[int], filename_prefix: str, post: bool = True + def insert_debug_pass_before( + self, passes: List[str], output_path_prefix: str=tempfile.gettempdir() ) -> None: + """Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass. + + Args: + passes: List of pass names to insert debug passes before + output_path_prefix: Prefix to use for generated debug files + + Debug passes generate SVG visualizations of the FX graph at specified points + in the pass sequence. + """ + new_pass_list = [] + for ps in self.passes: + if ps.__name__ in passes: + new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"before_{ps.__name__}")) + new_pass_list.append(ps) + + self.passes = new_pass_list + self._validated = False + + def insert_debug_pass_after( + self, passes: List[str], output_path_prefix: str=tempfile.gettempdir() + ) -> None: + """Insert debug passes in the PassManager pass sequence after the execution of a particular pass. + + Args: + passes: List of pass names to insert debug passes after + output_path_prefix: Prefix to use for generated debug files + + Debug passes generate SVG visualizations of the FX graph at specified points + in the pass sequence. + """ + new_pass_list = [] + for ps in self.passes: + new_pass_list.append(ps) + if ps.__name__ in passes: + new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"after_{ps.__name__}")) - for i in range(len(index)): - debug_pass = get_draw_fx_graph_pass_lowering( - index[i], filename_prefix, post - ) - self.add_pass_with_index(debug_pass, index[i] + i) + self.passes = new_pass_list + self._validated = False def __call__(self, gm: Any, settings: CompilationSettings) -> Any: self.validate()