From 4d30fe98701da1c91597b613bcb64c2f337181ce Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Nov 2025 11:48:20 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 7 +-- .../mx_formats/test_inference_workflow.py | 16 ++--- .../mx_formats/test_mx_serialization.py | 6 +- test/prototype/mx_formats/test_mx_tensor.py | 4 +- .../moe_training/scaled_grouped_mm.py | 8 +-- torchao/prototype/mx_formats/README.md | 22 +++---- torchao/prototype/mx_formats/__init__.py | 2 - torchao/prototype/mx_formats/config.py | 62 +++++++------------ .../mx_formats/inference_workflow.py | 17 +++-- torchao/prototype/mx_formats/mx_linear.py | 28 ++++----- torchao/prototype/mx_formats/mx_tensor.py | 43 ++++++------- .../quantize_/common/kernel_preference.py | 8 +++ 12 files changed, 97 insertions(+), 126 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 675c7f166f..dc732dc77a 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -38,9 +38,6 @@ ) import torchao -from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.inference_workflow import ( MXFPInferenceConfig, NVFP4InferenceConfig, @@ -439,13 +436,13 @@ def run( config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + kernel_preference=KernelPreference.AUTO, ) elif recipe_name == "mxfp4_cutlass": config = MXFPInferenceConfig( activation_dtype=torch.float4_e2m1fn_x2, weight_dtype=torch.float4_e2m1fn_x2, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + kernel_preference=KernelPreference.AUTO, ) elif recipe_name == "nvfp4": config = NVFP4InferenceConfig( diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 2f6e411ff7..8dad950c4c 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -12,15 +12,13 @@ import torch.nn as nn from torch.profiler import ProfilerActivity, profile -from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.inference_workflow import ( MXFPInferenceConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) from torchao.quantization import quantize_ +from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm from torchao.utils import ( @@ -105,15 +103,13 @@ def test_inference_workflow_mx( m_mx = copy.deepcopy(m) if emulate: - kernel_choice = MXGemmKernelChoice.EMULATED - elif elem_dtype == torch.float4_e2m1fn_x2: - kernel_choice = MXGemmKernelChoice.CUTLASS + kernel_choice = KernelPreference.EMULATED else: - kernel_choice = MXGemmKernelChoice.CUBLAS + kernel_choice = KernelPreference.AUTO config = MXFPInferenceConfig( activation_dtype=elem_dtype, weight_dtype=elem_dtype, - gemm_kernel_choice=kernel_choice, + kernel_preference=kernel_choice, ) quantize_(m_mx, config=config) if compile: @@ -254,7 +250,7 @@ def test_slice_and_copy_similar_to_vllm(self): config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + kernel_preference=KernelPreference.EMULATED, ) self._test_slice_and_copy_similar_to_vllm(config) @@ -267,7 +263,7 @@ def test_narrow_similar_to_vllm(self): config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + kernel_preference=KernelPreference.EMULATED, ) self._test_narrow_similar_to_vllm(config) diff --git a/test/prototype/mx_formats/test_mx_serialization.py b/test/prototype/mx_formats/test_mx_serialization.py index d04d23f46c..930dc1dfaa 100644 --- a/test/prototype/mx_formats/test_mx_serialization.py +++ b/test/prototype/mx_formats/test_mx_serialization.py @@ -12,15 +12,13 @@ import torch import torch.nn as nn -from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.inference_workflow import ( MXFPInferenceConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) from torchao.quantization import quantize_ +from torchao.quantization.quantize_.common import KernelPreference from torchao.utils import ( is_sm_at_least_100, torch_version_at_least, @@ -46,7 +44,7 @@ def test_serialization(recipe_name): config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + kernel_preference=KernelPreference.EMULATED, ) else: assert recipe_name == "nvfp4", "unsupported" diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 66f8998ea8..2b8c72ff91 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -12,7 +12,6 @@ from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck -from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, @@ -25,6 +24,7 @@ to_dtype, ) from torchao.prototype.mx_formats.utils import from_blocked, to_blocked +from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.utils import compute_error from torchao.utils import ( is_sm_at_least_89, @@ -375,7 +375,7 @@ def test_exponent_nan_out(elem_dtype): elem_dtype, block_size, torch.float, - MXGemmKernelChoice.EMULATED, + KernelPreference.EMULATED, None, False, ) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index c7705fec18..3a4ad43b4f 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -27,12 +27,12 @@ ) from torchao.prototype.mx_formats.config import ( MXFP8Dim1CastKernelChoice, - MXGemmKernelChoice, ScaleCalculationMode, ) from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0 from torchao.prototype.mx_formats.mx_tensor import to_mx from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper +from torchao.quantization.quantize_.common import KernelPreference logger: logging.Logger = logging.getLogger(__name__) @@ -412,7 +412,7 @@ def backward(ctx, grad_out: torch.Tensor): block_size, elem_dtype=torch.float8_e4m3fn, hp_dtype=grad_out.dtype, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used + kernel_preference=KernelPreference.AUTO, # Not used cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=scale_calculation_mode, ) @@ -428,7 +428,7 @@ def backward(ctx, grad_out: torch.Tensor): block_size, elem_dtype=torch.float8_e4m3fn, hp_dtype=A.dtype, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used + kernel_preference=KernelPreference.AUTO, # Not used cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=scale_calculation_mode, ) @@ -475,7 +475,7 @@ def _to_mxfp8_dim1_3d( block_size, elem_dtype=torch.float8_e4m3fn, hp_dtype=B_reshaped.dtype, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used + kernel_preference=KernelPreference.AUTO, # Not used cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=scaling_mode, ) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 8922be949b..6c36c2eaed 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -74,13 +74,13 @@ Below is a toy training loop. For an example real training loop, see our torchti import torch from torchao.quantization import quantize_ import torchao.prototype.mx_formats -from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice, ScaleCalculationMode +from torchao.prototype.mx_formats import MXLinearConfig, ScaleCalculationMode +from torchao.quantization.quantize_.common import KernelPreference -# on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels -gemm_kernel_choice = MXGemmKernelChoice.CUBLAS -# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS -# on older NVIDIA gpus, you can run training with emulated MX gemm -# gemm_kernel_choice = MXGemmKernelChoice.EMULATED +# low precision gemm, requires CUDA capability 10.0+ +kernel_preference = KernelPreference.AUTO +# or, emulated gemm +# kernel_preference = KernelPreference.EMULATED scale_calculation_mode = ScaleCalculationMode.FLOOR # other supported modes: RCEIL, CEIL, EVEN @@ -89,7 +89,7 @@ m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() config = MXLinearConfig( elem_dtype=torch.float8_e4m3fn, block_size=32, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scale_calculation_mode=scale_calculation_mode, ) quantize_(m, config) @@ -107,14 +107,12 @@ import torch import torch.nn as nn from torchao.quantization import quantize_ import torchao.prototype.mx_formats -from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.inference_workflow import ( MXFPInferenceConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) +from torchao.quantization.quantize_.common import KernelPreference m = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda") x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) @@ -125,7 +123,7 @@ m_mxfp8 = copy.deepcopy(m) config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + kernel_preference=KernelPreference.AUTO, ) quantize_(m_mxfp8, config=config) m_mxfp8 = torch.compile(m_mxfp8, fullgraph=True) @@ -137,7 +135,7 @@ m_mxfp4 = copy.deepcopy(m) config = MXFPInferenceConfig( activation_dtype=torch.float4_e2m1fn_x2, weight_dtype=torch.float4_e2m1fn_x2, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + kernel_preference=KernelPreference.AUTO, ) quantize_(m_mxfp4, config=config) m_mxfp4 = torch.compile(m_mxfp4, fullgraph=True) diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index c7a4c47f9d..8d1455d6f3 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -1,5 +1,4 @@ from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, MXLinearConfig, MXLinearRecipeName, ) @@ -16,7 +15,6 @@ import torchao.prototype.mx_formats.mx_linear # noqa: F401 __all__ = [ - "MXGemmKernelChoice", "MXLinearConfig", "MXLinearRecipeName", "MXFPInferenceConfig", diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 388af07874..d57b91b85f 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -15,20 +15,7 @@ DTYPE_TO_SHORT_STR, SUPPORTED_ELEM_DTYPES, ) - - -class MXGemmKernelChoice(Enum): - # always available - MX operands are dequantized and a high precision - # gemm is run - EMULATED = "emulated" - - # available only when CUDA capability is greater than or equal to 10.0 - CUTLASS = "cutlass" - - # available only when CUDA capability is greater than or equal to 10.0 - # available on recent versions of PyTorch nightly, with https://github.com/pytorch/pytorch/pull/147548 - # note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873 - CUBLAS = "cublas" +from torchao.quantization.quantize_.common.kernel_preference import KernelPreference class MXFP8Dim1CastKernelChoice(Enum): @@ -85,22 +72,17 @@ def _validate_elem_dtype(elem_dtype): ) -def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): - if gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: - assert block_size == 32, ( - f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}" - ) - valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2] - assert elem_dtype in valid_dtypes, ( - f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" - ) - elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS: - assert block_size in [16, 32], ( - f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}" - ) - valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2] - assert elem_dtype in valid_dtypes, ( - f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" +def _validate_kernel_preference(kernel_preference, block_size, elem_dtype): + if kernel_preference == KernelPreference.AUTO: + if elem_dtype in (torch.float8_e4m3fn, torch.float4_e2m1fn_x2): + assert block_size == 32, f"block_size must be 32, got {block_size}" + else: + raise AssertionError( + f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}" + ) + else: + assert kernel_preference == KernelPreference.EMULATED, ( + f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}" ) @@ -135,9 +117,9 @@ class MXLinearConfig(AOBaseConfig): elem_dtype_weight_override: Optional[Any] = None elem_dtype_grad_output_override: Optional[Any] = None - # defines the gemm kernel choice, if the chosen kernel is not supported + # defines the kernel preference, if the chosen kernel is not supported # on the given hardware an exception will be thrown - gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED + kernel_preference: KernelPreference = KernelPreference.EMULATED # define which kernel to use for mxfp8 casting # TODO(1945): remove this config option once torch.compile gives us @@ -150,15 +132,15 @@ class MXLinearConfig(AOBaseConfig): def __post_init__(self): _validate_elem_dtype(self.elem_dtype) - _validate_gemm_kernel_choice( - self.gemm_kernel_choice, self.block_size, self.elem_dtype + _validate_kernel_preference( + self.kernel_preference, self.block_size, self.elem_dtype ) if self.elem_dtype_weight_override is not None: _validate_elem_dtype(self.elem_dtype_weight_override) - assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported" + assert self.kernel_preference == KernelPreference.EMULATED, "unsupported" if self.elem_dtype_grad_output_override is not None: _validate_elem_dtype(self.elem_dtype_grad_output_override) - assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported" + assert self.kernel_preference == KernelPreference.EMULATED, "unsupported" _validate_mxfp8_cast_kernel_choice( self.mxfp8_cast_kernel_choice, self.scale_calculation_mode ) @@ -182,12 +164,12 @@ def from_recipe_name( return MXLinearConfig() elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS: return MXLinearConfig( - gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + kernel_preference=KernelPreference.AUTO, mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, ) elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL: return MXLinearConfig( - gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + kernel_preference=KernelPreference.AUTO, mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=ScaleCalculationMode.RCEIL, ) @@ -196,7 +178,7 @@ def from_recipe_name( elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS: return MXLinearConfig( elem_dtype=torch.float4_e2m1fn_x2, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + kernel_preference=KernelPreference.AUTO, ) else: raise AssertionError(f"unknown recipe_name {recipe_name}") @@ -212,7 +194,7 @@ def short_str(self) -> str: ) if self.elem_dtype_grad_output_override is not None: s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}" - s += f", kernel={self.gemm_kernel_choice.value}" + s += f", kernel={self.kernel_preference.value}" s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}" if self.scale_calculation_mode != ScaleCalculationMode.FLOOR: s += f", scale_calculation_mode={self.scale_calculation_mode}" diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 2ff4eedf5f..5991d8557e 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -10,12 +10,9 @@ import torch from torchao.core.config import AOBaseConfig -from torchao.prototype.mx_formats import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.config import ( _validate_elem_dtype, - _validate_gemm_kernel_choice, + _validate_kernel_preference, ) from torchao.prototype.mx_formats.mx_tensor import ( MXTensor, @@ -29,6 +26,7 @@ per_tensor_amax_to_scale, ) from torchao.quantization.quant_api import _quantization_type +from torchao.quantization.quantize_.common.kernel_preference import KernelPreference from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -80,7 +78,7 @@ class MXFPInferenceConfig(AOBaseConfig): weight_dtype: torch.dtype = torch.float8_e4m3fn # Which kernel to run for mm - gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS + kernel_preference: KernelPreference = KernelPreference.AUTO def __post_init__(self): assert self.activation_dtype == self.weight_dtype, ( @@ -88,8 +86,8 @@ def __post_init__(self): ) _validate_elem_dtype(self.activation_dtype) _validate_elem_dtype(self.weight_dtype) - _validate_gemm_kernel_choice( - self.gemm_kernel_choice, self.block_size, self.weight_dtype + _validate_kernel_preference( + self.kernel_preference, self.block_size, self.weight_dtype ) @@ -109,7 +107,7 @@ def _mx_inference_linear_transform( act_quant_kwargs = QuantizeTensorToMXKwargs( elem_dtype=config.activation_dtype, block_size=config.block_size, - gemm_kernel_choice=config.gemm_kernel_choice, + kernel_preference=config.kernel_preference, is_swizzled_scales=True, ) @@ -118,7 +116,7 @@ def _mx_inference_linear_transform( weight, config.weight_dtype, block_size=config.block_size, - gemm_kernel_choice=config.gemm_kernel_choice, + kernel_preference=config.kernel_preference, act_quant_kwargs=act_quant_kwargs, is_swizzled_scales=True, ) @@ -211,7 +209,6 @@ def _nvfp4_inference_linear_transform( MXTensor, NVFP4Tensor, NVFP4MMConfig, - MXGemmKernelChoice, QuantizeTensorToMXKwargs, QuantizeTensorToNVFP4Kwargs, ScaleCalculationMode, diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 19d658a6fc..8b9d1576c4 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -14,12 +14,12 @@ from torchao.prototype.mx_formats.config import ( MXFP8Dim1CastKernelChoice, - MXGemmKernelChoice, MXLinearConfig, ScaleCalculationMode, ) from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper +from torchao.quantization.quantize_.common.kernel_preference import KernelPreference from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -44,7 +44,7 @@ def forward( w_elem_dtype: Any, grad_elem_dtype: Any, block_size: int, - gemm_kernel_choice: MXGemmKernelChoice, + kernel_preference: KernelPreference, mxfp8_cast_kernel_choice: MXFP8Dim1CastKernelChoice, scale_calculation_mode: ScaleCalculationMode, ): @@ -53,7 +53,7 @@ def forward( ctx.w_elem_dtype = w_elem_dtype ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size - ctx.gemm_kernel_choice = gemm_kernel_choice + ctx.kernel_preference = kernel_preference ctx.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice ctx.scale_calculation_mode = scale_calculation_mode @@ -65,14 +65,14 @@ def forward( input_hp_r, in_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) weight_mx_dim0 = MXTensor.to_mx( weight_hp, w_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) @@ -87,7 +87,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): w_elem_dtype = ctx.w_elem_dtype grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size - gemm_kernel_choice = ctx.gemm_kernel_choice + kernel_preference = ctx.kernel_preference mxfp8_cast_kernel_choice = ctx.mxfp8_cast_kernel_choice scale_calculation_mode = ctx.scale_calculation_mode @@ -102,7 +102,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_output_hp_r, grad_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) @@ -112,7 +112,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): block_size, w_elem_dtype, weight_hp.dtype, - gemm_kernel_choice, + kernel_preference, mxfp8_cast_kernel_choice, scale_calculation_mode, ) @@ -122,7 +122,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): weight_hp_t_c, w_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) @@ -137,7 +137,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): block_size, grad_elem_dtype, grad_output_hp_r.dtype, - gemm_kernel_choice, + kernel_preference, mxfp8_cast_kernel_choice, scale_calculation_mode, ) @@ -146,7 +146,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) @@ -156,7 +156,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): block_size, in_elem_dtype, input_hp_r.dtype, - gemm_kernel_choice, + kernel_preference, mxfp8_cast_kernel_choice, scale_calculation_mode, ) @@ -166,7 +166,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): input_hp_r.t().contiguous(), in_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() @@ -215,7 +215,7 @@ def forward(self, x): config.elem_dtype_weight_override or config.elem_dtype, config.elem_dtype_grad_output_override or config.elem_dtype, config.block_size, - config.gemm_kernel_choice, + config.kernel_preference, config.mxfp8_cast_kernel_choice, config.scale_calculation_mode, ) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 7a1b5a160b..74f37bc2df 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -29,7 +29,7 @@ from torch.utils._pytree import tree_map import torchao.ops -from torchao.prototype.mx_formats.config import MXGemmKernelChoice, ScaleCalculationMode +from torchao.prototype.mx_formats.config import ScaleCalculationMode from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP6_E2M3, @@ -69,6 +69,7 @@ from torchao.quantization.quantize_.common import ( QuantizeTensorKwargs, ) +from torchao.quantization.quantize_.common.kernel_preference import KernelPreference from torchao.utils import TorchAOBaseTensor, fill_defaults aten = torch.ops.aten @@ -87,7 +88,7 @@ class QuantizeTensorToMXKwargs(QuantizeTensorKwargs): elem_dtype: Union[torch.dtype, str] = torch.float8_e4m3fn block_size: int = 32 scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR - gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED + kernel_preference: KernelPreference = KernelPreference.EMULATED is_swizzled_scales: bool = False @@ -438,7 +439,7 @@ class MXTensor(TorchAOBaseTensor): "_elem_dtype", "block_size", "_orig_dtype", - "_gemm_kernel_choice", + "kernel_preference", "act_quant_kwargs", "_is_swizzled_scales", ] @@ -450,7 +451,7 @@ def __new__( elem_dtype, block_size, orig_dtype, - gemm_kernel_choice, + kernel_preference, act_quant_kwargs, is_swizzled_scales, ): @@ -487,7 +488,7 @@ def __new__( self._elem_dtype = elem_dtype self.block_size = block_size self._orig_dtype = orig_dtype - self._gemm_kernel_choice = gemm_kernel_choice + self.kernel_preference = kernel_preference self.act_quant_kwargs = act_quant_kwargs self._is_swizzled_scales = is_swizzled_scales return self @@ -497,7 +498,7 @@ def __repr__(self): return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}, _is_swizzled_scales={self._is_swizzled_scales}" # noqa: E501 def _quantization_type(self): - return f"{self._elem_dtype=}, {self.block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" + return f"{self._elem_dtype=}, {self.block_size=}, {self._orig_dtype=}, {self.kernel_preference=}, {self.act_quant_kwargs=}" def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: @@ -534,7 +535,7 @@ def to_mx( block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, # TODO(future PR): switch default gemm to cublas - gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, + kernel_preference: KernelPreference = KernelPreference.EMULATED, act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None, is_swizzled_scales: bool = False, ): @@ -551,7 +552,7 @@ def to_mx( elem_dtype, block_size, data_hp.dtype, - gemm_kernel_choice, + kernel_preference, act_quant_kwargs, is_swizzled_scales, ) @@ -569,7 +570,7 @@ def to_mx( elem_dtype, block_size, data_hp.dtype, - gemm_kernel_choice, + kernel_preference, act_quant_kwargs, is_swizzled_scales, ) @@ -589,8 +590,8 @@ def _(func, types, args, kwargs): def _get_gemm_choice( - choice_a: Optional[MXGemmKernelChoice], choice_b: Optional[MXGemmKernelChoice] -) -> MXGemmKernelChoice: + choice_a: Optional[KernelPreference], choice_b: Optional[KernelPreference] +) -> KernelPreference: if choice_a is not None and choice_b is not None: assert choice_a == choice_b, ( "Both MXTensor inputs must have the same gemm config if specified" @@ -620,13 +621,13 @@ def _addmm_mx_dispatch( k.elem_dtype, k.block_size, k.scaling_mode, - k.gemm_kernel_choice, + k.kernel_preference, k.is_swizzled_scales, ) - gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) + gemm_choice = _get_gemm_choice(a.kernel_preference, b.kernel_preference) - if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): + if gemm_choice == KernelPreference.AUTO: # real MX gemm backed by torchao's CUTLASS kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a.qdata.is_contiguous() @@ -648,10 +649,6 @@ def _addmm_mx_dispatch( if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn - assert gemm_choice is MXGemmKernelChoice.CUBLAS, ( - "CUBLAS is the only supported kernel choice for MX FP8 operations" - ) - res = torch._scaled_mm( a.qdata, b.qdata, @@ -663,7 +660,6 @@ def _addmm_mx_dispatch( else: assert a._elem_dtype == torch.float4_e2m1fn_x2 assert b._elem_dtype == torch.float4_e2m1fn_x2 - assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported" # FP4 operations res = torchao.ops.mx_fp4_bf16( a.qdata, b.qdata, a_scale_block, b_scale_block @@ -673,6 +669,7 @@ def _addmm_mx_dispatch( res = res + bias else: + assert gemm_choice == KernelPreference.EMULATED, "unimplemented" # emulated MX gemm a_hp = a.dequantize(a._orig_dtype) b_hp = b.dequantize(b._orig_dtype) @@ -738,7 +735,7 @@ def mx_t(func, types, args, kwargs): old._elem_dtype, old.block_size, old._orig_dtype, - old._gemm_kernel_choice, + old.kernel_preference, old.act_quant_kwargs, old._is_swizzled_scales, ) @@ -779,7 +776,7 @@ def mx_view_op(func, types, args, kwargs): args[0]._elem_dtype, args[0].block_size, args[0]._orig_dtype, - args[0]._gemm_kernel_choice, + args[0].kernel_preference, args[0].act_quant_kwargs, args[0]._is_swizzled_scales, ) @@ -804,7 +801,7 @@ def mx_slice(func, types, args, kwargs): x._elem_dtype, x.block_size, x._orig_dtype, - x._gemm_kernel_choice, + x.kernel_preference, x.act_quant_kwargs, x._is_swizzled_scales, ), @@ -838,7 +835,7 @@ def mx_select(func, types, args, kwargs): old_mx_tensor._elem_dtype, old_mx_tensor.block_size, old_mx_tensor._orig_dtype, - old_mx_tensor._gemm_kernel_choice, + old_mx_tensor.kernel_preference, old_mx_tensor.act_quant_kwargs, old_mx_tensor._is_swizzled_scales, ) diff --git a/torchao/quantization/quantize_/common/kernel_preference.py b/torchao/quantization/quantize_/common/kernel_preference.py index 8f53f55c6a..45ae4d2ab6 100644 --- a/torchao/quantization/quantize_/common/kernel_preference.py +++ b/torchao/quantization/quantize_/common/kernel_preference.py @@ -30,5 +30,13 @@ class KernelPreference(str, Enum): """ FBGEMM = "fbgemm" + """Emulates gemm_lowp(A, B) with gemm_fp32(A.dequantize(), B.dequantize()). + Intended use cases are: + 1. Running CI for product logic on hardware which does not support the + actual lowp gemm. + 2. Debugging kernel numerics issues. + """ + EMULATED = "emulated" + torch.serialization.add_safe_globals([KernelPreference])