From 2fa22dceee088e1118197f31b1e45fe8c5bfc8fa Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Tue, 14 Oct 2025 18:52:29 -0700 Subject: [PATCH 1/2] Add nn.Module tracking to DebugMode This change adds the ability to track nn.Module forward calls in DebugMode. When record_nn_module=True, DebugMode will use ModTracker to record module hierarchy and display module names in the debug output. Added test coverage for nested module tracking. [ghstack-poisoned] --- .../tensor/debug/test_debug_mode.py | 37 ++++++++++++++++ torch/utils/_debug_mode.py | 44 ++++++++++++++++++- 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 21641ca04e10..a529d8880013 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -521,6 +521,43 @@ def forward(self, x): record = debug_mode.operators[3].record["output"] self.assertTrue(torch.allclose(record, x1)) + def test_nn_module(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(4, 4) + self.l2 = torch.nn.Linear(4, 4) + def forward(self, x): + return self.l2(self.l1(x)) + + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + self.abc = Foo() + self.xyz = torch.nn.Linear(4, 4) + def forward(self, x): + return self.xyz(self.abc(x)) + + mod = Bar() + inp = torch.randn(4, 4) + with DebugMode(record_nn_module=True) as debug_mode: + _ = mod(inp) + + self.assertExpectedInline( + debug_mode.debug_string(), + """\ + [nn.Mod] Bar + [nn.Mod] Bar.abc + [nn.Mod] Bar.abc.l1 + aten::t(t: f32[4, 4]) + aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4]) + [nn.Mod] Bar.abc.l2 + aten::t(t: f32[4, 4]) + aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4]) + [nn.Mod] Bar.xyz + aten::t(t: f32[4, 4]) + aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""", + ) instantiate_parametrized_tests(TestDTensorDebugMode) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 8f97b2e4fb7a..2a93912b4cb1 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import contextlib import traceback -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -14,6 +14,9 @@ from torch.utils._pytree import tree_map from torch.utils._traceback import CapturedTraceback +if TYPE_CHECKING: + from torch.distributed._tools.mod_tracker import ModTracker + __all__ = ["DebugMode", "get_active_debug_mode"] @@ -244,6 +247,17 @@ def render(self, attributes: list[str]) -> str: return f"{node_str}{log_str}" +class _NNModuleCall(_DebugCall): + """Designates entering an nn.Module's forward method""" + + def __init__(self, module_name: str, call_depth: int): + super().__init__(call_depth, record=None, log=None) + self.module_name = module_name + + def render(self, attributes: list[str]) -> str: + return f"[nn.Mod] {self.module_name}" + + def _run_hook(hook, *args): out = hook(*args) assert isinstance(out, dict) and all(isinstance(k, str) for k in out.keys()) @@ -334,6 +348,7 @@ def __init__( record_faketensor=False, record_realtensor=True, record_tensor_attributes=None, + record_nn_module=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 @@ -344,6 +359,12 @@ def __init__( self.record_realtensor = record_realtensor self.record_tensor_attributes = record_tensor_attributes or [] + self.record_nn_module = record_nn_module + + self.module_tracker: Optional["ModTracker"] = None + if self.record_nn_module: + self.module_tracker_setup() + self.operators = [] self.call_depth = 0 @@ -403,14 +424,35 @@ def __enter__(self): torch._C._push_on_torch_function_stack(self) super().__enter__() + if self.record_nn_module: + self.module_tracker.__enter__() # type: ignore[attribute] return self # pyrefly: ignore # bad-override def __exit__(self, *args): super().__exit__(*args) + if self.record_nn_module: + self.module_tracker.__exit__() # type: ignore[attribute] if self.record_torchfunction: torch._C._pop_torch_function_stack() + def module_tracker_setup(self): + from torch.distributed._tools.mod_tracker import ModTracker + + self.module_tracker = ModTracker() + + # module pre-fw hook: record module call + def pre_fw_hook(module, input): + fqn = self.module_tracker._get_mod_name(module) # type: ignore[attribute] + self.operators.append(_NNModuleCall(fqn, self.call_depth + 1)) + self.call_depth += 1 + + # module post-fw hook: decrement call depth + def post_fw_hook(module, input, output): + self.call_depth -= 1 + + self.module_tracker.register_user_hooks(pre_fw_hook, post_fw_hook) + @staticmethod @contextlib.contextmanager def dispatch_stack_trace(cpp=False): From 09865cf869fa465b603438ac4accd302b4c2762a Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Tue, 14 Oct 2025 22:16:03 -0700 Subject: [PATCH 2/2] Update on "[DebugMode][6/N] add nn.Module tracking" Uses ModTracker to record nn.Module entries, much like CommDebugMode. Can be switched on with `DebugMode(record_nn_module=True)`: ``` [nn.Mod] Bar [nn.Mod] Bar.abc [nn.Mod] Bar.abc.l1 aten::t(t: f32[4, 4]) aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4]) [nn.Mod] Bar.abc.l2 aten::t(t: f32[4, 4]) aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4]) [nn.Mod] Bar.xyz aten::t(t: f32[4, 4]) aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""" ``` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned] --- test/distributed/tensor/debug/test_debug_mode.py | 3 +++ torch/utils/_debug_mode.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index a529d8880013..308032d3887d 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -527,6 +527,7 @@ def __init__(self): super().__init__() self.l1 = torch.nn.Linear(4, 4) self.l2 = torch.nn.Linear(4, 4) + def forward(self, x): return self.l2(self.l1(x)) @@ -535,6 +536,7 @@ def __init__(self): super().__init__() self.abc = Foo() self.xyz = torch.nn.Linear(4, 4) + def forward(self, x): return self.xyz(self.abc(x)) @@ -559,6 +561,7 @@ def forward(self, x): aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""", ) + instantiate_parametrized_tests(TestDTensorDebugMode) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 2a93912b4cb1..b946a01d42f5 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -14,6 +14,7 @@ from torch.utils._pytree import tree_map from torch.utils._traceback import CapturedTraceback + if TYPE_CHECKING: from torch.distributed._tools.mod_tracker import ModTracker @@ -361,7 +362,7 @@ def __init__( self.record_nn_module = record_nn_module - self.module_tracker: Optional["ModTracker"] = None + self.module_tracker: Optional[ModTracker] = None if self.record_nn_module: self.module_tracker_setup()