diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 06bdd9305b..4ff6c8187b 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -8,7 +8,7 @@ from typing import Callable, Optional import torch -from torch._dynamo.functional_export import _dynamo_graph_capture_for_export +from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, @@ -33,67 +33,6 @@ def _clear_traced_params_buffers( setattr(traced_module, key, buffer) -def _restore_state_dict( - original_module: torch.nn.Module, traced_module: torch.fx.GraphModule -) -> None: - """ - TODO: move this into torch.export - Restores the state dict of the traced module to match the original module exactly. - Preserves the original FQNs with dots, creating intermediate empty modules as needed. - Ensures that the ordering of parameters/buffers matches the original module. - """ - # Build ID-based lookups for traced module params/buffers - traced_params: dict[int, tuple[str, torch.nn.Parameter]] = {} - for name, param in traced_module.named_parameters(remove_duplicate=False): - traced_params[id(param)] = (name, param) - - traced_buffers: dict[int, tuple[str, torch.Tensor]] = {} - for name, buffer in traced_module.named_buffers(remove_duplicate=False): - traced_buffers[id(buffer)] = (name, buffer) - - # Build mapping from old names to new names for graph node updates - name_mapping: dict[str, str] = {} - - # Restore parameters in the order they appear in original module - for orig_name, orig_param in original_module.named_parameters( - remove_duplicate=False - ): - if id(orig_param) in traced_params: - # This param exists in traced module - restore it with original FQN - traced_name, traced_param = traced_params[id(orig_param)] - torch.fx.graph_module._assign_attr(traced_param, traced_module, orig_name) - torch.fx.graph_module._del_attr(traced_module, traced_name) - name_mapping[traced_name] = orig_name - else: - # This param doesn't exist in traced module - add it - torch.fx.graph_module._assign_attr(orig_param, traced_module, orig_name) - - # Restore buffers in the order they appear in original module - for orig_name, orig_buffer in original_module.named_buffers(remove_duplicate=False): - if id(orig_buffer) in traced_buffers: - # This buffer exists in traced module - restore it with original FQN - traced_name, traced_buffer = traced_buffers[id(orig_buffer)] - torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name) - name_mapping[traced_name] = orig_name - torch.fx.graph_module._del_attr(traced_module, traced_name) - else: - # This buffer doesn't exist in traced module - add it - torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name) - - param_names = [v[0] for v in traced_params.values()] - buffer_names = [v[0] for v in traced_buffers.values()] - const_keys = set(param_names + buffer_names).difference(set(name_mapping.keys())) - - _clear_traced_params_buffers(traced_module, const_keys) - - # Update get_attr nodes in the graph to use the correct FQNs - for node in traced_module.graph.nodes: - if node.op == "get_attr" and node.target in name_mapping: - node.target = name_mapping[node.target] - - traced_module.recompile() - - def export_joint( model, args, kwargs=None ) -> tuple[JointWithDescriptors, TracingContext]: @@ -101,20 +40,16 @@ def export_joint( kwargs = {} assert isinstance(args, tuple) assert isinstance(kwargs, dict) - with torch._dynamo.config.patch( - install_free_tensors=True - ), torch.fx.traceback.preserve_node_meta(): - # TODO: switch to use the official graph_capture API once it is ready - gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs) - - # Restore the state dict to match the original module - _restore_state_dict(model, gm) - + with ( + # TODO Investigate error on MOE model with use_grouped_mm=False. + # For repro, see: https://gist.github.com/zhxchen17/d794ff58236243d9faddf713b9fc6a61 + torch._dynamo.config.patch(fake_tensor_cache_enabled=False), + torch.fx.traceback.preserve_node_meta(), + ): + gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) logger.info("Dynamo gm:") logger.info(gm.print_readable(print_output=False)) - - fake_mode = gm.meta.get("fake_mode", None) - tracing_context = TracingContext(fake_mode) + tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): return (