Skip to content

Commit bce274d

Browse files
pianpwketaf
authored andcommitted
[DebugMode] store stringify args by default (#166347)
DebugMode currently stores dispatch call args & kwargs, which is all intermediate tensors and more. This quickly OOMed on GPU when trying to debug some torchtitan / llama 8b models. This defaults to storing the stringified version, adding a flag `DebugMode(store_original_args=True)` if users want to store the original args as-is (and for BC). Pull Request resolved: #166347 Approved by: https://github.com/yushangdi
1 parent 55af9f5 commit bce274d

File tree

2 files changed

+88
-16
lines changed

2 files changed

+88
-16
lines changed

test/distributed/tensor/debug/test_debug_mode.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def test_debug_mode_mm(self):
6464
self.assertTrue(isinstance(debug_mode.operators[2], _RedistributeCall))
6565
self.assertEqual(next(iter(debug_mode.operators[1])), torch.ops.aten.mm.default)
6666

67+
# check stringification
68+
self.assertTrue(hasattr(debug_mode.operators[0], "args_str"))
69+
self.assertFalse(hasattr(debug_mode.operators[0], "args"))
70+
6771
def test_debug_string_inside_context(self):
6872
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
6973

@@ -267,6 +271,7 @@ def test_tensor_attributes(self):
267271
record_torchfunction=True,
268272
record_faketensor=True,
269273
record_tensor_attributes=["a1", "a2"],
274+
store_original_args=True,
270275
) as debug_mode:
271276
torch.matmul(y, x)
272277

@@ -279,6 +284,9 @@ def test_tensor_attributes(self):
279284
aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""",
280285
)
281286

287+
self.assertTrue(hasattr(debug_mode.operators[0], "args"))
288+
self.assertEqual(id(debug_mode.operators[0].args[0]), id(y))
289+
282290
@parametrize("has_inner_mode", [True, False])
283291
@parametrize("has_outer_mode", [True, False])
284292
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):

torch/utils/_debug_mode.py

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ class _DebugCall:
8787
def __init__(self, call_depth: int):
8888
self.call_depth = call_depth
8989

90+
def stringify_args(self, attributes: list[str]) -> None:
91+
"""
92+
To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs.
93+
"""
94+
raise NotImplementedError(
95+
"Subclasses must implement stringify_args(), even if no-op"
96+
)
97+
9098
def render(self, attributes: list[str]) -> str:
9199
raise NotImplementedError("Subclasses must implement string render()")
92100

@@ -103,15 +111,35 @@ def __init__(self, op, args: tuple, kwargs: dict, call_depth: int):
103111
self.args = args
104112
self.kwargs = kwargs
105113

106-
def render(self, attributes: list[str]) -> str:
107-
args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args)
114+
self.args_str: Optional[str] = None
115+
self.kwargs_str: Optional[str] = None
108116

117+
def stringify_args(self, attributes: list[str]) -> None:
118+
self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args)
109119
if self.kwargs:
110-
kwargs_str = ", " + ", ".join(
120+
self.kwargs_str = ", " + ", ".join(
111121
f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items()
112122
)
113123
else:
114-
kwargs_str = ""
124+
self.kwargs_str = ""
125+
del self.args
126+
del self.kwargs
127+
128+
def render(self, attributes: list[str]) -> str:
129+
if self.args_str is not None:
130+
args_str = self.args_str
131+
else:
132+
args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args)
133+
134+
if self.kwargs_str is not None:
135+
kwargs_str = self.kwargs_str
136+
else:
137+
if self.kwargs:
138+
kwargs_str = ", " + ", ".join(
139+
f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items()
140+
)
141+
else:
142+
kwargs_str = ""
115143

116144
if isinstance(self.op, torch._ops.OpOverload):
117145
op_name = self.op.__qualname__
@@ -124,7 +152,10 @@ def render(self, attributes: list[str]) -> str:
124152

125153
def __iter__(self):
126154
# for BC; tuple(self) returns (op, args, kwargs, call_depth)
127-
yield from [self.op, self.args, self.kwargs, self.call_depth]
155+
if self.args_str is not None:
156+
yield from [self.op, self.args_str, self.kwargs_str, self.call_depth]
157+
else:
158+
yield from [self.op, self.args, self.kwargs, self.call_depth]
128159

129160

130161
class _RedistributeCall(_DebugCall):
@@ -139,8 +170,18 @@ def __init__(
139170
self.dst_placement = dst_placement
140171
self.transform_info_str = transform_info_str
141172

173+
self.arg_str: Optional[str] = None
174+
175+
def stringify_args(self, attributes: list[str]) -> None:
176+
self.arg_str = f"{_arg_to_str(self.arg, attributes)}"
177+
del self.arg
178+
142179
def render(self, attributes: list[str]) -> str:
143-
arg_str = f"{_arg_to_str(self.arg, attributes)}"
180+
if self.arg_str is not None:
181+
arg_str = self.arg_str
182+
else:
183+
arg_str = f"{_arg_to_str(self.arg, attributes)}"
184+
144185
if self.transform_info_str is not None: # prioritize over src/dst placements
145186
placement_str = f"trace: {self.transform_info_str}"
146187
else:
@@ -151,11 +192,16 @@ def render(self, attributes: list[str]) -> str:
151192

152193
def __iter__(self):
153194
# for BC; tuple(self) returns (op, placement info, kwargs, call_depth)
195+
if self.arg_str is not None:
196+
arg = self.arg_str
197+
else:
198+
arg = self.arg
199+
154200
yield REDISTRIBUTE_FUNC
155201
if self.transform_info_str:
156-
yield [self.arg, self.transform_info_str]
202+
yield [arg, self.transform_info_str]
157203
else:
158-
yield [self.arg, self.src_placement, self.dst_placement]
204+
yield [arg, self.src_placement, self.dst_placement]
159205
yield {}
160206
yield self.call_depth
161207

@@ -167,6 +213,9 @@ def __init__(self, module_name: str, call_depth: int):
167213
super().__init__(call_depth)
168214
self.module_name = module_name
169215

216+
def stringify_args(self, attributes: list[str]) -> None:
217+
pass # nothing to stringify
218+
170219
def render(self, attributes: list[str]) -> str:
171220
return f"[nn.Mod] {self.module_name}"
172221

@@ -188,22 +237,34 @@ def __init__(
188237
record_realtensor=True,
189238
record_tensor_attributes=None,
190239
record_nn_module=False,
240+
store_original_args=False,
191241
):
192242
super().__init__()
193243
import torch.distributed.tensor # noqa: F401
194244

195245
self.supports_higher_order_operators = True
246+
247+
# Pushes DebugMode onto the torchfunction stack, and records __torch_function__ calls as well.
248+
# WARNING: currently incompatible with torch.compile due to dynamo guard failures.
196249
self.record_torchfunction = record_torchfunction
250+
# Records __torch_dispatch__ calls on FakeTensors.
197251
self.record_faketensor = record_faketensor
252+
# Records __torch_dispatch__ calls on real tensors.
198253
self.record_realtensor = record_realtensor
254+
# Optional list[str] of tensor attributes, to be annotated in the string dump.
199255
self.record_tensor_attributes = record_tensor_attributes or []
200-
256+
# Uses ModTracker to record nn.Module entrances, as _NNModuleCall entries.
257+
# This flag currently has no effect on torch.compiled-regions.
201258
self.record_nn_module = record_nn_module
202259

203260
self.module_tracker: Optional[ModTracker] = None
204261
if self.record_nn_module:
205262
self.module_tracker_setup()
206263

264+
# If True, stores call args/kwargs in logs, without immediately stringifying.
265+
# Defaults to False for memory concerns.
266+
self.store_original_args = store_original_args
267+
207268
self.operators = []
208269
self.call_depth = 0
209270

@@ -214,11 +275,16 @@ def __init__(
214275
def ignore_compile_internals(cls):
215276
return True
216277

278+
def _record_call(self, call):
279+
if not self.store_original_args:
280+
call.stringify_args(self.record_tensor_attributes)
281+
self.operators.append(call)
282+
217283
def __torch_function__(self, func, types, args=(), kwargs=None):
218284
if kwargs is None:
219285
kwargs = {}
220286

221-
self.operators.append(_OpCall(func, args, kwargs, self.call_depth))
287+
self._record_call(_OpCall(func, args, kwargs, self.call_depth))
222288

223289
try:
224290
self.call_depth += 1
@@ -232,19 +298,17 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
232298

233299
# Record the operation with its call depth
234300
if torch.distributed.tensor.DTensor in types:
235-
self.operators.append(_OpCall(func, args, kwargs, self.call_depth))
301+
self._record_call(_OpCall(func, args, kwargs, self.call_depth))
236302
return NotImplemented
237303
elif FakeTensor in types or isinstance(
238304
_get_current_dispatch_mode(), FakeTensorMode
239305
):
240306
if self.record_faketensor:
241307
if func != torch.ops.prim.device.default:
242-
self.operators.append(
243-
_OpCall(func, args, kwargs, self.call_depth + 1)
244-
)
308+
self._record_call(_OpCall(func, args, kwargs, self.call_depth + 1))
245309
elif len(types) == 0:
246310
if self.record_realtensor:
247-
self.operators.append(_OpCall(func, args, kwargs, self.call_depth + 1))
311+
self._record_call(_OpCall(func, args, kwargs, self.call_depth + 1))
248312

249313
result = func(*args, **kwargs)
250314

@@ -296,7 +360,7 @@ def record_redistribute_calls(
296360
transform_info_str: Optional[str] = None,
297361
):
298362
try:
299-
self.operators.append(
363+
self._record_call(
300364
_RedistributeCall(
301365
arg,
302366
src_placement=src_placement,

0 commit comments

Comments
 (0)