From 3476821b0c3dacb268b9d876848d5dfb3f211c71 Mon Sep 17 00:00:00 2001 From: Peiying Hua Date: Thu, 20 Nov 2025 14:32:51 -0800 Subject: [PATCH] Allow specifiying the use of persistent kernel (#5129) Summary: X-link: https://github.com/meta-pytorch/tritonbench/pull/654 X-link: https://github.com/facebookresearch/FBGEMM/pull/2130 Added environment argument "use_persistent" (default is False) to explicitly turn off non-persistent kernel and use persistent kernel. Throws error when both "use_persistent" and "no_use_persistent" are specified in the arguments. Example usage: Persistent kernel-- buck2 run mode/{dev-nosan,amd-gpu} -c xlog.level=WARNING -m ovr_config//triton:trunk -m rocm7 -c fbcode.nvcc_arch=mi350 -c fbcode.enable_gpu_sections=true pytorch/tritonbench:run -- --op fp8_gemm_rowwise --no_use_tma --use_persistent Non-persistent kernel-- buck2 run mode/{dev-nosan,amd-gpu} -c xlog.level=WARNING -m ovr_config//triton:trunk -m rocm7 -c fbcode.nvcc_arch=mi350 -c fbcode.enable_gpu_sections=true pytorch/tritonbench:run -- --op fp8_gemm_rowwise --no_use_tma --no_use_persistent When both specified in the arguments: buck2 run mode/{dev-nosan,amd-gpu} -c xlog.level=WARNING -m ovr_config//triton:trunk -m rocm7 -c fbcode.nvcc_arch=mi350 -c fbcode.enable_gpu_sections=true pytorch/tritonbench:run -- --op fp8_gemm_rowwise --no_use_tma --use_persistent --no_use_persistent IT WILL THROW ERROR: Cannot specify both '--use_persistent' and '--no_use_persistent' at the same time. These options are mutually exclusive. Please use only one. Reviewed By: adamomainz, njriasan, jwfromm Differential Revision: D86579911 --- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 7899854c3c..37f0e663bd 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -1212,6 +1212,8 @@ def matmul_fp8_row( imprecise_acc: bool = False, tma_persistent: bool = True, no_use_persistent: Optional[bool] = None, + # add an option to explicitly require the use of persistent process + use_persistent: Optional[bool] = None, use_warp_specialization: bool = False, ) -> torch.Tensor: """ @@ -1232,12 +1234,16 @@ def matmul_fp8_row( Returns: torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :]) """ - if no_use_persistent is None: + if use_persistent: + no_use_persistent = False + elif no_use_persistent is None: # Default True for AMD and False for Nvidia. if torch.version.hip is not None: no_use_persistent = True else: no_use_persistent = False + # if use_persistent is explicitly requested, set o_use_persistent to False + # Get datatypes and constants to use. pt_fp8_dtype, _, _, _ = get_fp8_constants() # Handle 3D+ a shape