-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[export] Initial serialization #102125
[export] Initial serialization #102125
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/102125
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3c0676b: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Serialization TODOs: - [ ] pytree spec - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] shape env - [ ] graph module metadata? [ghstack-poisoned]
Serialization TODOs: - [ ] pytree spec - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] shape env - [ ] graph module metadata? [ghstack-poisoned]
Serialization TODOs: - [ ] pytree spec - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] shape env - [ ] graph module metadata? [ghstack-poisoned]
Serialization TODOs: - [ ] pytree spec - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] shape env - [ ] graph module metadata? [ghstack-poisoned]
Serialization TODOs: - [ ] pytree spec - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] shape env - [ ] graph module metadata? [ghstack-poisoned]
Serialization TODOs: - [ ] pytree spec - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] shape env - [ ] graph module metadata? [ghstack-poisoned]
Serialization TODOs: - [ ] pytree spec - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] shape env - [ ] graph module metadata? [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great, lots of questions tho.
Serialization TODOs: - [ ] pytree spec: #102577 - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] shape env - [ ] graph module metadata? [ghstack-poisoned]
ghstack-source-id: 170e292382340441224db66295232fddf7899e03 Pull Request resolved: #102125
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accepting to unblock, lgtm. See comments though.
@dataclass | ||
class Node: | ||
target: Operator | ||
target: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Because version info will reside elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, it'll reside in the toplevel GraphModule
@@ -143,7 +137,7 @@ class Graph: | |||
@dataclass | |||
class BackwardSignature: | |||
gradients_to_parameters: Dict[str, str] | |||
gradients_to_userInputs: Dict[str, str] | |||
gradients_to_user_inputs: Dict[str, str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's "user" supposed to mean here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure..copied from the internal thrift file
buffers: Dict[str, TensorMeta] | ||
parameters: Dict[str, TensorMeta] | ||
metadata: Dict[str, str] | ||
opset_version: Dict[str, int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we keeping the dict here or somewhere else (indexed by a single int)? @larryliu0820
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My naive approach is to serialize a dict directly. It will be a namespace to op set version mapping. Example: {“aten”: 3, “custom_namespace”: 4}. We can change this in the future if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so in any case, no mapping of op names (also str) to versions, yeah? That mapping will exist somewhere else in a table / code?
signature: GraphSignature | ||
call_spec: CallSpec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you describe what the purpose of call_spec
is? Won't graph_module
have the call_spec
mappings burned in?
Is there a standard transform that uses this information?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call_spec contains the in_spec/out_spec which is used when we're running the graph module eagerly so that the inputs/output formats match how we would run the original function eagerly. We call it in the ExportedProgram's call function: https://github.com/pytorch/pytorch/blob/main/torch/_export/exported_program.py#L94
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bit confused about why we also then have flatten / unflatten calls in a graph module...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mm, it defines the way we pass inputs and receive outputs to the graph module. Zhengxu stated below that it makes more sense to put the GraphSignature in GraphModule, and I think CallSpec should go with wherever the Signature goes.
Serialization TODOs: - [ ] pytree spec: #102577 - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] shape env - [ ] graph module metadata? [ghstack-poisoned]
ghstack-source-id: 91b93b80ebc856e95c83015c12993bccdd6b2161 Pull Request resolved: #102125
buffers: Dict[str, TensorMeta] | ||
parameters: Dict[str, TensorMeta] | ||
metadata: Dict[str, str] | ||
opset_version: Dict[str, int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@angelayi I think it makes more sense to me if we put opset_version in ExportedProgram, and put signature in GraphModule. The rationale is to closely mirror the schema in sigmoid.
Summary: v2 of #102125 because of git issues corresponding deserialization diff: #102716 Implementing serialization of the exported program to a python dataclass, and then from that dataclass to json. This is split into a couple of sections: - `serialize(ep: ep.ExportedProgram, opset_version: Dict[str, int]) -> Tuple[bytes, bytes]` -- takes an exported program object, a dictionary mapping opset namespaces to versions, and returns the serialized exported program in bytes, and separately the state dict serialized in bytes - `GraphModuleSerializer` class that serializes torch.fx.GraphModule to the schema.GraphModule dataclass - `ExportedProgramSerializer` class that serializes torch._export.exported_program.ExportedProgram to the schema.ExportedProgram dataclass Serialization TODOs: - [x] pytree spec: #102577 - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] constraints - [ ] graph module metadata The tests are not super comprehensive, but that's because I think it'll be better tested + easier to test once deserialization is implemented. Pull Request resolved: #102707 Reviewed By: zhxchen17 Differential Revision: D46362466 Pulled By: angelayi fbshipit-source-id: 1d3fc157a7a5c2e615dbcc7f0e87d76f2f4c43ed
Summary: v2 of #102125 because of git issues corresponding deserialization diff: #102716 Implementing serialization of the exported program to a python dataclass, and then from that dataclass to json. This is split into a couple of sections: - `serialize(ep: ep.ExportedProgram, opset_version: Dict[str, int]) -> Tuple[bytes, bytes]` -- takes an exported program object, a dictionary mapping opset namespaces to versions, and returns the serialized exported program in bytes, and separately the state dict serialized in bytes - `GraphModuleSerializer` class that serializes torch.fx.GraphModule to the schema.GraphModule dataclass - `ExportedProgramSerializer` class that serializes torch._export.exported_program.ExportedProgram to the schema.ExportedProgram dataclass Serialization TODOs: - [x] pytree spec: #102577 - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] constraints - [ ] graph module metadata The tests are not super comprehensive, but that's because I think it'll be better tested + easier to test once deserialization is implemented. Pull Request resolved: #102707 Reviewed By: zhxchen17 Differential Revision: D46362466 Pulled By: angelayi fbshipit-source-id: 033cb9a22d905d944e182dba3b191df4c52413c8
Summary: v2 of #102125 because of git issues corresponding deserialization diff: #102716 Implementing serialization of the exported program to a python dataclass, and then from that dataclass to json. This is split into a couple of sections: - `serialize(ep: ep.ExportedProgram, opset_version: Dict[str, int]) -> Tuple[bytes, bytes]` -- takes an exported program object, a dictionary mapping opset namespaces to versions, and returns the serialized exported program in bytes, and separately the state dict serialized in bytes - `GraphModuleSerializer` class that serializes torch.fx.GraphModule to the schema.GraphModule dataclass - `ExportedProgramSerializer` class that serializes torch._export.exported_program.ExportedProgram to the schema.ExportedProgram dataclass Serialization TODOs: - [x] pytree spec: #102577 - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] constraints - [ ] graph module metadata The tests are not super comprehensive, but that's because I think it'll be better tested + easier to test once deserialization is implemented. Pull Request resolved: #102707 Reviewed By: zhxchen17 Differential Revision: D46362466 Pulled By: angelayi fbshipit-source-id: 22b0c38ddf3887e5966c0fe0b00c6984c30d98a9
Summary: v2 of #102125 because of git issues corresponding deserialization diff: #102716 Implementing serialization of the exported program to a python dataclass, and then from that dataclass to json. This is split into a couple of sections: - `serialize(ep: ep.ExportedProgram, opset_version: Dict[str, int]) -> Tuple[bytes, bytes]` -- takes an exported program object, a dictionary mapping opset namespaces to versions, and returns the serialized exported program in bytes, and separately the state dict serialized in bytes - `GraphModuleSerializer` class that serializes torch.fx.GraphModule to the schema.GraphModule dataclass - `ExportedProgramSerializer` class that serializes torch._export.exported_program.ExportedProgram to the schema.ExportedProgram dataclass Serialization TODOs: - [x] pytree spec: #102577 - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] constraints - [ ] graph module metadata The tests are not super comprehensive, but that's because I think it'll be better tested + easier to test once deserialization is implemented. Pull Request resolved: #102707 Reviewed By: zhxchen17 Differential Revision: D46362466 Pulled By: angelayi fbshipit-source-id: 32766639106abc0c4cea03bd298254140e7f3a1a
Summary: v2 of #102125 because of git issues corresponding deserialization diff: #102716 Implementing serialization of the exported program to a python dataclass, and then from that dataclass to json. This is split into a couple of sections: - `serialize(ep: ep.ExportedProgram, opset_version: Dict[str, int]) -> Tuple[bytes, bytes]` -- takes an exported program object, a dictionary mapping opset namespaces to versions, and returns the serialized exported program in bytes, and separately the state dict serialized in bytes - `GraphModuleSerializer` class that serializes torch.fx.GraphModule to the schema.GraphModule dataclass - `ExportedProgramSerializer` class that serializes torch._export.exported_program.ExportedProgram to the schema.ExportedProgram dataclass Serialization TODOs: - [x] pytree spec: #102577 - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] constraints - [ ] graph module metadata The tests are not super comprehensive, but that's because I think it'll be better tested + easier to test once deserialization is implemented. Pull Request resolved: #102707 Reviewed By: zhxchen17 Differential Revision: D46362466 Pulled By: angelayi fbshipit-source-id: 8e7d5cd4769bd6b4dcf64036dab43d54d7d4493a
Summary: v2 of #102125 because of git issues corresponding deserialization diff: #102716 Implementing serialization of the exported program to a python dataclass, and then from that dataclass to json. This is split into a couple of sections: - `serialize(ep: ep.ExportedProgram, opset_version: Dict[str, int]) -> Tuple[bytes, bytes]` -- takes an exported program object, a dictionary mapping opset namespaces to versions, and returns the serialized exported program in bytes, and separately the state dict serialized in bytes - `GraphModuleSerializer` class that serializes torch.fx.GraphModule to the schema.GraphModule dataclass - `ExportedProgramSerializer` class that serializes torch._export.exported_program.ExportedProgram to the schema.ExportedProgram dataclass Serialization TODOs: - [x] pytree spec: #102577 - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] constraints - [ ] graph module metadata The tests are not super comprehensive, but that's because I think it'll be better tested + easier to test once deserialization is implemented. Pull Request resolved: #102707 Reviewed By: zhxchen17 Differential Revision: D46362466 Pulled By: angelayi fbshipit-source-id: 8627d9f783cea5af9c36b09f4216c7effc021593
v2 of #102125 because of git issues corresponding deserialization diff: #102716 Implementing serialization of the exported program to a python dataclass, and then from that dataclass to json. This is split into a couple of sections: - `serialize(ep: ep.ExportedProgram, opset_version: Dict[str, int]) -> Tuple[bytes, bytes]` -- takes an exported program object, a dictionary mapping opset namespaces to versions, and returns the serialized exported program in bytes, and separately the state dict serialized in bytes - `GraphModuleSerializer` class that serializes torch.fx.GraphModule to the schema.GraphModule dataclass - `ExportedProgramSerializer` class that serializes torch._export.exported_program.ExportedProgram to the schema.ExportedProgram dataclass Serialization TODOs: - [x] pytree spec: #102577 - [ ] higher order ops - [ ] node metadata (specifically nn_module_stack/source_fn) - [ ] constraints - [ ] graph module metadata The tests are not super comprehensive, but that's because I think it'll be better tested + easier to test once deserialization is implemented. Pull Request resolved: #102707 Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
Serialization TODOs:
Stack from ghstack (oldest at bottom):