Skip to content

Commit

Permalink
[fx] Add matrix multiplication fusion pass (#50151)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #50151

**Summary**
This commit adds a graph transformation pass that merges several matrix
multiplications that use the same RHS operand into one large matrix
multiplication. The LHS operands from all of the smaller matrix multiplications
are concatenated together and used as an input in the large matrix multiply,
and the result is split in order to obtain the same products as the original
set of matrix multiplications.

**Test Plan**
This commit adds a simple unit test with two matrix multiplications that share
the same RHS operand.

`python test/test_fx_experimental.py -k merge_matmul -v`

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D25809409

Pulled By: SplitInfinity

fbshipit-source-id: fb55c044a54dea9f07b71aa60d44b7a8f3966ed0
  • Loading branch information
Meghan Lele authored and facebook-github-bot committed Jan 7, 2021
1 parent 838e73d commit 11cdb91
Show file tree
Hide file tree
Showing 2 changed files with 343 additions and 0 deletions.
123 changes: 123 additions & 0 deletions test/test_fx_experimental.py
Expand Up @@ -21,6 +21,7 @@
PartitionMode
)
from torch.fx.experimental.fuser import fuse
from torch.fx.experimental import merge_matmul

try:
from torchvision.models import resnet18
Expand Down Expand Up @@ -844,6 +845,128 @@ def forward(self, a):
for p_name in para_list:
assert p_name in node.attrs_for_lowering

def test_merge_matmuls(self):
"""
A collection of test cases for torch.fx.experimental.merge_matmul,
a graph transformation that merges matrix multiplication operations.
"""
# Utility function for counting matmuls for test assertions.
def _count_matmuls(mod):
gm = torch.fx.symbolic_trace(mod)

num_matmuls = 0
for node in gm.graph.nodes:
if node.target == torch.matmul:
num_matmuls += 1

return num_matmuls

# Simple test case in which there are two matmuls of the same size to merge.
class SimpleMergeMatmulModule(torch.nn.Module):
def __init__(self, rhs):
super().__init__()
self.rhs = rhs

def forward(self, x, y):
a = torch.matmul(x, self.rhs)
b = torch.matmul(y, self.rhs)
return a + b

# Initialize inputs.
a = torch.randn(3, 3)
b = torch.randn(3, 3)

# Initialize RHS for matmuls.
rhs = torch.randn(3, 4)

# Construct SimpleMergeMatmulModule and call merge_matmul on it.
module = SimpleMergeMatmulModule(rhs)
opt_module = merge_matmul.merge_matmul(module)

# Numerical correctness check.
before = module(a, b)
after = opt_module(a, b)
before.allclose(after)

# Basic graph structure check; original module should have 2 matmuls
# and optimized module should have 1.
self.assertEqual(_count_matmuls(module), 2)
self.assertEqual(_count_matmuls(opt_module), 1)

# Test case in which there are multiple matmuls of different sizes to merge.
class FiveMergeMatmulModule(torch.nn.Module):
def __init__(self, rhs):
super().__init__()
self.rhs = rhs

def forward(self, a, b, c, d, e):
s = torch.Tensor((0))
matmuls = []

# For some reason using a list comprehension or for-loop for this
# doesn't work.
matmuls.append(torch.matmul(a, self.rhs))
matmuls.append(torch.matmul(b, self.rhs))
matmuls.append(torch.matmul(c, self.rhs))
matmuls.append(torch.matmul(d, self.rhs))
matmuls.append(torch.matmul(e, self.rhs))

for m in matmuls:
s += torch.sum(m)

return s

# Initialize inputs.
inputs = [torch.randn(2 * i + 1, 5) for i in range(5)]

# Initialize RHS.
rhs = torch.randn(5, 4)

# Construct FiveMergeMatmulModule and call merge_matmul on it.
module = FiveMergeMatmulModule(rhs)
opt_module = merge_matmul.merge_matmul(module)

# Numerical correctness check.
before = module(*inputs)
after = opt_module(*inputs)
before.allclose(after)

# Basic graph structure check; original module should have len(inputs) matmuls
# and optimized module should have 1.
self.assertEqual(_count_matmuls(module), len(inputs))
self.assertEqual(_count_matmuls(opt_module), 1)

# Simple test case in which two matmuls cannot be merged due to a data dependency between
# the LHS operands.
class UnmergeableMatmulModule(torch.nn.Module):
def __init__(self, rhs):
super().__init__()
self.rhs = rhs

def forward(self, x):
a = torch.matmul(x, self.rhs)
a_abs = torch.abs(a)
b = torch.matmul(a_abs.transpose(1, 0), self.rhs)
return b

# Initialize inputs.
a = torch.randn(3, 3)

# Initialize RHS for matmuls.
rhs = torch.randn(3, 4)

# Construct UnmergeableMatmulModule and call merge_matmul on it.
module = UnmergeableMatmulModule(rhs)
opt_module = merge_matmul.merge_matmul(module)

# Numerical correctness check.
before = module(a)
after = opt_module(a)
before.allclose(after)

# Basic graph structure check; the number of matrix multiplcations should not have changed.
self.assertEqual(_count_matmuls(module), 2)
self.assertEqual(_count_matmuls(opt_module), 2)

if __name__ == "__main__":
run_tests()
220 changes: 220 additions & 0 deletions torch/fx/experimental/merge_matmul.py
@@ -0,0 +1,220 @@
import torch

from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx.symbolic_trace import symbolic_trace

import itertools
import operator

from typing import Dict, List


def get_first_dim(t: torch.Tensor) -> int:
"""
A free function primarily for use in the merge_matmul graph transformation below
that returns the first dimension of a Tensor. This is necessary because torch.Tensor.shape
is an attribute (and cannot be the target of a call_function node) and also helps save
a getitem op in the graph.
Arguments:
t: The tensor to get the first dimension of.
Returns:
The first dimension of t.
"""
return t.shape[0]


def legalize_graph(gm: GraphModule):
"""
Replace the graph of the given GraphModule with one that contains the same nodes as the
original, but in topologically sorted order.
This is used by the merge_matmul transformation below, which disturbs the topologically sorted
order of its input GraphModule, so that this order is restored before further transformation.
Arguments:
gm: The graph module to topologically sort. It is modified in-place.
"""
# Build an adjacency list representation of node dependencies in the graph. This also
# serves as a list of nodes that still need to be inserted into the new, topologically
# sorted graph.
dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes}

# Construct a new graph that will contain all nodes in topologically sorted order.
new_graph = Graph()
value_remap: Dict[Node, Node] = {}

# Copy over all nodes with no dependencies.
for node, deps in dependencies.items():
if not deps:
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])

# Remove the copied over nodes from the adjacency list.
for copied_node in value_remap.keys():
del dependencies[copied_node]

# While there are still nodes to insert into the new graph:
while dependencies:
copied_this_round = []

# Copy over all nodes whose dependencies already exist in the new graph.
for node, deps in dependencies.items():
all_deps_copied = True
for dep in deps:
if dep not in value_remap:
all_deps_copied = False

if all_deps_copied:
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
copied_this_round.append(node)

# Delete all nodes copied over in this iteration from dependencies.
for copied_node in copied_this_round:
del dependencies[copied_node]

# Replace the old graph with the new, topologically sorted one.
gm.graph = new_graph


def may_depend_on(a: Node, b: Node, search_depth: int = 6):
"""
Determine if one node depends on another in a torch.fx.Graph.
Arguments:
a: The node that may have a dependency on b.
b: The node that a may have a dependency on.
search_depth: In the case of an indirect dependency, this function
searches upto this many nodes away in search of a
data dependency. If none is found, the function
makes the conservative assumption that there is a
dependency.
Returns:
True if a may depend on b, False if it definitely does not.
"""
# Equivalence is defined as dependence.
if a == b:
return True

# If a has no inputs, it cannot depend on b.
if len(a.all_input_nodes) == 0:
return False

# If the search depth has been exhausted and no conclusion has been
# reached, assume that there is a data dependency.
if search_depth == 0:
return True

# Recursively check all inputs of a.
for inp in a.all_input_nodes:
if may_depend_on(inp, b, search_depth - 1):
return True

return False


def are_nodes_independent(nodes: List[Node]):
"""
Check if all of the given nodes are pairwise-data independent.
Arguments:
nodes: The nodes to check for data dependencies.
Returns:
True if any pair in nodes has a data dependency.
"""
# For each pair in nodes:
for i, j in itertools.combinations(nodes, 2):
if may_depend_on(i, j) or may_depend_on(j, i):
return False

return True


def merge_matmul(in_mod: torch.nn.Module):
"""
A graph transformation that merges matrix multiplication operations that share the same right-hand
side operand into one large matrix multiplication.
____ _________ _________
---- | | | | M| A * C |
M| A | T| B | * K| C | = |---------|
---- , | | | | T| B * C |
K ---- --------- ---------
K R R
"""
gm = symbolic_trace(in_mod)

rhs_users: Dict[Node, List[Node]] = {}
lhs_users: Dict[Node, List[Node]] = {}

# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
# the matmul of which they are the LHS/RHS.
for node in gm.graph.nodes:
if node.op != "call_function" or node.target is not torch.matmul:
continue

lhs, rhs = node.args

# TODO: Properly handle aliasing caused by get_attr. For now,
# use the attribute name as the operand if the node is a
# get_attr.
lhs = lhs.target if lhs.op == "get_attr" else lhs
rhs = rhs.target if rhs.op == "get_attr" else rhs

lhs_users.setdefault(lhs, []).append(node)
rhs_users.setdefault(rhs, []).append(node)

for rhs, mms in rhs_users.items():
# There must be at least matmuls for a merge to make sense.
if len(mms) < 2:
continue

# All matmuls must not depend on each other directly or indirectly
# in order for the merge to be possible.
if not are_nodes_independent(mms):
continue

lhs_vals = [mm.args[0] for mm in mms]

# Merge the matmul.
# Collect a list of LHS operands and the single RHS operand.
lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs

# Concatenate all the LHS operands.
merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})

# Multiply the concatenated LHS operands with the one RHS. This will produce
# the same results as all the individual matmuls involving rhs in the original graph,
# but they will all be concatenated together.
merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})

# Split the result of the merged matmul using the shapes of the LHS operands
# to ascertain how large each chunk should be.
merge_mm_sizes = [
gm.graph.call_function(get_first_dim, (l,), {}) for l in lhs
]
merge_mm_split = gm.graph.call_function(
torch.split, (merge_mm, merge_mm_sizes), {}
)
merge_mm_res = [
gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
for out in range(len(lhs))
]

# Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
for old, new in zip(mms, merge_mm_res):
old.replace_all_uses_with(new)
gm.graph.erase_node(old)

# All of the new nodes created above were inserted at the end, so we need to sort
# the nodes topologically to make sure all definitions precede uses.
legalize_graph(gm)

gm.recompile()
gm.graph.lint(in_mod)
return gm

0 comments on commit 11cdb91

Please sign in to comment.