diff --git a/helion/_testing.py b/helion/_testing.py index ad524ad89..47f9c09bd 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -770,7 +770,47 @@ def normalize_codegen_variants(code: str) -> str: # libdevice.sqrt( -> tl.sqrt_rn( code = re.sub(r"\blibdevice\.sqrt\s*\(", "tl.sqrt_rn(", code) # tl.sqrt( -> tl.sqrt_rn( - return re.sub(r"\btl\.sqrt\s*\(", "tl.sqrt_rn(", code) + code = re.sub(r"\btl\.sqrt\s*\(", "tl.sqrt_rn(", code) + + # Normalize rsqrt variants + # libdevice.rsqrt( -> tl.rsqrt( + code = re.sub(r"\blibdevice\.rsqrt\s*\(", "tl.rsqrt(", code) + + total_num_triton_helpers_replacements = 0 + + # Normalize maximum variants + # tl.maximum(a, b, tl.PropagateNan.ALL) -> triton_helpers.maximum(a, b) + code, num_replacements = re.subn( + r"\btl\.maximum\s*\(([^,]+),\s*([^,]+),\s*tl\.PropagateNan\.ALL\s*\)", + r"triton_helpers.maximum(\1, \2)", + code, + ) + total_num_triton_helpers_replacements += num_replacements + + # Normalize minimum variants + # tl.minimum(a, b, tl.PropagateNan.ALL) -> triton_helpers.minimum(a, b) + code, num_replacements = re.subn( + r"\btl\.minimum\s*\(([^,]+),\s*([^,]+),\s*tl\.PropagateNan\.ALL\s*\)", + r"triton_helpers.minimum(\1, \2)", + code, + ) + total_num_triton_helpers_replacements += num_replacements + + triton_helpers_import = "from torch._inductor.runtime import triton_helpers" + if ( + total_num_triton_helpers_replacements > 0 + and triton_helpers_import not in code + ): + # Insert right after `import triton.language as tl` + code = re.sub( + r"(^import triton\.language as tl$)", + rf"\1\n{triton_helpers_import}", + code, + count=1, + flags=re.MULTILINE, + ) + + return code @staticmethod def normalize_source_comment_structure(code: str) -> str: