Skip to content

Commit

Permalink
pow cleanup part 1 (#4726)
Browse files Browse the repository at this point in the history
use _broadcasted to convert 3 cases into 1. const simplification should be handled by const folding.
  • Loading branch information
chenyuxyz committed May 25, 2024
1 parent f7201b6 commit 85e5722
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,18 +2417,19 @@ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
if x in [3,2,1]: return functools.reduce(lambda acc,_: acc * self, range(int(x)-1), self)
if x == 0.5: return self.sqrt()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()

base, exponent = self._broadcasted(x, reverse=reverse)
ret = base.abs().log().mul(exponent).exp()
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the exponent)
sign = (exponent * math.pi).cos()
# we only need to correct the sign if the base is negative
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
base_sign = ((base.sign()) - 1) / -2
# we need 0 to be positive so we need to correct base_sign when the base is 0
base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
# inject nan if the base is negative and the power is not an integer
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else \
int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
base_sign = base_sign - (1.5 * (1 - (base.sign().abs())))
# inject nan if the base is negative and the exponent is not an integer
to_nan = (exponent != exponent.trunc()).detach() * base_sign
inject_nan = (-to_nan * 2 + 1).log().add(1)
return ret.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)

def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""
Expand Down

0 comments on commit 85e5722

Please sign in to comment.