From 388cffa8b3559135fb033254512a6b68151d4f81 Mon Sep 17 00:00:00 2001
From: Naren Dasan <naren@narendasan.com>
Date: Thu, 22 May 2025 11:51:35 -0600
Subject: [PATCH] Simplify pass manager debug system

---
 .../dynamo/lowering/passes/pass_manager.py    | 77 ++++++++++++-------
 1 file changed, 48 insertions(+), 29 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py
index 73be3b2400..7dbaf70571 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py
@@ -1,4 +1,6 @@
-from typing import Any, Callable, List, Optional
+import tempfile
+from types import new_class
+from typing import Any, Callable, List, Optional, Union
 
 import torch
 from torch.fx import passes
@@ -6,30 +8,14 @@
 from torch_tensorrt.dynamo._settings import CompilationSettings
 
 
-def get_draw_fx_graph_pass_lowering(
-    idx: int, path_prefix: str, post: bool
+def _generate_draw_fx_graph_pass(
+    output_path_prefix: str, name: str
 ) -> Callable[[torch.fx.GraphModule, CompilationSettings], torch.fx.GraphModule]:
-    from torch_tensorrt.dynamo.lowering.passes import (
-        post_lowering_pass_list,
-        pre_lowering_pass_list,
-    )
-
-    PRE_DEBUG_NAME = {
-        i + 1: f"after_{p.__name__}" for i, p in enumerate(pre_lowering_pass_list)
-    }
-    PRE_DEBUG_NAME[0] = "exported_program"
-
-    POST_DEBUG_NAME = {
-        i + 1: f"after_{p.__name__}" for i, p in enumerate(post_lowering_pass_list)
-    }
-    POST_DEBUG_NAME[0] = "after_decomposition"
-
     def draw_fx_graph_pass(
         gm: torch.fx.GraphModule, settings: CompilationSettings
     ) -> torch.fx.GraphModule:
-        DEBUG_NAME = POST_DEBUG_NAME[idx] if post else PRE_DEBUG_NAME[idx]
-        path = f"{path_prefix}_{DEBUG_NAME}.svg"
-        g = passes.graph_drawer.FxGraphDrawer(gm, DEBUG_NAME)
+        path = f"{output_path_prefix}/{name}.svg"
+        g = passes.graph_drawer.FxGraphDrawer(gm, name)
         with open(path, "wb") as f:
             f.write(g.get_dot_graph().create_svg())
         return gm
@@ -47,8 +33,9 @@ def __init__(
                 ]
             ]
         ] = None,
+        constraints: Optional[List[Callable]] = None
     ):
-        super().__init__(passes)
+        super().__init__(passes, constraints)
 
     @classmethod
     def build_from_passlist(
@@ -80,16 +67,48 @@ def add_pass_with_index(
     def remove_pass_with_index(self, index: int) -> None:
         del self.passes[index]
 
-    def insert_debug_pass(
-        self, index: List[int], filename_prefix: str, post: bool = True
+    def insert_debug_pass_before(
+        self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
     ) -> None:
+        """Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass.
+
+        Args:
+            passes: List of pass names to insert debug passes before
+            output_path_prefix: Prefix to use for generated debug files
+
+        Debug passes generate SVG visualizations of the FX graph at specified points
+        in the pass sequence.
+        """
+        new_pass_list = []
+        for ps in self.passes:
+            if ps.__name__ in passes:
+                new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"before_{ps.__name__}"))
+            new_pass_list.append(ps)
+
+        self.passes = new_pass_list
+        self._validated = False
+
+    def insert_debug_pass_after(
+        self, passes: List[str], output_path_prefix: str=tempfile.gettempdir()
+    ) -> None:
+        """Insert debug passes in the PassManager pass sequence after the execution of a particular pass.
+
+        Args:
+            passes: List of pass names to insert debug passes after
+            output_path_prefix: Prefix to use for generated debug files
+
+        Debug passes generate SVG visualizations of the FX graph at specified points
+        in the pass sequence.
+        """
+        new_pass_list = []
+        for ps in self.passes:
+            new_pass_list.append(ps)
+            if ps.__name__ in passes:
+                new_pass_list.append(_generate_draw_fx_graph_pass(output_path_prefix, f"after_{ps.__name__}"))
 
-        for i in range(len(index)):
 
-            debug_pass = get_draw_fx_graph_pass_lowering(
-                index[i], filename_prefix, post
-            )
-            self.add_pass_with_index(debug_pass, index[i] + i)
+        self.passes = new_pass_list
+        self._validated = False
 
     def __call__(self, gm: Any, settings: CompilationSettings) -> Any:
         self.validate()