-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[export] Serialize symbolic values (#103273)
* Modified the SymInt schema to also store the hint of the SymInt if it is represented as a symbol so that when we reconstruct the SymInt, the hint will also exist on the node. * GraphModuleDeserializer.deserialize now also optionally map of symbol names to range. ReplaceSymSizeOpPass should not be needed after #103107 lands Pull Request resolved: #103273 Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
- Loading branch information
1 parent
876695d
commit 8dc6001
Showing
7 changed files
with
367 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import Dict | ||
|
||
import torch | ||
from torch.fx.passes.infra.pass_base import PassBase | ||
|
||
replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = { | ||
torch.ops.aten.sym_size: torch.ops.aten.sym_size.int, | ||
torch.ops.aten.sym_stride: torch.ops.aten.sym_stride.int, | ||
torch.ops.aten.sym_numel: torch.ops.aten.sym_numel.default, | ||
} | ||
|
||
|
||
class _ReplaceSymSizeOpPass(PassBase): | ||
""" | ||
Replace torch.ops.aten.sym_size with torch.ops.aten.sym_size.int | ||
and torch.ops.aten.sym_stride with torch.ops.aten.sym_stride.int | ||
""" | ||
|
||
def call(self, graph_module): | ||
for module in graph_module.modules(): | ||
if not isinstance(module, torch.fx.GraphModule): | ||
continue | ||
for node in module.graph.nodes: | ||
if node.target in replacements: | ||
node.target = replacements[node.target] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.