From a75a29b7f29456c08f821d345a32f9ddb6689312 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 24 Oct 2025 13:18:26 -0700 Subject: [PATCH 1/4] fix AMD journal errors --- helion/_testing.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/helion/_testing.py b/helion/_testing.py index ad524ad89..faba866e5 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -770,7 +770,31 @@ def normalize_codegen_variants(code: str) -> str: # libdevice.sqrt( -> tl.sqrt_rn( code = re.sub(r"\blibdevice\.sqrt\s*\(", "tl.sqrt_rn(", code) # tl.sqrt( -> tl.sqrt_rn( - return re.sub(r"\btl\.sqrt\s*\(", "tl.sqrt_rn(", code) + code = re.sub(r"\btl\.sqrt\s*\(", "tl.sqrt_rn(", code) + + # Normalize rsqrt variants + # libdevice.rsqrt( -> tl.rsqrt( + code = re.sub(r"\blibdevice\.rsqrt\s*\(", "tl.rsqrt(", code) + + # Normalize maximum variants + # tl.maximum(a, b, tl.PropagateNan.ALL) -> triton_helpers.maximum(a, b) + code, num_triton_helpers_replacements = re.subn( + r"\btl\.maximum\s*\(([^,]+),\s*([^,]+),\s*tl\.PropagateNan\.ALL\s*\)", + r"triton_helpers.maximum(\1, \2)", + code, + ) + + triton_helpers_import = "from torch._inductor.runtime import triton_helpers" + if num_triton_helpers_replacements > 0 and triton_helpers_import not in code: + # Insert after __future__ imports if present, otherwise at start + code = re.sub( + r"^(from __future__ import .*?\n)?", + rf"\1{triton_helpers_import}\n", + code, + count=1, + ) + + return code @staticmethod def normalize_source_comment_structure(code: str) -> str: From 6795da264cc78e5475dfb92ab66b16e4ad12f6a1 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 24 Oct 2025 21:39:06 -0700 Subject: [PATCH 2/4] up --- helion/_testing.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/helion/_testing.py b/helion/_testing.py index faba866e5..3da4c211a 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -786,13 +786,22 @@ def normalize_codegen_variants(code: str) -> str: triton_helpers_import = "from torch._inductor.runtime import triton_helpers" if num_triton_helpers_replacements > 0 and triton_helpers_import not in code: - # Insert after __future__ imports if present, otherwise at start - code = re.sub( - r"^(from __future__ import .*?\n)?", - rf"\1{triton_helpers_import}\n", - code, - count=1, - ) + # Insert before helion imports if present, otherwise after __future__ + if "from helion" in code: + code = re.sub( + r"(^from helion)", + rf"{triton_helpers_import}\n\1", + code, + count=1, + flags=re.MULTILINE, + ) + else: + code = re.sub( + r"^(from __future__ import .*?\n)?", + rf"\1{triton_helpers_import}\n", + code, + count=1, + ) return code From ba336bc0b10f75c73e7aea1c1865fc8e083b3a61 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 24 Oct 2025 23:26:52 -0700 Subject: [PATCH 3/4] up --- helion/_testing.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/helion/_testing.py b/helion/_testing.py index 3da4c211a..2b644d120 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -786,22 +786,14 @@ def normalize_codegen_variants(code: str) -> str: triton_helpers_import = "from torch._inductor.runtime import triton_helpers" if num_triton_helpers_replacements > 0 and triton_helpers_import not in code: - # Insert before helion imports if present, otherwise after __future__ - if "from helion" in code: - code = re.sub( - r"(^from helion)", - rf"{triton_helpers_import}\n\1", - code, - count=1, - flags=re.MULTILINE, - ) - else: - code = re.sub( - r"^(from __future__ import .*?\n)?", - rf"\1{triton_helpers_import}\n", - code, - count=1, - ) + # Insert right after `import triton.language as tl` + code = re.sub( + r"(^import triton\.language as tl$)", + rf"\1\n{triton_helpers_import}", + code, + count=1, + flags=re.MULTILINE, + ) return code From ec2d7acbac802a3729884e40d5c7a3abccecbb51 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 24 Oct 2025 23:42:53 -0700 Subject: [PATCH 4/4] up --- helion/_testing.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/helion/_testing.py b/helion/_testing.py index 2b644d120..47f9c09bd 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -776,16 +776,31 @@ def normalize_codegen_variants(code: str) -> str: # libdevice.rsqrt( -> tl.rsqrt( code = re.sub(r"\blibdevice\.rsqrt\s*\(", "tl.rsqrt(", code) + total_num_triton_helpers_replacements = 0 + # Normalize maximum variants # tl.maximum(a, b, tl.PropagateNan.ALL) -> triton_helpers.maximum(a, b) - code, num_triton_helpers_replacements = re.subn( + code, num_replacements = re.subn( r"\btl\.maximum\s*\(([^,]+),\s*([^,]+),\s*tl\.PropagateNan\.ALL\s*\)", r"triton_helpers.maximum(\1, \2)", code, ) + total_num_triton_helpers_replacements += num_replacements + + # Normalize minimum variants + # tl.minimum(a, b, tl.PropagateNan.ALL) -> triton_helpers.minimum(a, b) + code, num_replacements = re.subn( + r"\btl\.minimum\s*\(([^,]+),\s*([^,]+),\s*tl\.PropagateNan\.ALL\s*\)", + r"triton_helpers.minimum(\1, \2)", + code, + ) + total_num_triton_helpers_replacements += num_replacements triton_helpers_import = "from torch._inductor.runtime import triton_helpers" - if num_triton_helpers_replacements > 0 and triton_helpers_import not in code: + if ( + total_num_triton_helpers_replacements > 0 + and triton_helpers_import not in code + ): # Insert right after `import triton.language as tl` code = re.sub( r"(^import triton\.language as tl$)",