Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[fx] Add matrix multiplication fusion pass (#50151)
Summary: Pull Request resolved: #50151 **Summary** This commit adds a graph transformation pass that merges several matrix multiplications that use the same RHS operand into one large matrix multiplication. The LHS operands from all of the smaller matrix multiplications are concatenated together and used as an input in the large matrix multiply, and the result is split in order to obtain the same products as the original set of matrix multiplications. **Test Plan** This commit adds a simple unit test with two matrix multiplications that share the same RHS operand. `python test/test_fx_experimental.py -k merge_matmul -v` Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25809409 Pulled By: SplitInfinity fbshipit-source-id: fb55c044a54dea9f07b71aa60d44b7a8f3966ed0
- Loading branch information
1 parent
838e73d
commit 11cdb91
Showing
2 changed files
with
343 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import torch | ||
|
||
from torch.fx.graph import Graph | ||
from torch.fx.graph_module import GraphModule | ||
from torch.fx.node import Node | ||
from torch.fx.symbolic_trace import symbolic_trace | ||
|
||
import itertools | ||
import operator | ||
|
||
from typing import Dict, List | ||
|
||
|
||
def get_first_dim(t: torch.Tensor) -> int: | ||
""" | ||
A free function primarily for use in the merge_matmul graph transformation below | ||
that returns the first dimension of a Tensor. This is necessary because torch.Tensor.shape | ||
is an attribute (and cannot be the target of a call_function node) and also helps save | ||
a getitem op in the graph. | ||
Arguments: | ||
t: The tensor to get the first dimension of. | ||
Returns: | ||
The first dimension of t. | ||
""" | ||
return t.shape[0] | ||
|
||
|
||
def legalize_graph(gm: GraphModule): | ||
""" | ||
Replace the graph of the given GraphModule with one that contains the same nodes as the | ||
original, but in topologically sorted order. | ||
This is used by the merge_matmul transformation below, which disturbs the topologically sorted | ||
order of its input GraphModule, so that this order is restored before further transformation. | ||
Arguments: | ||
gm: The graph module to topologically sort. It is modified in-place. | ||
""" | ||
# Build an adjacency list representation of node dependencies in the graph. This also | ||
# serves as a list of nodes that still need to be inserted into the new, topologically | ||
# sorted graph. | ||
dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes} | ||
|
||
# Construct a new graph that will contain all nodes in topologically sorted order. | ||
new_graph = Graph() | ||
value_remap: Dict[Node, Node] = {} | ||
|
||
# Copy over all nodes with no dependencies. | ||
for node, deps in dependencies.items(): | ||
if not deps: | ||
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) | ||
|
||
# Remove the copied over nodes from the adjacency list. | ||
for copied_node in value_remap.keys(): | ||
del dependencies[copied_node] | ||
|
||
# While there are still nodes to insert into the new graph: | ||
while dependencies: | ||
copied_this_round = [] | ||
|
||
# Copy over all nodes whose dependencies already exist in the new graph. | ||
for node, deps in dependencies.items(): | ||
all_deps_copied = True | ||
for dep in deps: | ||
if dep not in value_remap: | ||
all_deps_copied = False | ||
|
||
if all_deps_copied: | ||
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) | ||
copied_this_round.append(node) | ||
|
||
# Delete all nodes copied over in this iteration from dependencies. | ||
for copied_node in copied_this_round: | ||
del dependencies[copied_node] | ||
|
||
# Replace the old graph with the new, topologically sorted one. | ||
gm.graph = new_graph | ||
|
||
|
||
def may_depend_on(a: Node, b: Node, search_depth: int = 6): | ||
""" | ||
Determine if one node depends on another in a torch.fx.Graph. | ||
Arguments: | ||
a: The node that may have a dependency on b. | ||
b: The node that a may have a dependency on. | ||
search_depth: In the case of an indirect dependency, this function | ||
searches upto this many nodes away in search of a | ||
data dependency. If none is found, the function | ||
makes the conservative assumption that there is a | ||
dependency. | ||
Returns: | ||
True if a may depend on b, False if it definitely does not. | ||
""" | ||
# Equivalence is defined as dependence. | ||
if a == b: | ||
return True | ||
|
||
# If a has no inputs, it cannot depend on b. | ||
if len(a.all_input_nodes) == 0: | ||
return False | ||
|
||
# If the search depth has been exhausted and no conclusion has been | ||
# reached, assume that there is a data dependency. | ||
if search_depth == 0: | ||
return True | ||
|
||
# Recursively check all inputs of a. | ||
for inp in a.all_input_nodes: | ||
if may_depend_on(inp, b, search_depth - 1): | ||
return True | ||
|
||
return False | ||
|
||
|
||
def are_nodes_independent(nodes: List[Node]): | ||
""" | ||
Check if all of the given nodes are pairwise-data independent. | ||
Arguments: | ||
nodes: The nodes to check for data dependencies. | ||
Returns: | ||
True if any pair in nodes has a data dependency. | ||
""" | ||
# For each pair in nodes: | ||
for i, j in itertools.combinations(nodes, 2): | ||
if may_depend_on(i, j) or may_depend_on(j, i): | ||
return False | ||
|
||
return True | ||
|
||
|
||
def merge_matmul(in_mod: torch.nn.Module): | ||
""" | ||
A graph transformation that merges matrix multiplication operations that share the same right-hand | ||
side operand into one large matrix multiplication. | ||
____ _________ _________ | ||
---- | | | | M| A * C | | ||
M| A | T| B | * K| C | = |---------| | ||
---- , | | | | T| B * C | | ||
K ---- --------- --------- | ||
K R R | ||
""" | ||
gm = symbolic_trace(in_mod) | ||
|
||
rhs_users: Dict[Node, List[Node]] = {} | ||
lhs_users: Dict[Node, List[Node]] = {} | ||
|
||
# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to | ||
# the matmul of which they are the LHS/RHS. | ||
for node in gm.graph.nodes: | ||
if node.op != "call_function" or node.target is not torch.matmul: | ||
continue | ||
|
||
lhs, rhs = node.args | ||
|
||
# TODO: Properly handle aliasing caused by get_attr. For now, | ||
# use the attribute name as the operand if the node is a | ||
# get_attr. | ||
lhs = lhs.target if lhs.op == "get_attr" else lhs | ||
rhs = rhs.target if rhs.op == "get_attr" else rhs | ||
|
||
lhs_users.setdefault(lhs, []).append(node) | ||
rhs_users.setdefault(rhs, []).append(node) | ||
|
||
for rhs, mms in rhs_users.items(): | ||
# There must be at least matmuls for a merge to make sense. | ||
if len(mms) < 2: | ||
continue | ||
|
||
# All matmuls must not depend on each other directly or indirectly | ||
# in order for the merge to be possible. | ||
if not are_nodes_independent(mms): | ||
continue | ||
|
||
lhs_vals = [mm.args[0] for mm in mms] | ||
|
||
# Merge the matmul. | ||
# Collect a list of LHS operands and the single RHS operand. | ||
lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] | ||
rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs | ||
|
||
# Concatenate all the LHS operands. | ||
merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) | ||
|
||
# Multiply the concatenated LHS operands with the one RHS. This will produce | ||
# the same results as all the individual matmuls involving rhs in the original graph, | ||
# but they will all be concatenated together. | ||
merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) | ||
|
||
# Split the result of the merged matmul using the shapes of the LHS operands | ||
# to ascertain how large each chunk should be. | ||
merge_mm_sizes = [ | ||
gm.graph.call_function(get_first_dim, (l,), {}) for l in lhs | ||
] | ||
merge_mm_split = gm.graph.call_function( | ||
torch.split, (merge_mm, merge_mm_sizes), {} | ||
) | ||
merge_mm_res = [ | ||
gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) | ||
for out in range(len(lhs)) | ||
] | ||
|
||
# Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. | ||
for old, new in zip(mms, merge_mm_res): | ||
old.replace_all_uses_with(new) | ||
gm.graph.erase_node(old) | ||
|
||
# All of the new nodes created above were inserted at the end, so we need to sort | ||
# the nodes topologically to make sure all definitions precede uses. | ||
legalize_graph(gm) | ||
|
||
gm.recompile() | ||
gm.graph.lint(in_mod) | ||
return gm |