Skip to content

Commit

Permalink
Correctly codegen math.inf in Inductor (#114159)
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #114159
Approved by: https://github.com/lezcano
  • Loading branch information
ezyang authored and pytorchmergebot committed Nov 21, 2023
1 parent c47d2b8 commit 2abfb8e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
20 changes: 20 additions & 0 deletions test/inductor/test_torchinductor_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,26 @@ def f(x):

f(torch.tensor([3], device=device))

@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
def test_float_item_inf(self, device):
@torch.compile(fullgraph=True)
def f(x):
return x.item() == math.inf

f(torch.tensor([3.0], device=device))

@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
def test_float_item_neginf(self, device):
@torch.compile(fullgraph=True)
def f(x):
return x.item() == -math.inf

f(torch.tensor([3.0], device=device))

@torch._dynamo.config.patch(capture_scalar_outputs=True)
@torch._inductor.config.patch(implicit_fallbacks=True)
def test_item_to_inputs_kernel_nobreak(self, device):
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ def _print_Pow(self, expr):
else: # exp == 0
return "1"

def _print_Infinity(self, expr):
return "math.inf"

def _print_NegativeInfinity(self, expr):
return "-math.inf"

def _print_Relational(self, expr):
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))

Expand Down

0 comments on commit 2abfb8e

Please sign in to comment.