diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 98b33303dbf8b..dab470de2ea5d 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -294,6 +294,20 @@ def fmax(self, other): return torch.where(torch.isnan(other) | (other < self), self, other) +@register_decomposition(aten.amax) +def amax(self, dim=None, keepdim=False): + if self.dtype == torch.bool: + return torch.any(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition(aten.amin) +def amin(self, dim=None, keepdim=False): + if self.dtype == torch.bool: + return torch.all(self, dim=dim, keepdim=keepdim) + return NotImplemented + + @register_decomposition([aten.narrow_copy]) def narrow_copy(self, dim, start, length): return torch.narrow(self, dim, start, length).clone()