Skip to content

Simplify pass manager debug system #3530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 48 additions & 29 deletions py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,21 @@
from typing import Any, Callable, List, Optional
import tempfile
from types import new_class
from typing import Any, Callable, List, Optional, Union

import torch
from torch.fx import passes
from torch.fx.passes.pass_manager import PassManager
from torch_tensorrt.dynamo._settings import CompilationSettings


def get_draw_fx_graph_pass_lowering(
idx: int, path_prefix: str, post: bool
def _generate_draw_fx_graph_pass(
output_path_prefix: str, name: str
) -> Callable[[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule]:
from torch_tensorrt.dynamo.lowering.passes import (
post_lowering_pass_list,
pre_lowering_pass_list,
)

PRE_DEBUG_NAME = {
i + 1: f"after_{p.__name__}" for i, p in enumerate(pre_lowering_pass_list)
}
PRE_DEBUG_NAME[0] = "exported_program"

POST_DEBUG_NAME = {
i + 1: f"after_{p.__name__}" for i, p in enumerate(post_lowering_pass_list)
}
POST_DEBUG_NAME[0] = "after_decomposition"

def draw_fx_graph_pass(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
DEBUG_NAME = POST_DEBUG_NAME[idx] if post else PRE_DEBUG_NAME[idx]
path = f"{path_prefix}_{DEBUG_NAME}.svg"
g = passes.graph_drawer.FxGraphDrawer(gm, DEBUG_NAME)
path = f"{output_path_prefix}/{name}.svg"
g = passes.graph_drawer.FxGraphDrawer(gm, name)
with open(path, "wb") as f:
f.write(g.get_dot_graph().create_svg())
return gm
@@ -47,8 +33,9 @@ def __init__(
]
]
] = None,
constraints: Optional[List[Callable]] = None
):
super().__init__(passes)
super().__init__(passes, constraints)

@classmethod
def build_from_passlist(
@@ -80,16 +67,48 @@ def add_pass_with_index(
def remove_pass_with_index(self, index: int) -> None:
del self.passes[index]

def insert_debug_pass(
self, index: List[int], filename_prefix: str, post: bool = True
def insert_debug_pass_before(
self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
) -> None:
"""Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass.

Args:
passes: List of pass names to insert debug passes before
output_path_prefix: Prefix to use for generated debug files

Debug passes generate SVG visualizations of the FX graph at specified points
in the pass sequence.
"""
new_pass_list = []
for ps in self.passes:
if ps.__name__ in passes:
new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"before_{ps.__name__}"))
new_pass_list.append(ps)

self.passes = new_pass_list
self._validated = False

def insert_debug_pass_after(
self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
) -> None:
"""Insert debug passes in the PassManager pass sequence after the execution of a particular pass.

Args:
passes: List of pass names to insert debug passes after
output_path_prefix: Prefix to use for generated debug files

Debug passes generate SVG visualizations of the FX graph at specified points
in the pass sequence.
"""
new_pass_list = []
for ps in self.passes:
new_pass_list.append(ps)
if ps.__name__ in passes:
new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"after_{ps.__name__}"))

for i in range(len(index)):

debug_pass = get_draw_fx_graph_pass_lowering(
index[i], filename_prefix, post
)
self.add_pass_with_index(debug_pass, index[i] + i)
self.passes = new_pass_list
self._validated = False

def __call__(self, gm: Any, settings: CompilationSettings) -> Any:
self.validate()
Loading
Oops, something went wrong.