Skip to content

Commit

Permalink
[inductor] Add CPU-side profiler event for triton kernels w/ python w…
Browse files Browse the repository at this point in the history
…rapper (#106351)

This allows you to view the original kernel names (e.g. to reference the triton kernel implementation in the python wrapper code / TORCH_COMPILE_DEBUG logs). `torch._inductor.config.unique_kernel_names=True` does this too, but leaving unique_kernel_names=False will increase triton caching.

Another benefit to this approach is that we can attach additional information to this profiler event in the future. For example, we could attach input shapes/strides (i.e. record_shapes=True for profiler), or possibly paths to the files where the code was dumped.

<img width="435" alt="Screenshot 2023-07-31 at 5 34 25 PM" src="https://github.com/pytorch/pytorch/assets/5067123/839b752f-3907-4f29-9038-9d1822222b45">

^ in the trace above, the pink "triton_poi_fused_add_cos_sin_0" kernel is the new trace event which is added by this PR.

**Performance impact**: [dashboard run](https://hud.pytorch.org/benchmark/compilers?startTime=Thu%2C%2010%20Aug%202023%2000%3A52%3A06%20GMT&stopTime=Thu%2C%2017%20Aug%202023%2000%3A52%3A06%20GMT&granularity=hour&suite=torchbench&mode=inference&dtype=bfloat16&lBranch=gh/davidberard98/216/orig&lCommit=90c4212a7993c3660e7ea53bcd9d21160be31d1a&rBranch=main&rCommit=35cca799ff42182a1b7f1ee4d0225ee879b7c924). There are some regressions, including a 1.72x -> 1.71x on huggingface and 1.30x -> 1.29x on dynamic; however, locally I can't reproduce the results on any of the individual models (differences look like they are within noise). I think the perf impact is likely < 1% overall.

Differential Revision: [D47941809](https://our.internmc.facebook.com/intern/diff/D47941809)

Pull Request resolved: #106351
Approved by: https://github.com/eellison, https://github.com/albanD
ghstack dependencies: #107195
  • Loading branch information
davidberard98 authored and pytorchmergebot committed Aug 22, 2023
1 parent 614b865 commit ba5eeed
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 5 deletions.
41 changes: 41 additions & 0 deletions test/inductor/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch._dynamo.test_case
import torch._inductor.utils

from torch.profiler import ProfilerActivity

from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM

HAS_TRITON = torch._inductor.utils.has_triton()
Expand Down Expand Up @@ -44,6 +46,45 @@ def nameMatchesLaunchKernel(event_name):
)
)

@unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
def test_inductor_profiling_kernel_names(self):
"""
We expect a record_function event to be added on the CPU side, surrounding
the launch of each triton kernel.
"""

def fn(x, y):
return (x + y).sin().cos()

fn_opt = torch.compile(fn)

x, y = (torch.rand((4, 4), device="cuda") for _ in range(2))

for _ in range(2):
fn_opt(x, y)

with torch.profiler.profile(activities=[ProfilerActivity.CPU]) as prof:
fn_opt(x, y)

# The name of the kernel is expected to match the name of the kernel in debug
# files etc. The name could change in the future, but it seems reasonable that
# the name should always contain "triton" and "sin" - "sin" because this
# kernel contains a sin op. If this changes in the future, feel free to change
# the assertion here.
# As of time of writing, the kernel name was "triton_poi_fused_add_cos_sin_0"
# Debugging tips: you can add prof.export_chrome_trace("test.json") inline in
# this test, and then view test.json in chrome://tracing to see the trace.
self.assertTrue(
any(
(
hasattr(event, "name")
and "sin" in event.name
and "triton" in event.name
)
for event in prof.events()
)
)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,7 @@ def codegen_kernel(self, name=None):
"constants": {},
"mutated_arg_names": mutated_args,
"autotune_hints": set(self.autotune_hints),
"kernel_name": "DESCRIPTIVE_KRNL_NAME",
}

for tree in self.range_trees:
Expand Down Expand Up @@ -2462,6 +2463,11 @@ def define_kernel(self, src_code, node_schedule):
# use the original src_code as the key
wrapper.src_to_kernel[src_code] = kernel_name
subs_name = kernel_name if config.triton.unique_kernel_names else "triton_"

# DESCRIPTIVE_KRNL_NAME is used for profiling purposes; it shows the full kernel name
# even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set
# to "triton_" to maximize caching opportunities (when unique_kernel_names = False).
src_code = src_code.replace("DESCRIPTIVE_KRNL_NAME", kernel_name)
src_code = src_code.replace("KERNEL_NAME", subs_name)

# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
Expand Down
31 changes: 26 additions & 5 deletions torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from typing import Any, Callable, List, Optional, Set, Tuple

import torch

import torch.autograd.profiler as autograd_profiler
from torch._dynamo.utils import dynamo_timed

from . import config
Expand Down Expand Up @@ -154,6 +156,11 @@ def __init__(
is_mm=False, name=self.fn.__name__, size_hints=size_hints
)

# pre-create the profiler context manager to reduce latency
self.record_function_ctx = torch._C._profiler._RecordFunctionFast(
self.meta.get("kernel_name", "triton kernel")
)

def precompile(self, warm_cache_only_with_cc=None):
with self.lock:
if self.launchers:
Expand Down Expand Up @@ -399,11 +406,25 @@ def run(self, *args, grid, stream):
launcher.config.pre_hook(
{**dict(zip(self.arg_names, args)), **launcher.config.kwargs}
)
return launcher(
*args,
grid=grid,
stream=stream,
)

# guard the record_function_ctx and only call it if profiling is currently
# in progress, to reduce latency when profiler is not turned on. Note that
# the "if" statement (instead of, say, a contextlib.nullcontext) is intentional;
# it is faster than entering and exiting a context manager, even if the context
# manager is a nullcontext.
if autograd_profiler._is_profiler_enabled:
with self.record_function_ctx:
return launcher(
*args,
grid=grid,
stream=stream,
)
else:
return launcher(
*args,
grid=grid,
stream=stream,
)


def _find_names(obj):
Expand Down

0 comments on commit ba5eeed

Please sign in to comment.