From 2abfb8ec7d7a3970097c12caabe1ccb7a05bb5d5 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 21 Nov 2023 10:13:43 -0500 Subject: [PATCH] Correctly codegen math.inf in Inductor (#114159) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/114159 Approved by: https://github.com/lezcano --- .../test_torchinductor_dynamic_shapes.py | 20 +++++++++++++++++++ torch/_inductor/codegen/common.py | 6 ++++++ 2 files changed, 26 insertions(+) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 459059a7434c0..3745afa72bbd2 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -235,6 +235,26 @@ def f(x): f(torch.tensor([3], device=device)) + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + def test_float_item_inf(self, device): + @torch.compile(fullgraph=True) + def f(x): + return x.item() == math.inf + + f(torch.tensor([3.0], device=device)) + + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + def test_float_item_neginf(self, device): + @torch.compile(fullgraph=True) + def f(x): + return x.item() == -math.inf + + f(torch.tensor([3.0], device=device)) + @torch._dynamo.config.patch(capture_scalar_outputs=True) @torch._inductor.config.patch(implicit_fallbacks=True) def test_item_to_inputs_kernel_nobreak(self, device): diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 920dc318d7155..be949e8f92a98 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -311,6 +311,12 @@ def _print_Pow(self, expr): else: # exp == 0 return "1" + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))