Skip to content

Commit

Permalink
Propagate tokens in aotautograd (#127028)
Browse files Browse the repository at this point in the history
Test Plan: `buck run mode/dev-nosan //aimp/experimental/pt2:pt2_export -- --model-entity-id 938593492 --output /tmp/938593492.zip --use-torchrec-eager-mp --use-manifold`

Differential Revision: D57750072

Pull Request resolved: #127028
Approved by: https://github.com/tugsbayasgalan
  • Loading branch information
angelayi authored and pytorchmergebot committed May 24, 2024
1 parent 99a11ef commit cb6ef68
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
self.deserialize_outputs(serialized_node, fx_node)
else:
raise SerializeError(
f"Unsupported target type for node {serialized_node}: {target}"
f"Unsupported target type for node {serialized_node}: {type(target)}"
)

fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
Expand Down
2 changes: 1 addition & 1 deletion torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def convert(idx, x):
subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta,
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
is_train=needs_autograd,
tokens=fw_metadata.tokens,
)

if fw_metadata.num_intermediate_bases > 0:
Expand Down Expand Up @@ -682,7 +683,6 @@ def convert(idx, x):
)
# Once all fw_metadata_wrappers have run, runtime_metadata is fixed
runtime_metadata = fw_metadata

compiled_fn = compiler_fn(
flat_fn, fake_flat_args, aot_config, fw_metadata=runtime_metadata
)
Expand Down

0 comments on commit cb6ef68

Please sign in to comment.