From fe3c1302d519ae55b211a9c501ecb107152e7e97 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Mon, 29 Apr 2024 16:20:37 -0700 Subject: [PATCH] [inductor] add triton code to SchedulerNode.debug_str (#125091) Here is an example print: https://gist.github.com/shunting314/75c161368a833a535bd0d240b8099d7e Pull Request resolved: https://github.com/pytorch/pytorch/pull/125091 Approved by: https://github.com/jansel ghstack dependencies: #125090 --- .../codegen/cuda_combined_scheduling.py | 5 ++++ torch/_inductor/codegen/triton.py | 14 +++++++--- torch/_inductor/scheduler.py | 26 ++++++++++++++++++- 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index eceadeb4c7d56..3eac88881fab5 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -75,3 +75,8 @@ def codegen_foreach(self, *args, **kwargs): def benchmark_fused_nodes(self, nodes): return self._triton_scheduling.benchmark_fused_nodes(nodes) + + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): + return self._triton_scheduling.generate_kernel_code_from_nodes( + nodes, benchmark_kernel + ) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 152621453c79f..17bb6e1a89a13 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -3921,8 +3921,7 @@ def flush(self): def ready_to_flush(self) -> bool: return False - @preserve_rng_state() - def benchmark_fused_nodes(self, nodes): + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): @dataclasses.dataclass class LastUsageHolder: n: Any @@ -3954,18 +3953,25 @@ def __del__(self): ) self.codegen_node_schedule_with_kernel(node_schedule, kernel) - with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel): + with config.patch( + "benchmark_kernel", benchmark_kernel + ), V.set_kernel_handler(kernel): src_code = kernel.codegen_kernel() else: template_node = nodes[0] epilogue_nodes = nodes[1:] - with config.patch("benchmark_kernel", True): + with config.patch("benchmark_kernel", benchmark_kernel): src_code = self.codegen_template( template_node, epilogue_nodes, only_gen_src_code=True ) src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + return src_code + + @preserve_rng_state() + def benchmark_fused_nodes(self, nodes): + src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True) mod = PyCodeCache.load(src_code) def cache_file_path(): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 41a988ec221d1..b827d72530726 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -743,6 +743,20 @@ def debug_str_extra(self) -> str: if isinstance(self._body, ir.LoopBody): lines.append(f"class {name}_loop_body:") lines.append(textwrap.indent(self._body.debug_str(), " ")) + + if ir.is_triton(self.node.get_device()): + backend = self.scheduler.get_backend(self.node.get_device()) + V.graph.scheduler.current_device = self.node.get_device() + + # Don't increment kernel count when generating debug string. + # This will confuse some unit tests that check the number of + # generated kernels. + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes((self,)).strip() + metrics.generated_kernel_count = old_generated_kernel_count + + lines.append(f"{self.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) return "\n".join(lines) def get_ranges(self): @@ -900,6 +914,16 @@ def debug_str_extra(self) -> str: f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" for i, node in enumerate(self.snodes) ] + device = self.snodes[0].node.get_device() + if ir.is_triton(device): + backend = self.scheduler.get_backend(device) + V.graph.scheduler.current_device = device + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes(self.snodes).strip() + metrics.generated_kernel_count = old_generated_kernel_count + lines.append(f"{self.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) + return textwrap.indent("\n".join(lines).rstrip(), " ") def set_last_usage( @@ -1271,6 +1295,7 @@ class Scheduler: @dynamo_timed def __init__(self, nodes): super().__init__() + V.graph.scheduler = self self.backends = {} self.fuse_cache = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -1734,7 +1759,6 @@ def benchmark_fused_nodes(self, nodes) -> Tuple[float, str]: """ assert len(nodes) > 0 device = nodes[0].get_device() - V.graph.scheduler = self self.current_device = device backend = self.get_backend(device) return backend.benchmark_fused_nodes(nodes)