New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Partition modules #98628
Partition modules #98628
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98628
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 7434821: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_dynamo/output_graph.py
Outdated
@@ -870,13 +870,16 @@ def create_proxy( | |||
|
|||
nn_module_stack = tx.nn_module_stack | |||
if nn_module_stack: | |||
rv.node.meta["nn_module_stack"] = nn_module_stack.copy() | |||
nn_module_stack = nn_module_stack.copy() | |||
_, last_value = nn_module_stack.popitem() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that we lose the last entry in the stack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that it is not "lost", but not clear to me what is happening here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just renaming the last value in the stack 😅
Originally if we have some module like: x -> self.linear -> self.linear
with the two linear
pointing to the same class attribute, the nn_module_stack
does not tell the two linear calls apart, but the name of rv.node
does.
The torch IR will look something like:
placeholder arg_0 has nn_module_stack of "None"
call_module self_linear_0 has nn_module_satck of {'self_linear': ('self.linear', <class 'torch.nn.modules.linear.Linear'>)}
call_module self_linear_1 has nn_module_stack of {'self_linear': ('self.linear', <class 'torch.nn.modules.linear.Linear'>)}
So I'm renaming the first key to be the same as the node name:
placeholder arg_0 has nn_module_stack of "None"
call_module self_linear_0 has nn_module_satck of {'self_linear_0': ('self.linear', <class 'torch.nn.modules.linear.Linear'>)}
call_module self_linear_1 has nn_module_stack of {'self_linear_1': ('self.linear', <class 'torch.nn.modules.linear.Linear'>)}
f58a8b9
to
c13bae1
Compare
elif kind == "call_module": | ||
# For modules we store the class | ||
rv.node.meta["source_fn"] = rv.node.meta["nn_module_stack"][target][1] | ||
rv.node.meta["source_fn"] = ( | ||
rv.node.name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is rv.node.name
the qualified name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, it's the unique name of the node in the fx graph, so it will help us handle the case where if there are 2 linear module calls side by side in the graph.
@dataclass | ||
class SourcePartition(): | ||
# Nodes in a particular partition | ||
nodes: List[Node] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reasoning for having the partition be a list of node rather than the partitioned graph itself? We can derive nodes from the graph, and having the graph can help preserve the partitions structure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can generate a graph from the list. Do you want Graph
istead of List[Node]
? If so why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a simple api to convert a List[Node] --> Graph? If not, in the case I might want to use something like subgraph_rewriter after to replace these partition modules, using a graph as the pattern to replace rather than a list of nodes would be easier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, there exists fuse_as_graphmodule
@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks neat. Thanks Angela!
7d739a6
to
75633b2
Compare
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
b35be48
to
7434821
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: Meta Internal-Only Changes Check Details for Dev Infra teamRaised by workflow job |
@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@@ -295,10 +295,10 @@ def get_fused_kernel_name(node_schedule): | |||
sources = [] | |||
for origin in all_origins: | |||
if origin.op == "call_function" and "source_fn" in origin.meta: | |||
if isinstance(origin.meta["source_fn"], str): | |||
sources.append(origin.meta["source_fn"]) | |||
if isinstance(origin.meta["source_fn"][1], str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what ws this change for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified the "source_fn" metadata to additionally return a unique qualifying name for each function that is called so that if there are 2 modules that are called one after the other then we can distinguish between the two. This change is just to make inductor compatible.
output_nodes: List[Node] = field(default_factory=list) | ||
|
||
# Parameters that are being used | ||
params: List[str] = field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why was this list of strings instead of List[Node]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure how you wanted the parameters formatted so I just returned a list of the attributes of the parameters. But I can fix this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we could have NamedParameters
= tuple(str, Tensor) such that weight
would correspond to weight tensor. But talking to Sherlock, I remember this was harder. For now this is fine.
|
||
|
||
@compatibility(is_backward_compatible=False) | ||
def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is somewhat loose in that, two graphs maybe overlapping, right? I would expect you to check the output nodes being the input nodes to the second partition? That might be stricter?
if source_fn[1] not in wanted_sources: | ||
continue | ||
|
||
diff_modules = modules.setdefault(source_fn[1], {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Angela, so why are we using source_fn and not nn_module_stack. Main difference I see is that source_fn trackes the leaf module/function whereas nn_module_stack tracked the entire module hierarchy. I think it is useful to use nn_module_stack so that a node can belong to multiple partitions but ecah partition it belongs to must have a strict parent->child relation.
Reason why this might be useful is that when modules like LSTM or attentention get decomposed you still can get nodes that belong to higher level module like Attention. And if I were to quantize entire attention module, I can.
input_nodes.add(arg) | ||
|
||
if node.op == "get_attr": | ||
params.add(node.target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the reasons I mentioned above, like for module here, https://fburl.com/owodcrrr, if we have named parameters than it is easier to access. ALthough I dont know what happens to constants. If they are "burnt" in then it will be harder to figure out what are their "names".
Added helper functions to match nodes in the graph that are decomposed from their source (leaf modules, or functional ops), as a result of dynamo tracing.
get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]
Args:
Returns:
Example:
Original:
Traced graph:
Result of
get_module_partitions
:Also added helper function to check if two module partitions are connected:
check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool
cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire