diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 6e9c877b8de6..ac71d6037591 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -21,6 +21,7 @@ PartitionMode ) from torch.fx.experimental.fuser import fuse +from torch.fx.experimental import merge_matmul try: from torchvision.models import resnet18 @@ -844,6 +845,128 @@ def forward(self, a): for p_name in para_list: assert p_name in node.attrs_for_lowering + def test_merge_matmuls(self): + """ + A collection of test cases for torch.fx.experimental.merge_matmul, + a graph transformation that merges matrix multiplication operations. + """ + # Utility function for counting matmuls for test assertions. + def _count_matmuls(mod): + gm = torch.fx.symbolic_trace(mod) + + num_matmuls = 0 + for node in gm.graph.nodes: + if node.target == torch.matmul: + num_matmuls += 1 + + return num_matmuls + + # Simple test case in which there are two matmuls of the same size to merge. + class SimpleMergeMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, x, y): + a = torch.matmul(x, self.rhs) + b = torch.matmul(y, self.rhs) + return a + b + + # Initialize inputs. + a = torch.randn(3, 3) + b = torch.randn(3, 3) + + # Initialize RHS for matmuls. + rhs = torch.randn(3, 4) + + # Construct SimpleMergeMatmulModule and call merge_matmul on it. + module = SimpleMergeMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(a, b) + after = opt_module(a, b) + before.allclose(after) + + # Basic graph structure check; original module should have 2 matmuls + # and optimized module should have 1. + self.assertEqual(_count_matmuls(module), 2) + self.assertEqual(_count_matmuls(opt_module), 1) + + # Test case in which there are multiple matmuls of different sizes to merge. + class FiveMergeMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, a, b, c, d, e): + s = torch.Tensor((0)) + matmuls = [] + + # For some reason using a list comprehension or for-loop for this + # doesn't work. + matmuls.append(torch.matmul(a, self.rhs)) + matmuls.append(torch.matmul(b, self.rhs)) + matmuls.append(torch.matmul(c, self.rhs)) + matmuls.append(torch.matmul(d, self.rhs)) + matmuls.append(torch.matmul(e, self.rhs)) + + for m in matmuls: + s += torch.sum(m) + + return s + + # Initialize inputs. + inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] + + # Initialize RHS. + rhs = torch.randn(5, 4) + + # Construct FiveMergeMatmulModule and call merge_matmul on it. + module = FiveMergeMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(*inputs) + after = opt_module(*inputs) + before.allclose(after) + + # Basic graph structure check; original module should have len(inputs) matmuls + # and optimized module should have 1. + self.assertEqual(_count_matmuls(module), len(inputs)) + self.assertEqual(_count_matmuls(opt_module), 1) + + # Simple test case in which two matmuls cannot be merged due to a data dependency between + # the LHS operands. + class UnmergeableMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, x): + a = torch.matmul(x, self.rhs) + a_abs = torch.abs(a) + b = torch.matmul(a_abs.transpose(1, 0), self.rhs) + return b + + # Initialize inputs. + a = torch.randn(3, 3) + + # Initialize RHS for matmuls. + rhs = torch.randn(3, 4) + + # Construct UnmergeableMatmulModule and call merge_matmul on it. + module = UnmergeableMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(a) + after = opt_module(a) + before.allclose(after) + + # Basic graph structure check; the number of matrix multiplcations should not have changed. + self.assertEqual(_count_matmuls(module), 2) + self.assertEqual(_count_matmuls(opt_module), 2) if __name__ == "__main__": run_tests() diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py new file mode 100644 index 000000000000..b72bbe633dd9 --- /dev/null +++ b/torch/fx/experimental/merge_matmul.py @@ -0,0 +1,220 @@ +import torch + +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node +from torch.fx.symbolic_trace import symbolic_trace + +import itertools +import operator + +from typing import Dict, List + + +def get_first_dim(t: torch.Tensor) -> int: + """ + A free function primarily for use in the merge_matmul graph transformation below + that returns the first dimension of a Tensor. This is necessary because torch.Tensor.shape + is an attribute (and cannot be the target of a call_function node) and also helps save + a getitem op in the graph. + + Arguments: + t: The tensor to get the first dimension of. + + Returns: + The first dimension of t. + """ + return t.shape[0] + + +def legalize_graph(gm: GraphModule): + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + """ + # Build an adjacency list representation of node dependencies in the graph. This also + # serves as a list of nodes that still need to be inserted into the new, topologically + # sorted graph. + dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes} + + # Construct a new graph that will contain all nodes in topologically sorted order. + new_graph = Graph() + value_remap: Dict[Node, Node] = {} + + # Copy over all nodes with no dependencies. + for node, deps in dependencies.items(): + if not deps: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + + # Remove the copied over nodes from the adjacency list. + for copied_node in value_remap.keys(): + del dependencies[copied_node] + + # While there are still nodes to insert into the new graph: + while dependencies: + copied_this_round = [] + + # Copy over all nodes whose dependencies already exist in the new graph. + for node, deps in dependencies.items(): + all_deps_copied = True + for dep in deps: + if dep not in value_remap: + all_deps_copied = False + + if all_deps_copied: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + copied_this_round.append(node) + + # Delete all nodes copied over in this iteration from dependencies. + for copied_node in copied_this_round: + del dependencies[copied_node] + + # Replace the old graph with the new, topologically sorted one. + gm.graph = new_graph + + +def may_depend_on(a: Node, b: Node, search_depth: int = 6): + """ + Determine if one node depends on another in a torch.fx.Graph. + + Arguments: + a: The node that may have a dependency on b. + b: The node that a may have a dependency on. + search_depth: In the case of an indirect dependency, this function + searches upto this many nodes away in search of a + data dependency. If none is found, the function + makes the conservative assumption that there is a + dependency. + + Returns: + True if a may depend on b, False if it definitely does not. + """ + # Equivalence is defined as dependence. + if a == b: + return True + + # If a has no inputs, it cannot depend on b. + if len(a.all_input_nodes) == 0: + return False + + # If the search depth has been exhausted and no conclusion has been + # reached, assume that there is a data dependency. + if search_depth == 0: + return True + + # Recursively check all inputs of a. + for inp in a.all_input_nodes: + if may_depend_on(inp, b, search_depth - 1): + return True + + return False + + +def are_nodes_independent(nodes: List[Node]): + """ + Check if all of the given nodes are pairwise-data independent. + + Arguments: + nodes: The nodes to check for data dependencies. + + Returns: + True if any pair in nodes has a data dependency. + """ + # For each pair in nodes: + for i, j in itertools.combinations(nodes, 2): + if may_depend_on(i, j) or may_depend_on(j, i): + return False + + return True + + +def merge_matmul(in_mod: torch.nn.Module): + """ + A graph transformation that merges matrix multiplication operations that share the same right-hand + side operand into one large matrix multiplication. + ____ _________ _________ + ---- | | | | M| A * C | + M| A | T| B | * K| C | = |---------| + ---- , | | | | T| B * C | + K ---- --------- --------- + K R R + """ + gm = symbolic_trace(in_mod) + + rhs_users: Dict[Node, List[Node]] = {} + lhs_users: Dict[Node, List[Node]] = {} + + # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to + # the matmul of which they are the LHS/RHS. + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not torch.matmul: + continue + + lhs, rhs = node.args + + # TODO: Properly handle aliasing caused by get_attr. For now, + # use the attribute name as the operand if the node is a + # get_attr. + lhs = lhs.target if lhs.op == "get_attr" else lhs + rhs = rhs.target if rhs.op == "get_attr" else rhs + + lhs_users.setdefault(lhs, []).append(node) + rhs_users.setdefault(rhs, []).append(node) + + for rhs, mms in rhs_users.items(): + # There must be at least matmuls for a merge to make sense. + if len(mms) < 2: + continue + + # All matmuls must not depend on each other directly or indirectly + # in order for the merge to be possible. + if not are_nodes_independent(mms): + continue + + lhs_vals = [mm.args[0] for mm in mms] + + # Merge the matmul. + # Collect a list of LHS operands and the single RHS operand. + lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] + rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs + + # Concatenate all the LHS operands. + merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) + + # Multiply the concatenated LHS operands with the one RHS. This will produce + # the same results as all the individual matmuls involving rhs in the original graph, + # but they will all be concatenated together. + merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) + + # Split the result of the merged matmul using the shapes of the LHS operands + # to ascertain how large each chunk should be. + merge_mm_sizes = [ + gm.graph.call_function(get_first_dim, (l,), {}) for l in lhs + ] + merge_mm_split = gm.graph.call_function( + torch.split, (merge_mm, merge_mm_sizes), {} + ) + merge_mm_res = [ + gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) + for out in range(len(lhs)) + ] + + # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. + for old, new in zip(mms, merge_mm_res): + old.replace_all_uses_with(new) + gm.graph.erase_node(old) + + # All of the new nodes created above were inserted at the end, so we need to sort + # the nodes topologically to make sure all definitions precede uses. + legalize_graph(gm) + + gm.recompile() + gm.graph.lint(in_mod) + return gm