Skip to content

Commit

Permalink
Fixed minor issues for bmm/mm decompositon (#109836)
Browse files Browse the repository at this point in the history
Summary:

* Fixed minor issues for bmm/mm decompositon
* enabled addmm for inductor

Test Plan: ci

Reviewed By: mikekgfb

Differential Revision: D49522332
  • Loading branch information
chenyang78 authored and facebook-github-bot committed Sep 28, 2023
1 parent e20c35a commit bde7066
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ def round_dec(x, decimals=0):


@register_decomposition([aten.bmm])
@pw_cast_for_opmath
def bmm(self, batch2):
if self.device == "cpu":
if self.device.type == "cpu":
if self.size(1) == 1 and batch2.size(-1) == 1:
return torch.sum(
self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
Expand All @@ -189,15 +190,17 @@ def bmm(self, batch2):


@register_decomposition([aten.mm])
@pw_cast_for_opmath
def mm(self, input2):
# Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
# todo: Look into why and fix it (hopefully)
if config.coordinate_descent_tuning:
if self.shape[0] == 1 or input2.shape[1] == 1:
return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
if self.device == "cpu":
if self.device.type == "cpu":
if (
self.size(-1) == 1
and self.size(0) > 0
and input2.size(0) == 1
and (self.dtype == input2.dtype)
and ((torch.numel(self) + torch.numel(input2)) <= 32)
Expand Down

0 comments on commit bde7066

Please sign in to comment.