diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 4d36e5d7646fc..73bbf75ba5d03 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -179,8 +179,9 @@ def round_dec(x, decimals=0): @register_decomposition([aten.bmm]) +@pw_cast_for_opmath def bmm(self, batch2): - if self.device == "cpu": + if self.device.type == "cpu": if self.size(1) == 1 and batch2.size(-1) == 1: return torch.sum( self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True @@ -201,15 +202,17 @@ def addmm(self, mat1, mat2, beta=1, alpha=1): @register_decomposition([aten.mm]) +@pw_cast_for_opmath def mm(self, input2): # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. # todo: Look into why and fix it (hopefully) if config.coordinate_descent_tuning: if self.shape[0] == 1 or input2.shape[1] == 1: return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) - if self.device == "cpu": + if self.device.type == "cpu": if ( self.size(-1) == 1 + and self.size(0) > 0 and input2.size(0) == 1 and (self.dtype == input2.dtype) and ((torch.numel(self) + torch.numel(input2)) <= 32)