Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 88 additions & 73 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch.fx
from executorch.backends.cadence.aot.compiler_utils import (
get_shape,
get_tensor_from_attr,
get_zero_point,
is_node_with_op,
quantize_tensor_multiplier,
Expand Down Expand Up @@ -321,90 +320,106 @@ def call_operator(self, op, args, kwargs, meta):


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class ReplaceAddMMWithLinearPass(ExportPass):
class ReplaceAddMMWithLinearPass(RemoveOrReplacePassInterface):
"""
This pass replaces addmm with linear op.

AddMM computes: beta*bias + alpha*mm(mat1, mat2)
Linear computes: mat1 @ weight.T + bias

"""

def __init__(self):
super().__init__()
self.counter = 0
@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.addmm.default]

def replace_addmm_with_linear(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in graph.nodes:
# We are only interested in admm nodes
if node.target != exir_ops.edge.aten.addmm.default:
continue
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# The addmm op has three concrete args: bias, mat1, mat2
assert len(node.args) >= 3
(bias, mat1, mat2) = node.args[0:3]

# The addmm op has three concrete args: input, mat1, mat2
assert len(node.args) >= 3
(bias, mat1, mat2) = node.args[0:3]
# The other two args are optional scale args
beta = node.kwargs.get("beta", 1.0)
alpha = node.kwargs.get("alpha", 1.0)

# AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert
# it to linear op by multiplying beta to bias, and alpha to mat2.t().
# However, the following two conditions must hold:
# a. If bias is not a param, then beta must be 1.0
# b. If mat2 is not a param, then mat2 must be a transpose op. Also,
# the input to the transpose must be a param, or alpha must be 1.0.
fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0
fit_mat2 = is_node_with_op(mat2, "get_attr")
transposed_mat2 = False
if (
not fit_mat2
and is_node_with_op(mat2, "call_function")
and mat2.target == exir_ops.edge.aten.transpose_copy.int
):
mat2, transposed_mat2 = mat2.args[0], True
fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0
# The other two args are optional scale args
beta = float(node.kwargs.get("beta", 1.0))
alpha = float(node.kwargs.get("alpha", 1.0))

if not fit_bias or not fit_mat2:
continue
bias, mat1, mat2 = cast(
tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node],
(bias, mat1, mat2),
)

graph = node.graph

# Handle transpose: if mat2 is a transpose op, extract the original tensor
transposed_mat2 = False
if (
mat2.op == "call_function"
and mat2.target == exir_ops.edge.aten.transpose_copy.int
):
# mat2 is already transposed, so we use the input to the transpose
mat2 = cast(torch.fx.Node, mat2.args[0])
transposed_mat2 = True

# Multiply bias by beta if needed
if beta != 1.0:
# Create a scaled bias using element-wise multiplication in the graph
with graph.inserting_before(node):
beta_scalar = graph.call_function(
exir_ops.edge.aten.full.default,
args=([1], beta),
kwargs={"dtype": torch.float32},
)
beta_scalar.meta = node.meta
bias = graph.call_function(
exir_ops.edge.aten.mul.Tensor,
args=(bias, beta_scalar),
)

# Multiply bias by beta
if beta != 1.0:
assert is_node_with_op(bias, "get_attr")
bias_tensor = get_tensor_from_attr(graph_module, bias)
assert isinstance(bias_tensor, torch.Tensor)
bias_tensor = beta * bias_tensor
with graph.inserting_before(node):
bias_name = f"_bias_addmm_to_linear_{self.counter}"
graph_module.register_buffer(bias_name, bias_tensor)
bias = graph.get_attr(bias_name)

# Use associativity of scalar multiplication, and multiply alpha to mat2
if is_node_with_op(mat2, "get_attr"):
mat2_tensor = get_tensor_from_attr(graph_module, mat2)
assert isinstance(mat2_tensor, torch.Tensor)
mat2_tensor = alpha * mat2_tensor
# transpose mat2
mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t()
with graph.inserting_before(node):
mat2_name = f"_mat2_addmm_to_linear_{self.counter}"
graph_module.register_buffer(mat2_name, mat2_tensor)
mat2 = graph.get_attr(mat2_name)

# Construct the linear node
linear_args = (mat1, mat2, bias)
# Metadata copy important
bias.meta = node.meta

# Multiply mat2 by alpha if needed
if alpha != 1.0:
with graph.inserting_before(node):
linear_node = graph.call_function(
exir_ops.edge.aten.linear.default, args=linear_args
alpha_scalar = graph.call_function(
exir_ops.edge.aten.full.default,
args=([1], alpha),
kwargs={"dtype": torch.float32},
)
alpha_scalar.meta = node.meta
mat2 = graph.call_function(
exir_ops.edge.aten.mul.Tensor,
args=(mat2, alpha_scalar),
)
linear_node.meta = node.meta
# Replace all the uses of the addmm op with linear op
node.replace_all_uses_with(linear_node)
self.counter += 1

graph_module.recompile()
graph_module.graph.eliminate_dead_code()
# Metadata copy important
mat2.meta = node.meta

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.replace_addmm_with_linear(graph_module)
result = super().call(graph_module)
return result
# Transpose mat2 if it wasn't already transposed
if not transposed_mat2:
with graph.inserting_before(node):
mat2 = graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(mat2, -1, -2),
)

# Metadata copy important
mat2.meta = node.meta

# Construct the linear node: linear(input, weight, bias)
# linear computes: input @ weight.T + bias
linear_args = (mat1, mat2, bias)
with graph.inserting_before(node):
linear_node = graph.call_function(
exir_ops.edge.aten.linear.default,
args=linear_args,
)

# Metadata copy important
linear_node.meta = node.meta

# Replace all uses of the addmm op with linear op
node.replace_all_uses_with(linear_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand Down
91 changes: 66 additions & 25 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,19 @@ def validate(
modified: torch.fx.GraphModule,
inputs: tuple[torch.Tensor, ...] | list[torch.Tensor],
pass_name: str,
rtol: float = 1e-5,
atol: float = 1e-6,
) -> None:
"""Validate that two graph modules produce numerically equivalent outputs.

Args:
original: The original graph module before the pass
modified: The modified graph module after the pass
inputs: Input tensors to run through both graphs
pass_name: Name of the pass being validated (for error messages)
rtol: Relative tolerance for allclose comparison
atol: Absolute tolerance for allclose comparison
"""
original.eval()
modified.eval()
with torch.no_grad():
Expand All @@ -74,10 +86,17 @@ def validate(

flat_orig_out, _ = pytree.tree_flatten(orig_out)
flat_mod_out, _ = pytree.tree_flatten(mod_out)
if not all(pytree.tree_map(torch.equal, flat_orig_out, flat_mod_out)):
raise AssertionError(
f"Pass validation failed with exact match for pass {pass_name}. Original graph {original} and modified graph {modified}"
)

# Check that outputs match within tolerance
for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)):
if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol):
max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item()
raise AssertionError(
f"Pass validation failed for pass {pass_name}. "
f"Output tensor {i} differs by max {max_diff:.6e}. "
f"Expected rtol={rtol}, atol={atol}. "
f"Original output: {orig_tensor}, Modified output: {mod_tensor}"
)


class TestReplaceOpsPasses(unittest.TestCase):
Expand Down Expand Up @@ -840,10 +859,10 @@ def test_replace_scalar_tensor_with_full(
def test_replace_linear_with_fully_connected(self) -> None:
shape, in_channels, out_channels = (1, 14), 14, 128
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
weights = builder.placeholder(
"weights", torch.randn([out_channels, in_channels], dtype=torch.float32)
)
x_input = torch.randn(*shape, dtype=torch.float32)
weights_input = torch.randn([out_channels, in_channels], dtype=torch.float32)
x = builder.placeholder("x", x_input)
weights = builder.placeholder("weights", weights_input)
permute_copy = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(weights, [1, 0]),
Expand All @@ -854,14 +873,31 @@ def test_replace_linear_with_fully_connected(self) -> None:
)
builder.output([mm])
original_gm = builder.get_graph_module()

gm = cast(
PassResult, ReplacePermuteWithTransposePass()(original_gm)
).graph_module
gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module
gm = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)).graph_module

gm_before_linear = copy.deepcopy(gm)
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
self.assertTrue(pass_result.modified)
gm = pass_result.graph_module

inputs = [x_input, weights_input]
validate(gm_before_linear, gm, inputs, "ReplaceAddMMWithLinearPass")
gm_before_fc = copy.deepcopy(gm)
graph_after_passes = cast(
PassResult, ReplaceLinearWithFullyConnectedOpPass()(gm)
).graph_module

validate(
gm_before_fc,
graph_after_passes,
inputs,
"ReplaceLinearWithFullyConnectedOpPass",
)

self.assertIsNotNone(graph_after_passes)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.full.default),
Expand All @@ -878,21 +914,17 @@ def test_replace_linear_with_fully_connected(self) -> None:
0,
)

@expand(
[
[(4, 16, 256), 256, 512, True],
[(7, 17, 12), 12, 34, False],
]
)
@expand([[1.0, 1.0], [2.0, 3.0]])
@torch.no_grad()
def test_replace_addmm_with_linear(
self, shape: Tuple[int], in_features: int, out_features: int, bias: bool
) -> None:
M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0
def test_replace_addmm_with_linear(self, alpha: float, beta: float) -> None:
M, K, N = 14, 12, 10
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(N, dtype=torch.float32))
y = builder.placeholder("y", torch.randn([M, K], dtype=torch.float32))
z = builder.placeholder("z", torch.randn([N, K], dtype=torch.float32))
x_input = torch.randn(N, dtype=torch.float32)
y_input = torch.randn([M, K], dtype=torch.float32)
z_input = torch.randn([N, K], dtype=torch.float32)
x = builder.placeholder("x", x_input)
y = builder.placeholder("y", y_input)
z = builder.placeholder("z", z_input)
permute_copy = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(z, [1, 0]),
Expand All @@ -904,12 +936,21 @@ def test_replace_addmm_with_linear(
)
builder.output([addmm])
original_gm = builder.get_graph_module()

gm = cast(
PassResult, ReplacePermuteWithTransposePass()(original_gm)
).graph_module
graph_after_passes = cast(
PassResult, ReplaceAddMMWithLinearPass()(gm)
).graph_module

gm_before_linear = copy.deepcopy(gm)
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
self.assertTrue(pass_result.modified)
graph_after_passes = pass_result.graph_module

inputs = [x_input, y_input, z_input]
validate(
gm_before_linear, graph_after_passes, inputs, "ReplaceAddMMWithLinearPass"
)

self.assertIsNotNone(graph_after_passes)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.linear.default),
Expand Down
Loading