Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,47 @@ def normalize_codegen_variants(code: str) -> str:
# libdevice.sqrt( -> tl.sqrt_rn(
code = re.sub(r"\blibdevice\.sqrt\s*\(", "tl.sqrt_rn(", code)
# tl.sqrt( -> tl.sqrt_rn(
return re.sub(r"\btl\.sqrt\s*\(", "tl.sqrt_rn(", code)
code = re.sub(r"\btl\.sqrt\s*\(", "tl.sqrt_rn(", code)

# Normalize rsqrt variants
# libdevice.rsqrt( -> tl.rsqrt(
code = re.sub(r"\blibdevice\.rsqrt\s*\(", "tl.rsqrt(", code)
Comment on lines +776 to +777
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad pasta?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one actually shows up in AMD tests:

2025-10-24T16:25:05.3039413Z E   -     v_8 = tl.rsqrt(v_7)
2025-10-24T16:25:05.3039489Z E   ?           -
2025-10-24T16:25:05.3039637Z E   +     v_8 = libdevice.rsqrt(v_7)
2025-10-24T16:25:05.3039708Z E   ?            ++++++++


total_num_triton_helpers_replacements = 0

# Normalize maximum variants
# tl.maximum(a, b, tl.PropagateNan.ALL) -> triton_helpers.maximum(a, b)
code, num_replacements = re.subn(
r"\btl\.maximum\s*\(([^,]+),\s*([^,]+),\s*tl\.PropagateNan\.ALL\s*\)",
r"triton_helpers.maximum(\1, \2)",
code,
)
total_num_triton_helpers_replacements += num_replacements

# Normalize minimum variants
# tl.minimum(a, b, tl.PropagateNan.ALL) -> triton_helpers.minimum(a, b)
code, num_replacements = re.subn(
r"\btl\.minimum\s*\(([^,]+),\s*([^,]+),\s*tl\.PropagateNan\.ALL\s*\)",
r"triton_helpers.minimum(\1, \2)",
code,
)
total_num_triton_helpers_replacements += num_replacements

triton_helpers_import = "from torch._inductor.runtime import triton_helpers"
if (
total_num_triton_helpers_replacements > 0
and triton_helpers_import not in code
):
# Insert right after `import triton.language as tl`
code = re.sub(
r"(^import triton\.language as tl$)",
rf"\1\n{triton_helpers_import}",
code,
count=1,
flags=re.MULTILINE,
)

return code

@staticmethod
def normalize_source_comment_structure(code: str) -> str:
Expand Down
Loading