Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partition modules #98628

Closed
wants to merge 9 commits into from
Closed

Partition modules #98628

wants to merge 9 commits into from

Conversation

angelayi
Copy link
Contributor

@angelayi angelayi commented Apr 7, 2023

Added helper functions to match nodes in the graph that are decomposed from their source (leaf modules, or functional ops), as a result of dynamo tracing.

get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]

Args:

  • graph: The graph we want to partition
  • wanted_sources: List of sources of nodes that were decomposed from this source. This can be a function (ex. torch.nn.functional.linear) or a leaf module type (ex. torch.nn.Linear)

Returns:

  • Dictionary mapping sources (ex. torch.nn.modules.linear.Linear) to a list of SourcePartitions that correspond to the list of nodes that were flattened from a module of that type.
@dataclass
class SourcePartition():
    # Nodes in a particular partition
    nodes: List[Node]
    # Module type
    module_type: Type
    # Nodes in the graph that are needed as inputs to the partition
    input_nodes: List[Node] = field(default_factory=list)
    # Nodes in the partition that are being used by nodes outside of the partition
    output_nodes: List[Node] = field(default_factory=list)
    # Parameters that are being used
    params: List[str] = field(default_factory=list)

Example:

Original:

x -> linear -> linear -> relu -> linear

Traced graph:

.graph():
    %arg0 : [#users=1] = placeholder[target=arg0]
    %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
    %t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0,), kwargs = {})
    %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
    %_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
    %t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0_1,), kwargs = {})
    %_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
    %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
    %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
    %t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
    %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
    %addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
    return [addmm_default_2]

Result of get_module_partitions:

{<class 'torch.nn.modules.linear.Linear'>: [
    ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]), 
    ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]), 
    ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])], 

 <class 'torch.nn.modules.activation.ReLU'>: [
    ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}

Also added helper function to check if two module partitions are connected:
check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 7, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98628

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 7434821:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@@ -870,13 +870,16 @@ def create_proxy(

nn_module_stack = tx.nn_module_stack
if nn_module_stack:
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
nn_module_stack = nn_module_stack.copy()
_, last_value = nn_module_stack.popitem()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that we lose the last entry in the stack?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that it is not "lost", but not clear to me what is happening here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just renaming the last value in the stack 😅

Originally if we have some module like: x -> self.linear -> self.linear with the two linear pointing to the same class attribute, the nn_module_stack does not tell the two linear calls apart, but the name of rv.node does.

The torch IR will look something like:

placeholder arg_0 has nn_module_stack of "None"
call_module self_linear_0 has nn_module_satck of {'self_linear': ('self.linear', <class 'torch.nn.modules.linear.Linear'>)}
call_module self_linear_1 has nn_module_stack of {'self_linear': ('self.linear', <class 'torch.nn.modules.linear.Linear'>)}

So I'm renaming the first key to be the same as the node name:

placeholder arg_0 has nn_module_stack of "None"
call_module self_linear_0 has nn_module_satck of {'self_linear_0': ('self.linear', <class 'torch.nn.modules.linear.Linear'>)}
call_module self_linear_1 has nn_module_stack of {'self_linear_1': ('self.linear', <class 'torch.nn.modules.linear.Linear'>)}

elif kind == "call_module":
# For modules we store the class
rv.node.meta["source_fn"] = rv.node.meta["nn_module_stack"][target][1]
rv.node.meta["source_fn"] = (
rv.node.name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is rv.node.name the qualified name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's the unique name of the node in the fx graph, so it will help us handle the case where if there are 2 linear module calls side by side in the graph.

@dataclass
class SourcePartition():
# Nodes in a particular partition
nodes: List[Node]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reasoning for having the partition be a list of node rather than the partitioned graph itself? We can derive nodes from the graph, and having the graph can help preserve the partitions structure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can generate a graph from the list. Do you want Graph istead of List[Node]? If so why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a simple api to convert a List[Node] --> Graph? If not, in the case I might want to use something like subgraph_rewriter after to replace these partition modules, using a graph as the pattern to replace rather than a list of nodes would be easier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there exists fuse_as_graphmodule

@facebook-github-bot
Copy link
Contributor

@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks neat. Thanks Angela!

@angelayi
Copy link
Contributor Author

angelayi commented May 3, 2023

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased partition_modules onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout partition_modules && git pull --rebase)

@angelayi
Copy link
Contributor Author

angelayi commented May 3, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: Meta Internal-Only Changes Check

Details for Dev Infra team Raised by workflow job

@facebook-github-bot
Copy link
Contributor

@angelayi has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@angelayi
Copy link
Contributor Author

angelayi commented May 3, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@@ -295,10 +295,10 @@ def get_fused_kernel_name(node_schedule):
sources = []
for origin in all_origins:
if origin.op == "call_function" and "source_fn" in origin.meta:
if isinstance(origin.meta["source_fn"], str):
sources.append(origin.meta["source_fn"])
if isinstance(origin.meta["source_fn"][1], str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what ws this change for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the "source_fn" metadata to additionally return a unique qualifying name for each function that is called so that if there are 2 modules that are called one after the other then we can distinguish between the two. This change is just to make inductor compatible.

output_nodes: List[Node] = field(default_factory=list)

# Parameters that are being used
params: List[str] = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this list of strings instead of List[Node]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure how you wanted the parameters formatted so I just returned a list of the attributes of the parameters. But I can fix this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we could have NamedParameters = tuple(str, Tensor) such that weight would correspond to weight tensor. But talking to Sherlock, I remember this was harder. For now this is fine.



@compatibility(is_backward_compatible=False)
def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat loose in that, two graphs maybe overlapping, right? I would expect you to check the output nodes being the input nodes to the second partition? That might be stricter?

if source_fn[1] not in wanted_sources:
continue

diff_modules = modules.setdefault(source_fn[1], {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Angela, so why are we using source_fn and not nn_module_stack. Main difference I see is that source_fn trackes the leaf module/function whereas nn_module_stack tracked the entire module hierarchy. I think it is useful to use nn_module_stack so that a node can belong to multiple partitions but ecah partition it belongs to must have a strict parent->child relation.

Reason why this might be useful is that when modules like LSTM or attentention get decomposed you still can get nodes that belong to higher level module like Attention. And if I were to quantize entire attention module, I can.

input_nodes.add(arg)

if node.op == "get_attr":
params.add(node.target)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the reasons I mentioned above, like for module here, https://fburl.com/owodcrrr, if we have named parameters than it is easier to access. ALthough I dont know what happens to constants. If they are "burnt" in then it will be harder to figure out what are their "names".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants