From 11cdb910b4af2b4abb6a9c98a325f6f378347fba Mon Sep 17 00:00:00 2001 From: Meghan Lele Date: Wed, 6 Jan 2021 21:46:56 -0800 Subject: [PATCH] [fx] Add matrix multiplication fusion pass (#50151) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50151 **Summary** This commit adds a graph transformation pass that merges several matrix multiplications that use the same RHS operand into one large matrix multiplication. The LHS operands from all of the smaller matrix multiplications are concatenated together and used as an input in the large matrix multiply, and the result is split in order to obtain the same products as the original set of matrix multiplications. **Test Plan** This commit adds a simple unit test with two matrix multiplications that share the same RHS operand. `python test/test_fx_experimental.py -k merge_matmul -v` Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25809409 Pulled By: SplitInfinity fbshipit-source-id: fb55c044a54dea9f07b71aa60d44b7a8f3966ed0 --- test/test_fx_experimental.py | 123 ++++++++++++++ torch/fx/experimental/merge_matmul.py | 220 ++++++++++++++++++++++++++ 2 files changed, 343 insertions(+) create mode 100644 torch/fx/experimental/merge_matmul.py diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 6e9c877b8de6..ac71d6037591 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -21,6 +21,7 @@ PartitionMode ) from torch.fx.experimental.fuser import fuse +from torch.fx.experimental import merge_matmul try: from torchvision.models import resnet18 @@ -844,6 +845,128 @@ def forward(self, a): for p_name in para_list: assert p_name in node.attrs_for_lowering + def test_merge_matmuls(self): + """ + A collection of test cases for torch.fx.experimental.merge_matmul, + a graph transformation that merges matrix multiplication operations. + """ + # Utility function for counting matmuls for test assertions. + def _count_matmuls(mod): + gm = torch.fx.symbolic_trace(mod) + + num_matmuls = 0 + for node in gm.graph.nodes: + if node.target == torch.matmul: + num_matmuls += 1 + + return num_matmuls + + # Simple test case in which there are two matmuls of the same size to merge. + class SimpleMergeMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, x, y): + a = torch.matmul(x, self.rhs) + b = torch.matmul(y, self.rhs) + return a + b + + # Initialize inputs. + a = torch.randn(3, 3) + b = torch.randn(3, 3) + + # Initialize RHS for matmuls. + rhs = torch.randn(3, 4) + + # Construct SimpleMergeMatmulModule and call merge_matmul on it. + module = SimpleMergeMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(a, b) + after = opt_module(a, b) + before.allclose(after) + + # Basic graph structure check; original module should have 2 matmuls + # and optimized module should have 1. + self.assertEqual(_count_matmuls(module), 2) + self.assertEqual(_count_matmuls(opt_module), 1) + + # Test case in which there are multiple matmuls of different sizes to merge. + class FiveMergeMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, a, b, c, d, e): + s = torch.Tensor((0)) + matmuls = [] + + # For some reason using a list comprehension or for-loop for this + # doesn't work. + matmuls.append(torch.matmul(a, self.rhs)) + matmuls.append(torch.matmul(b, self.rhs)) + matmuls.append(torch.matmul(c, self.rhs)) + matmuls.append(torch.matmul(d, self.rhs)) + matmuls.append(torch.matmul(e, self.rhs)) + + for m in matmuls: + s += torch.sum(m) + + return s + + # Initialize inputs. + inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] + + # Initialize RHS. + rhs = torch.randn(5, 4) + + # Construct FiveMergeMatmulModule and call merge_matmul on it. + module = FiveMergeMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(*inputs) + after = opt_module(*inputs) + before.allclose(after) + + # Basic graph structure check; original module should have len(inputs) matmuls + # and optimized module should have 1. + self.assertEqual(_count_matmuls(module), len(inputs)) + self.assertEqual(_count_matmuls(opt_module), 1) + + # Simple test case in which two matmuls cannot be merged due to a data dependency between + # the LHS operands. + class UnmergeableMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, x): + a = torch.matmul(x, self.rhs) + a_abs = torch.abs(a) + b = torch.matmul(a_abs.transpose(1, 0), self.rhs) + return b + + # Initialize inputs. + a = torch.randn(3, 3) + + # Initialize RHS for matmuls. + rhs = torch.randn(3, 4) + + # Construct UnmergeableMatmulModule and call merge_matmul on it. + module = UnmergeableMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(a) + after = opt_module(a) + before.allclose(after) + + # Basic graph structure check; the number of matrix multiplcations should not have changed. + self.assertEqual(_count_matmuls(module), 2) + self.assertEqual(_count_matmuls(opt_module), 2) if __name__ == "__main__": run_tests() diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py new file mode 100644 index 000000000000..b72bbe633dd9 --- /dev/null +++ b/torch/fx/experimental/merge_matmul.py @@ -0,0 +1,220 @@ +import torch + +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node +from torch.fx.symbolic_trace import symbolic_trace + +import itertools +import operator + +from typing import Dict, List + + +def get_first_dim(t: torch.Tensor) -> int: + """ + A free function primarily for use in the merge_matmul graph transformation below + that returns the first dimension of a Tensor. This is necessary because torch.Tensor.shape + is an attribute (and cannot be the target of a call_function node) and also helps save + a getitem op in the graph. + + Arguments: + t: The tensor to get the first dimension of. + + Returns: + The first dimension of t. + """ + return t.shape[0] + + +def legalize_graph(gm: GraphModule): + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + """ + # Build an adjacency list representation of node dependencies in the graph. This also + # serves as a list of nodes that still need to be inserted into the new, topologically + # sorted graph. + dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes} + + # Construct a new graph that will contain all nodes in topologically sorted order. + new_graph = Graph() + value_remap: Dict[Node, Node] = {} + + # Copy over all nodes with no dependencies. + for node, deps in dependencies.items(): + if not deps: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + + # Remove the copied over nodes from the adjacency list. + for copied_node in value_remap.keys(): + del dependencies[copied_node] + + # While there are still nodes to insert into the new graph: + while dependencies: + copied_this_round = [] + + # Copy over all nodes whose dependencies already exist in the new graph. + for node, deps in dependencies.items(): + all_deps_copied = True + for dep in deps: + if dep not in value_remap: + all_deps_copied = False + + if all_deps_copied: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + copied_this_round.append(node) + + # Delete all nodes copied over in this iteration from dependencies. + for copied_node in copied_this_round: + del dependencies[copied_node] + + # Replace the old graph with the new, topologically sorted one. + gm.graph = new_graph + + +def may_depend_on(a: Node, b: Node, search_depth: int = 6): + """ + Determine if one node depends on another in a torch.fx.Graph. + + Arguments: + a: The node that may have a dependency on b. + b: The node that a may have a dependency on. + search_depth: In the case of an indirect dependency, this function + searches upto this many nodes away in search of a + data dependency. If none is found, the function + makes the conservative assumption that there is a + dependency. + + Returns: + True if a may depend on b, False if it definitely does not. + """ + # Equivalence is defined as dependence. + if a == b: + return True + + # If a has no inputs, it cannot depend on b. + if len(a.all_input_nodes) == 0: + return False + + # If the search depth has been exhausted and no conclusion has been + # reached, assume that there is a data dependency. + if search_depth == 0: + return True + + # Recursively check all inputs of a. + for inp in a.all_input_nodes: + if may_depend_on(inp, b, search_depth - 1): + return True + + return False + + +def are_nodes_independent(nodes: List[Node]): + """ + Check if all of the given nodes are pairwise-data independent. + + Arguments: + nodes: The nodes to check for data dependencies. + + Returns: + True if any pair in nodes has a data dependency. + """ + # For each pair in nodes: + for i, j in itertools.combinations(nodes, 2): + if may_depend_on(i, j) or may_depend_on(j, i): + return False + + return True + + +def merge_matmul(in_mod: torch.nn.Module): + """ + A graph transformation that merges matrix multiplication operations that share the same right-hand + side operand into one large matrix multiplication. + ____ _________ _________ + ---- | | | | M| A * C | + M| A | T| B | * K| C | = |---------| + ---- , | | | | T| B * C | + K ---- --------- --------- + K R R + """ + gm = symbolic_trace(in_mod) + + rhs_users: Dict[Node, List[Node]] = {} + lhs_users: Dict[Node, List[Node]] = {} + + # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to + # the matmul of which they are the LHS/RHS. + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not torch.matmul: + continue + + lhs, rhs = node.args + + # TODO: Properly handle aliasing caused by get_attr. For now, + # use the attribute name as the operand if the node is a + # get_attr. + lhs = lhs.target if lhs.op == "get_attr" else lhs + rhs = rhs.target if rhs.op == "get_attr" else rhs + + lhs_users.setdefault(lhs, []).append(node) + rhs_users.setdefault(rhs, []).append(node) + + for rhs, mms in rhs_users.items(): + # There must be at least matmuls for a merge to make sense. + if len(mms) < 2: + continue + + # All matmuls must not depend on each other directly or indirectly + # in order for the merge to be possible. + if not are_nodes_independent(mms): + continue + + lhs_vals = [mm.args[0] for mm in mms] + + # Merge the matmul. + # Collect a list of LHS operands and the single RHS operand. + lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] + rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs + + # Concatenate all the LHS operands. + merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) + + # Multiply the concatenated LHS operands with the one RHS. This will produce + # the same results as all the individual matmuls involving rhs in the original graph, + # but they will all be concatenated together. + merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) + + # Split the result of the merged matmul using the shapes of the LHS operands + # to ascertain how large each chunk should be. + merge_mm_sizes = [ + gm.graph.call_function(get_first_dim, (l,), {}) for l in lhs + ] + merge_mm_split = gm.graph.call_function( + torch.split, (merge_mm, merge_mm_sizes), {} + ) + merge_mm_res = [ + gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) + for out in range(len(lhs)) + ] + + # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. + for old, new in zip(mms, merge_mm_res): + old.replace_all_uses_with(new) + gm.graph.erase_node(old) + + # All of the new nodes created above were inserted at the end, so we need to sort + # the nodes topologically to make sure all definitions precede uses. + legalize_graph(gm) + + gm.recompile() + gm.graph.lint(in_mod) + return gm