@@ -87,6 +87,14 @@ class _DebugCall:
8787 def __init__ (self , call_depth : int ):
8888 self .call_depth = call_depth
8989
90+ def stringify_args (self , attributes : list [str ]) -> None :
91+ """
92+ To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs.
93+ """
94+ raise NotImplementedError (
95+ "Subclasses must implement stringify_args(), even if no-op"
96+ )
97+
9098 def render (self , attributes : list [str ]) -> str :
9199 raise NotImplementedError ("Subclasses must implement string render()" )
92100
@@ -103,15 +111,35 @@ def __init__(self, op, args: tuple, kwargs: dict, call_depth: int):
103111 self .args = args
104112 self .kwargs = kwargs
105113
106- def render ( self , attributes : list [str ]) -> str :
107- args_str = ", " . join ( _arg_to_str ( arg , attributes ) for arg in self . args )
114+ self . args_str : Optional [str ] = None
115+ self . kwargs_str : Optional [ str ] = None
108116
117+ def stringify_args (self , attributes : list [str ]) -> None :
118+ self .args_str = ", " .join (_arg_to_str (arg , attributes ) for arg in self .args )
109119 if self .kwargs :
110- kwargs_str = ", " + ", " .join (
120+ self . kwargs_str = ", " + ", " .join (
111121 f"{ k } ={ _arg_to_str (v , attributes )} " for k , v in self .kwargs .items ()
112122 )
113123 else :
114- kwargs_str = ""
124+ self .kwargs_str = ""
125+ del self .args
126+ del self .kwargs
127+
128+ def render (self , attributes : list [str ]) -> str :
129+ if self .args_str is not None :
130+ args_str = self .args_str
131+ else :
132+ args_str = ", " .join (_arg_to_str (arg , attributes ) for arg in self .args )
133+
134+ if self .kwargs_str is not None :
135+ kwargs_str = self .kwargs_str
136+ else :
137+ if self .kwargs :
138+ kwargs_str = ", " + ", " .join (
139+ f"{ k } ={ _arg_to_str (v , attributes )} " for k , v in self .kwargs .items ()
140+ )
141+ else :
142+ kwargs_str = ""
115143
116144 if isinstance (self .op , torch ._ops .OpOverload ):
117145 op_name = self .op .__qualname__
@@ -124,7 +152,10 @@ def render(self, attributes: list[str]) -> str:
124152
125153 def __iter__ (self ):
126154 # for BC; tuple(self) returns (op, args, kwargs, call_depth)
127- yield from [self .op , self .args , self .kwargs , self .call_depth ]
155+ if self .args_str is not None :
156+ yield from [self .op , self .args_str , self .kwargs_str , self .call_depth ]
157+ else :
158+ yield from [self .op , self .args , self .kwargs , self .call_depth ]
128159
129160
130161class _RedistributeCall (_DebugCall ):
@@ -139,8 +170,18 @@ def __init__(
139170 self .dst_placement = dst_placement
140171 self .transform_info_str = transform_info_str
141172
173+ self .arg_str : Optional [str ] = None
174+
175+ def stringify_args (self , attributes : list [str ]) -> None :
176+ self .arg_str = f"{ _arg_to_str (self .arg , attributes )} "
177+ del self .arg
178+
142179 def render (self , attributes : list [str ]) -> str :
143- arg_str = f"{ _arg_to_str (self .arg , attributes )} "
180+ if self .arg_str is not None :
181+ arg_str = self .arg_str
182+ else :
183+ arg_str = f"{ _arg_to_str (self .arg , attributes )} "
184+
144185 if self .transform_info_str is not None : # prioritize over src/dst placements
145186 placement_str = f"trace: { self .transform_info_str } "
146187 else :
@@ -151,11 +192,16 @@ def render(self, attributes: list[str]) -> str:
151192
152193 def __iter__ (self ):
153194 # for BC; tuple(self) returns (op, placement info, kwargs, call_depth)
195+ if self .arg_str is not None :
196+ arg = self .arg_str
197+ else :
198+ arg = self .arg
199+
154200 yield REDISTRIBUTE_FUNC
155201 if self .transform_info_str :
156- yield [self . arg , self .transform_info_str ]
202+ yield [arg , self .transform_info_str ]
157203 else :
158- yield [self . arg , self .src_placement , self .dst_placement ]
204+ yield [arg , self .src_placement , self .dst_placement ]
159205 yield {}
160206 yield self .call_depth
161207
@@ -167,6 +213,9 @@ def __init__(self, module_name: str, call_depth: int):
167213 super ().__init__ (call_depth )
168214 self .module_name = module_name
169215
216+ def stringify_args (self , attributes : list [str ]) -> None :
217+ pass # nothing to stringify
218+
170219 def render (self , attributes : list [str ]) -> str :
171220 return f"[nn.Mod] { self .module_name } "
172221
@@ -188,22 +237,34 @@ def __init__(
188237 record_realtensor = True ,
189238 record_tensor_attributes = None ,
190239 record_nn_module = False ,
240+ store_original_args = False ,
191241 ):
192242 super ().__init__ ()
193243 import torch .distributed .tensor # noqa: F401
194244
195245 self .supports_higher_order_operators = True
246+
247+ # Pushes DebugMode onto the torchfunction stack, and records __torch_function__ calls as well.
248+ # WARNING: currently incompatible with torch.compile due to dynamo guard failures.
196249 self .record_torchfunction = record_torchfunction
250+ # Records __torch_dispatch__ calls on FakeTensors.
197251 self .record_faketensor = record_faketensor
252+ # Records __torch_dispatch__ calls on real tensors.
198253 self .record_realtensor = record_realtensor
254+ # Optional list[str] of tensor attributes, to be annotated in the string dump.
199255 self .record_tensor_attributes = record_tensor_attributes or []
200-
256+ # Uses ModTracker to record nn.Module entrances, as _NNModuleCall entries.
257+ # This flag currently has no effect on torch.compiled-regions.
201258 self .record_nn_module = record_nn_module
202259
203260 self .module_tracker : Optional [ModTracker ] = None
204261 if self .record_nn_module :
205262 self .module_tracker_setup ()
206263
264+ # If True, stores call args/kwargs in logs, without immediately stringifying.
265+ # Defaults to False for memory concerns.
266+ self .store_original_args = store_original_args
267+
207268 self .operators = []
208269 self .call_depth = 0
209270
@@ -214,11 +275,16 @@ def __init__(
214275 def ignore_compile_internals (cls ):
215276 return True
216277
278+ def _record_call (self , call ):
279+ if not self .store_original_args :
280+ call .stringify_args (self .record_tensor_attributes )
281+ self .operators .append (call )
282+
217283 def __torch_function__ (self , func , types , args = (), kwargs = None ):
218284 if kwargs is None :
219285 kwargs = {}
220286
221- self .operators . append (_OpCall (func , args , kwargs , self .call_depth ))
287+ self ._record_call (_OpCall (func , args , kwargs , self .call_depth ))
222288
223289 try :
224290 self .call_depth += 1
@@ -232,19 +298,17 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
232298
233299 # Record the operation with its call depth
234300 if torch .distributed .tensor .DTensor in types :
235- self .operators . append (_OpCall (func , args , kwargs , self .call_depth ))
301+ self ._record_call (_OpCall (func , args , kwargs , self .call_depth ))
236302 return NotImplemented
237303 elif FakeTensor in types or isinstance (
238304 _get_current_dispatch_mode (), FakeTensorMode
239305 ):
240306 if self .record_faketensor :
241307 if func != torch .ops .prim .device .default :
242- self .operators .append (
243- _OpCall (func , args , kwargs , self .call_depth + 1 )
244- )
308+ self ._record_call (_OpCall (func , args , kwargs , self .call_depth + 1 ))
245309 elif len (types ) == 0 :
246310 if self .record_realtensor :
247- self .operators . append (_OpCall (func , args , kwargs , self .call_depth + 1 ))
311+ self ._record_call (_OpCall (func , args , kwargs , self .call_depth + 1 ))
248312
249313 result = func (* args , ** kwargs )
250314
@@ -296,7 +360,7 @@ def record_redistribute_calls(
296360 transform_info_str : Optional [str ] = None ,
297361 ):
298362 try :
299- self .operators . append (
363+ self ._record_call (
300364 _RedistributeCall (
301365 arg ,
302366 src_placement = src_placement ,
0 commit comments