Skip to content

Commit

Permalink
Freeze fuse two mms (#111232)
Browse files Browse the repository at this point in the history
Improves llama_v2 perf locally from 1.48x -> 1.55x.

A good future rewrite would be to unify the freezing batching with the other batching rules that @yanboliang & co were working on. I want to wait for the forthcoming pre-dispatch changes to settle down first.

Pull Request resolved: #111232
Approved by: https://github.com/Chillee
  • Loading branch information
eellison authored and pytorchmergebot committed Oct 19, 2023
1 parent cb856b0 commit 652f4c6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
36 changes: 35 additions & 1 deletion test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ def __init__(self):
def forward(self, x):
return x @ self.t1, x @ self.t2, x @ self.t3

class MM2(torch.nn.Module):
def __init__(self):
super().__init__()

self.t1 = torch.nn.Parameter(torch.rand(10, 10))
self.t2 = torch.nn.Parameter(torch.rand(10, 10))

def forward(self, x):
return x @ self.t1, x @ self.t2

class AddMM(MM):
def __init__(self):
super().__init__()
Expand All @@ -192,7 +202,12 @@ def forward(self, x):
]
]

for mod in [MM().to(self.device), AddMM().to(self.device)][1:]:
for mod_fn in [
lambda: MM().to(self.device),
lambda: MM2().to(self.device),
lambda: AddMM().to(self.device),
]:
mod = mod_fn()
inp = torch.rand([10, 10]).to(self.device)

@torch.compile()
Expand All @@ -209,6 +224,25 @@ def foo(mod, inp):
).run(code[0])
self.assertEqual(out_eager, out)

mod2 = mod_fn()
mod2.t1 = torch.nn.Parameter(torch.rand([10, 15], device=self.device))
mod2.t2 = torch.nn.Parameter(torch.rand([10, 20], device=self.device))

if hasattr(mod2, "b1"):
mod2.b1 = torch.nn.Parameter(torch.rand([15], device=self.device))
mod2.b2 = torch.nn.Parameter(torch.rand([20], device=self.device))

# not fused
count = 3 if hasattr(mod2, "t3") else 2

with torch.no_grad():
out_eager = mod2(inp)
out, code = run_and_get_code(foo, mod2, inp)
FileCheck().check_not(kernel_invoke).check_count(
"mm(", count=count, exactly=True
).run(code[0])
self.assertEqual(out_eager, out)

def test_error_on_eager(self):
mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device)

Expand Down
22 changes: 21 additions & 1 deletion torch/_inductor/fx_passes/freezing_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,10 @@ def check_concat_weights(match):
weights = [
match.kwargs["w1"],
match.kwargs["w2"],
match.kwargs["w3"],
]
if "w3" in match.kwargs:
weights.append(match.kwargs["w3"])

return all(
w.op == "get_attr" and w.meta["val"].shape == weights[0].meta["val"].shape
for w in weights
Expand All @@ -148,6 +150,24 @@ def matmul_replacement(inp, w1, w2, w3):
exclusive_arg_names=("w1", "w2", "w3"),
)

def matmul_fuse_pattern_two(inp, w1, w2):
return (inp @ w1, inp @ w2)

def matmul_replacement_two(inp, w1, w2):
cat_t = torch.cat((w1, w2), dim=1)
mm = inp @ cat_t
return mm.chunk(2, dim=1)

register_replacement(
matmul_fuse_pattern_two,
matmul_replacement_two,
[val(), val(), val()],
inference_graph,
pass_patterns[0],
extra_check=check_concat_weights,
exclusive_arg_names=("w1", "w2"),
)

def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3):
return (
aten.addmm(b1, inp, w1),
Expand Down

0 comments on commit 652f4c6

Please sign in to comment.