Skip to content

Commit

Permalink
fix sign return type
Browse files Browse the repository at this point in the history
  • Loading branch information
ngimel committed May 13, 2023
1 parent da02ccc commit c6fe6a5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
7 changes: 7 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3111,6 +3111,13 @@ def fn(x):
(torch.randn([1, 2, 6, 6]),),
)

def test_sign_dtype(self):
def fn(x):
y = torch.sign(x)
return torch.tanh(y)

self.common(fn, (torch.randn([1, 2, 6, 6]),))

def test_fmod(self):
def fn(a, b):
return torch.fmod(a, b), torch.fmod(3.0 * a, b) - 2.0
Expand Down
6 changes: 0 additions & 6 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,6 @@ def reciprocal(x):
def square(x):
return ops.mul(x, x)

@staticmethod
def sign(x):
left = ops.where(ops.lt("0", x), "1", "0")
right = ops.where(ops.lt(x, "0"), "1", "0")
return ops.sub(left, right)

@staticmethod
def bitwise_not(x):
return f"~{ExprPrinter.paren(x)}"
Expand Down
7 changes: 7 additions & 0 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,13 @@ def floordiv(a, b):
rem = f"{a} % {b}"
return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})"

@staticmethod
def sign(x):
left = ops.where(ops.lt("0", x), 1, 0)
right = ops.where(ops.lt(x, "0"), 1, 0)
sub = ops.sub(left, right)
return f"{sub}.to({x}.dtype)"

@staticmethod
def trunc(x):
return f"tl.math.trunc({x})"
Expand Down

0 comments on commit c6fe6a5

Please sign in to comment.