Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Oct 29, 2025

Our users commonly invoke Helion kernel from a very deep code stack, and it's been very difficult to extract a small repro.

With HELION_PRINT_REPRO=1, we will print the minimal Helion kernel repro script to console, thus making it much easier to repro issues.

Example output to stderr:

import helion
import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided

@helion.kernel(config=helion.Config(block_sizes=[2, 2], flatten_loops=[False], indexing=['pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=[''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[None]), static_shapes=True)
def kernel1(x: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x)
    m, n = x.shape
    for tile_m, tile_n in hl.tile([m, n]):
        out[tile_m, tile_n] = x[tile_m, tile_n] + 1
    return out

def helion_repro_caller():
    torch.manual_seed(0)
    x = rand_strided((2, 2), (2, 1), dtype=torch.float32, device=DEVICE)
    return kernel1(x)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 29, 2025
@yf225 yf225 changed the title Add HELION_PRINT_REPRO=1 to print Helion kernel + caller code to console Add HELION_PRINT_REPRO=1 to print Helion kernel + caller code to console Oct 29, 2025
@yf225 yf225 changed the title Add HELION_PRINT_REPRO=1 to print Helion kernel + caller code to console Add HELION_PRINT_REPRO=1 to print Helion kernel and caller code to console Oct 29, 2025

return self._run(*args)

def _print_repro(
Copy link
Contributor

@oulgen oulgen Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think you need any of this, i added a mode to to_triton_code that prints repro?

there's even a unit test for it

Copy link
Contributor Author

@yf225 yf225 Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah so I actually want to get the Helion kernel (with config and caller) not the generated Triton kernel.

what internal folks usually provide is a buck command that's many layers deep, and it's very hard to extract out a clean Helion kernel repro. With this env var, it will print the Helion kernel with the right input tensors, so that I can just focus on debugging the minimal repro.

@yf225 yf225 force-pushed the env_var_print_helion_repro branch 7 times, most recently from 61c47cb to 05ddf4a Compare October 30, 2025 01:22
@yf225 yf225 marked this pull request as ready for review October 30, 2025 01:25
@yf225 yf225 requested review from jansel and oulgen October 30, 2025 01:25
@yf225 yf225 force-pushed the env_var_print_helion_repro branch from 05ddf4a to 2f765f1 Compare October 30, 2025 01:26
@yf225 yf225 changed the title Add HELION_PRINT_REPRO=1 to print Helion kernel and caller code to console Add HELION_PRINT_REPRO=1 to print Helion kernel repro script to console Oct 30, 2025
@yf225 yf225 force-pushed the env_var_print_helion_repro branch 3 times, most recently from d56b64a to b7c954f Compare October 30, 2025 21:03
@yf225 yf225 force-pushed the env_var_print_helion_repro branch from b7c954f to 5968d9a Compare October 30, 2025 21:04
@yf225 yf225 merged commit bfa223a into main Oct 30, 2025
14 checks passed
@yf225 yf225 deleted the env_var_print_helion_repro branch October 30, 2025 21:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants