From 48750e8423eb2dab9eed9f1261a51c439d001694 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Mon, 25 Sep 2023 07:38:03 -0700 Subject: [PATCH] [inductor] Decompose addmm if it's a dot product on cpu Generated code for dot product is often faster (on CPU) than dispatching to aten, since it avoids op dispatch overhead and allows fusion with surrounding ops, which in turn avoids allocations. Differential Revision: [D49595876](https://our.internmc.facebook.com/intern/diff/D49595876/) ghstack-source-id: 201785775 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110010 --- torch/_inductor/decomposition.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index da407f65ebb75..306216c952042 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -202,6 +202,18 @@ def bmm(self, batch2): return NotImplemented +@register_decomposition([aten.addmm]) +@pw_cast_for_opmath +def addmm(self, mat1, mat2, beta=1, alpha=1): + if self.device.type == "cpu": + if mat1.size(0) == 1 and mat2.size(-1) == 1: + out = torch.sum( + mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return alpha * out + beta * self + return NotImplemented + + @register_decomposition([aten.mm]) def mm(self, input2): # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.