|
9 | 9 | import operator |
10 | 10 | import re |
11 | 11 | import sys |
| 12 | +import textwrap |
12 | 13 | import types |
13 | 14 | from typing import TYPE_CHECKING |
14 | 15 | from typing import Callable |
@@ -641,8 +642,88 @@ def __call__(self, *args: object) -> _R: |
641 | 642 | self.format_kernel_decorator(self._config, self.settings) |
642 | 643 | ] = 1 |
643 | 644 |
|
| 645 | + if self.settings.print_repro: |
| 646 | + self._print_repro(args) |
| 647 | + |
644 | 648 | return self._run(*args) |
645 | 649 |
|
| 650 | + def _print_repro( |
| 651 | + self, args: tuple[object, ...], config: Config | None = None |
| 652 | + ) -> None: |
| 653 | + effective_config = config or self._config |
| 654 | + assert effective_config is not None |
| 655 | + |
| 656 | + # Get kernel source |
| 657 | + try: |
| 658 | + raw_source = inspect.getsource(self.kernel.fn) |
| 659 | + source_lines = textwrap.dedent(raw_source).splitlines() |
| 660 | + # Skip decorator lines |
| 661 | + start_idx = 0 |
| 662 | + while start_idx < len(source_lines) and source_lines[ |
| 663 | + start_idx |
| 664 | + ].lstrip().startswith("@"): |
| 665 | + start_idx += 1 |
| 666 | + kernel_body = "\n".join(source_lines[start_idx:]) |
| 667 | + except (OSError, TypeError): |
| 668 | + kernel_body = f"# Source unavailable for {self.kernel.fn.__module__}.{self.kernel.fn.__qualname__}" |
| 669 | + |
| 670 | + # Format decorator |
| 671 | + decorator = self.format_kernel_decorator(effective_config, self.settings) |
| 672 | + |
| 673 | + # Build output |
| 674 | + output_lines = [ |
| 675 | + "# === HELION KERNEL REPRO ===", |
| 676 | + "import helion", |
| 677 | + "import helion.language as hl", |
| 678 | + "import torch", |
| 679 | + "from torch._dynamo.testing import rand_strided", |
| 680 | + "", |
| 681 | + decorator, |
| 682 | + kernel_body, |
| 683 | + ] |
| 684 | + |
| 685 | + # Generate caller function |
| 686 | + if args: |
| 687 | + |
| 688 | + def _render_input_arg_assignment(name: str, value: object) -> list[str]: |
| 689 | + if isinstance(value, torch.Tensor): |
| 690 | + shape = tuple(int(d) for d in value.shape) |
| 691 | + stride = tuple(int(s) for s in value.stride()) |
| 692 | + device = str(value.device) |
| 693 | + dtype = str(value.dtype) |
| 694 | + |
| 695 | + lines = [ |
| 696 | + f"{name} = rand_strided({shape!r}, {stride!r}, dtype={dtype}, device={device!r})" |
| 697 | + ] |
| 698 | + |
| 699 | + if value.requires_grad: |
| 700 | + lines.append(f"{name}.requires_grad_(True)") |
| 701 | + return lines |
| 702 | + |
| 703 | + return [f"{name} = {value!r}"] |
| 704 | + |
| 705 | + sig_param_names = list(self.kernel.signature.parameters.keys()) |
| 706 | + assert len(args) == len(sig_param_names) |
| 707 | + |
| 708 | + output_lines.extend(["", "def helion_repro_caller():"]) |
| 709 | + output_lines.append(" torch.manual_seed(0)") |
| 710 | + arg_names = [] |
| 711 | + |
| 712 | + for i, value in enumerate(args): |
| 713 | + var_name = sig_param_names[i] |
| 714 | + arg_names.append(var_name) |
| 715 | + |
| 716 | + # Add assignment lines with indentation |
| 717 | + for line in _render_input_arg_assignment(var_name, value): |
| 718 | + output_lines.append(f" {line}") |
| 719 | + |
| 720 | + # Add return statement |
| 721 | + call_args = ", ".join(arg_names) |
| 722 | + output_lines.append(f" return {self.kernel.name}({call_args})") |
| 723 | + |
| 724 | + output_lines.append("# === END HELION KERNEL REPRO ===") |
| 725 | + print("\n".join(output_lines), file=sys.stderr) |
| 726 | + |
646 | 727 |
|
647 | 728 | class _KernelDecorator(Protocol): |
648 | 729 | def __call__( |
|
0 commit comments