Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/distributed/tensor/debug/test_debug_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,46 @@ def f(x):
f(x)
self.assertEqual(len(debug_mode.debug_string()), 0)

def test_nn_module(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(4, 4)
self.l2 = torch.nn.Linear(4, 4)

def forward(self, x):
return self.l2(self.l1(x))

class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.abc = Foo()
self.xyz = torch.nn.Linear(4, 4)

def forward(self, x):
return self.xyz(self.abc(x))

mod = Bar()
inp = torch.randn(4, 4)
with DebugMode(record_nn_module=True) as debug_mode:
_ = mod(inp)

self.assertExpectedInline(
debug_mode.debug_string(),
"""\
[nn.Mod] Bar
[nn.Mod] Bar.abc
[nn.Mod] Bar.abc.l1
aten::t(t: f32[4, 4])
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
[nn.Mod] Bar.abc.l2
aten::t(t: f32[4, 4])
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
[nn.Mod] Bar.xyz
aten::t(t: f32[4, 4])
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
)


instantiate_parametrized_tests(TestDTensorDebugMode)

Expand Down
45 changes: 44 additions & 1 deletion torch/utils/_debug_mode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Optional
from typing import Optional, TYPE_CHECKING

import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
Expand All @@ -13,6 +13,10 @@
from torch.utils._pytree import tree_map


if TYPE_CHECKING:
from torch.distributed._tools.mod_tracker import ModTracker


__all__ = ["DebugMode", "get_active_debug_mode"]

REDISTRIBUTE_FUNC = "redistribute_input"
Expand Down Expand Up @@ -139,6 +143,17 @@ def render(self, attributes: list[str]) -> str:
return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})"


class _NNModuleCall(_DebugCall):
"""Designates entering an nn.Module's forward method"""

def __init__(self, module_name: str, call_depth: int):
super().__init__(call_depth)
self.module_name = module_name

def render(self, attributes: list[str]) -> str:
return f"[nn.Mod] {self.module_name}"


class DebugMode(TorchDispatchMode):
def __init__(
self,
Expand All @@ -147,6 +162,7 @@ def __init__(
record_faketensor=False,
record_realtensor=True,
record_tensor_attributes=None,
record_nn_module=False,
):
super().__init__()
import torch.distributed.tensor # noqa: F401
Expand All @@ -157,6 +173,12 @@ def __init__(
self.record_realtensor = record_realtensor
self.record_tensor_attributes = record_tensor_attributes or []

self.record_nn_module = record_nn_module

self.module_tracker: Optional[ModTracker] = None
if self.record_nn_module:
self.module_tracker_setup()

self.operators = []
self.call_depth = 0

Expand Down Expand Up @@ -211,14 +233,35 @@ def __enter__(self):
torch._C._push_on_torch_function_stack(self)

super().__enter__()
if self.record_nn_module:
self.module_tracker.__enter__() # type: ignore[attribute, union-attr]
return self

# pyrefly: ignore # bad-override
def __exit__(self, *args):
super().__exit__(*args)
if self.record_nn_module:
self.module_tracker.__exit__() # type: ignore[attribute, union-attr]
if self.record_torchfunction:
torch._C._pop_torch_function_stack()

def module_tracker_setup(self):
from torch.distributed._tools.mod_tracker import ModTracker

self.module_tracker = ModTracker()

# module pre-fw hook: record module call
def pre_fw_hook(module, input):
fqn = self.module_tracker._get_mod_name(module) # type: ignore[attribute, union-attr]
self.operators.append(_NNModuleCall(fqn, self.call_depth + 1))
self.call_depth += 1

# module post-fw hook: decrement call depth
def post_fw_hook(module, input, output):
self.call_depth -= 1

self.module_tracker.register_user_hooks(pre_fw_hook, post_fw_hook)

@contextlib.contextmanager
def record_redistribute_calls(
self,
Expand Down
Loading