Skip to content

Commit 9de70ea

Browse files
committed
4/x: mx cleanup: use kernel_preference instead of gemm_kernel_choice
Summary: Moves MX workflows to use torchao-wide `kernel_preference`. * MXGemmKernelPreference.CUBLAS -> KernelPreference.AUTO * MXGemmKernelPreference.CUTLASS -> KernelPreference.AUTO * MXGemmKernelPreference.EMULATED -> KernelPreference.EMULATED Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 726b651 ghstack-comment-id: 3572455799 Pull-Request: #3385
1 parent 36275f4 commit 9de70ea

File tree

12 files changed

+97
-126
lines changed

12 files changed

+97
-126
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@
3838
)
3939

4040
import torchao
41-
from torchao.prototype.mx_formats.config import (
42-
MXGemmKernelChoice,
43-
)
4441
from torchao.prototype.mx_formats.inference_workflow import (
4542
MXFPInferenceConfig,
4643
NVFP4InferenceConfig,
@@ -439,13 +436,13 @@ def run(
439436
config = MXFPInferenceConfig(
440437
activation_dtype=torch.float8_e4m3fn,
441438
weight_dtype=torch.float8_e4m3fn,
442-
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
439+
kernel_preference=KernelPreference.AUTO,
443440
)
444441
elif recipe_name == "mxfp4_cutlass":
445442
config = MXFPInferenceConfig(
446443
activation_dtype=torch.float4_e2m1fn_x2,
447444
weight_dtype=torch.float4_e2m1fn_x2,
448-
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
445+
kernel_preference=KernelPreference.AUTO,
449446
)
450447
elif recipe_name == "nvfp4":
451448
config = NVFP4InferenceConfig(

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
import torch.nn as nn
1313
from torch.profiler import ProfilerActivity, profile
1414

15-
from torchao.prototype.mx_formats.config import (
16-
MXGemmKernelChoice,
17-
)
1815
from torchao.prototype.mx_formats.inference_workflow import (
1916
MXFPInferenceConfig,
2017
NVFP4InferenceConfig,
2118
NVFP4MMConfig,
2219
)
2320
from torchao.quantization import quantize_
21+
from torchao.quantization.quantize_.common import KernelPreference
2422
from torchao.quantization.utils import compute_error
2523
from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm
2624
from torchao.utils import (
@@ -105,15 +103,13 @@ def test_inference_workflow_mx(
105103
m_mx = copy.deepcopy(m)
106104

107105
if emulate:
108-
kernel_choice = MXGemmKernelChoice.EMULATED
109-
elif elem_dtype == torch.float4_e2m1fn_x2:
110-
kernel_choice = MXGemmKernelChoice.CUTLASS
106+
kernel_choice = KernelPreference.EMULATED
111107
else:
112-
kernel_choice = MXGemmKernelChoice.CUBLAS
108+
kernel_choice = KernelPreference.AUTO
113109
config = MXFPInferenceConfig(
114110
activation_dtype=elem_dtype,
115111
weight_dtype=elem_dtype,
116-
gemm_kernel_choice=kernel_choice,
112+
kernel_preference=kernel_choice,
117113
)
118114
quantize_(m_mx, config=config)
119115
if compile:
@@ -254,7 +250,7 @@ def test_slice_and_copy_similar_to_vllm(self):
254250
config = MXFPInferenceConfig(
255251
activation_dtype=torch.float8_e4m3fn,
256252
weight_dtype=torch.float8_e4m3fn,
257-
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
253+
kernel_preference=KernelPreference.EMULATED,
258254
)
259255
self._test_slice_and_copy_similar_to_vllm(config)
260256

@@ -267,7 +263,7 @@ def test_narrow_similar_to_vllm(self):
267263
config = MXFPInferenceConfig(
268264
activation_dtype=torch.float8_e4m3fn,
269265
weight_dtype=torch.float8_e4m3fn,
270-
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
266+
kernel_preference=KernelPreference.EMULATED,
271267
)
272268
self._test_narrow_similar_to_vllm(config)
273269

test/prototype/mx_formats/test_mx_serialization.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
import torch
1313
import torch.nn as nn
1414

15-
from torchao.prototype.mx_formats.config import (
16-
MXGemmKernelChoice,
17-
)
1815
from torchao.prototype.mx_formats.inference_workflow import (
1916
MXFPInferenceConfig,
2017
NVFP4InferenceConfig,
2118
NVFP4MMConfig,
2219
)
2320
from torchao.quantization import quantize_
21+
from torchao.quantization.quantize_.common import KernelPreference
2422
from torchao.utils import (
2523
is_sm_at_least_100,
2624
torch_version_at_least,
@@ -46,7 +44,7 @@ def test_serialization(recipe_name):
4644
config = MXFPInferenceConfig(
4745
activation_dtype=torch.float8_e4m3fn,
4846
weight_dtype=torch.float8_e4m3fn,
49-
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
47+
kernel_preference=KernelPreference.EMULATED,
5048
)
5149
else:
5250
assert recipe_name == "nvfp4", "unsupported"

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch._inductor.utils import run_and_get_code
1313
from torch.testing import FileCheck
1414

15-
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
1615
from torchao.prototype.mx_formats.constants import (
1716
DTYPE_FP6_E2M3,
1817
DTYPE_FP6_E3M2,
@@ -25,6 +24,7 @@
2524
to_dtype,
2625
)
2726
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
27+
from torchao.quantization.quantize_.common import KernelPreference
2828
from torchao.quantization.utils import compute_error
2929
from torchao.utils import (
3030
is_sm_at_least_89,
@@ -375,7 +375,7 @@ def test_exponent_nan_out(elem_dtype):
375375
elem_dtype,
376376
block_size,
377377
torch.float,
378-
MXGemmKernelChoice.EMULATED,
378+
KernelPreference.EMULATED,
379379
None,
380380
False,
381381
)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
)
2828
from torchao.prototype.mx_formats.config import (
2929
MXFP8Dim1CastKernelChoice,
30-
MXGemmKernelChoice,
3130
ScaleCalculationMode,
3231
)
3332
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
3433
from torchao.prototype.mx_formats.mx_tensor import to_mx
3534
from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper
35+
from torchao.quantization.quantize_.common import KernelPreference
3636

3737
logger: logging.Logger = logging.getLogger(__name__)
3838

@@ -412,7 +412,7 @@ def backward(ctx, grad_out: torch.Tensor):
412412
block_size,
413413
elem_dtype=torch.float8_e4m3fn,
414414
hp_dtype=grad_out.dtype,
415-
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
415+
kernel_preference=KernelPreference.AUTO, # Not used
416416
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
417417
scale_calculation_mode=scale_calculation_mode,
418418
)
@@ -428,7 +428,7 @@ def backward(ctx, grad_out: torch.Tensor):
428428
block_size,
429429
elem_dtype=torch.float8_e4m3fn,
430430
hp_dtype=A.dtype,
431-
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
431+
kernel_preference=KernelPreference.AUTO, # Not used
432432
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
433433
scale_calculation_mode=scale_calculation_mode,
434434
)
@@ -475,7 +475,7 @@ def _to_mxfp8_dim1_3d(
475475
block_size,
476476
elem_dtype=torch.float8_e4m3fn,
477477
hp_dtype=B_reshaped.dtype,
478-
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
478+
kernel_preference=KernelPreference.AUTO, # Not used
479479
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
480480
scale_calculation_mode=scaling_mode,
481481
)

torchao/prototype/mx_formats/README.md

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ Below is a toy training loop. For an example real training loop, see our torchti
7474
import torch
7575
from torchao.quantization import quantize_
7676
import torchao.prototype.mx_formats
77-
from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice, ScaleCalculationMode
77+
from torchao.prototype.mx_formats import MXLinearConfig, ScaleCalculationMode
78+
from torchao.quantization.quantize_.common import KernelPreference
7879

79-
# on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels
80-
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
81-
# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS
82-
# on older NVIDIA gpus, you can run training with emulated MX gemm
83-
# gemm_kernel_choice = MXGemmKernelChoice.EMULATED
80+
# low precision gemm, requires CUDA capability 10.0+
81+
kernel_preference = KernelPreference.AUTO
82+
# or, emulated gemm
83+
# kernel_preference = KernelPreference.EMULATED
8484

8585
scale_calculation_mode = ScaleCalculationMode.FLOOR
8686
# other supported modes: RCEIL, CEIL, EVEN
@@ -89,7 +89,7 @@ m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
8989
config = MXLinearConfig(
9090
elem_dtype=torch.float8_e4m3fn,
9191
block_size=32,
92-
gemm_kernel_choice=gemm_kernel_choice,
92+
kernel_preference=kernel_preference,
9393
scale_calculation_mode=scale_calculation_mode,
9494
)
9595
quantize_(m, config)
@@ -107,14 +107,12 @@ import torch
107107
import torch.nn as nn
108108
from torchao.quantization import quantize_
109109
import torchao.prototype.mx_formats
110-
from torchao.prototype.mx_formats.config import (
111-
MXGemmKernelChoice,
112-
)
113110
from torchao.prototype.mx_formats.inference_workflow import (
114111
MXFPInferenceConfig,
115112
NVFP4InferenceConfig,
116113
NVFP4MMConfig,
117114
)
115+
from torchao.quantization.quantize_.common import KernelPreference
118116

119117
m = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda")
120118
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
@@ -125,7 +123,7 @@ m_mxfp8 = copy.deepcopy(m)
125123
config = MXFPInferenceConfig(
126124
activation_dtype=torch.float8_e4m3fn,
127125
weight_dtype=torch.float8_e4m3fn,
128-
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
126+
kernel_preference=KernelPreference.AUTO,
129127
)
130128
quantize_(m_mxfp8, config=config)
131129
m_mxfp8 = torch.compile(m_mxfp8, fullgraph=True)
@@ -137,7 +135,7 @@ m_mxfp4 = copy.deepcopy(m)
137135
config = MXFPInferenceConfig(
138136
activation_dtype=torch.float4_e2m1fn_x2,
139137
weight_dtype=torch.float4_e2m1fn_x2,
140-
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
138+
kernel_preference=KernelPreference.AUTO,
141139
)
142140
quantize_(m_mxfp4, config=config)
143141
m_mxfp4 = torch.compile(m_mxfp4, fullgraph=True)

torchao/prototype/mx_formats/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from torchao.prototype.mx_formats.config import (
2-
MXGemmKernelChoice,
32
MXLinearConfig,
43
MXLinearRecipeName,
54
)
@@ -16,7 +15,6 @@
1615
import torchao.prototype.mx_formats.mx_linear # noqa: F401
1716

1817
__all__ = [
19-
"MXGemmKernelChoice",
2018
"MXLinearConfig",
2119
"MXLinearRecipeName",
2220
"MXFPInferenceConfig",

torchao/prototype/mx_formats/config.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,7 @@
1515
DTYPE_TO_SHORT_STR,
1616
SUPPORTED_ELEM_DTYPES,
1717
)
18-
19-
20-
class MXGemmKernelChoice(Enum):
21-
# always available - MX operands are dequantized and a high precision
22-
# gemm is run
23-
EMULATED = "emulated"
24-
25-
# available only when CUDA capability is greater than or equal to 10.0
26-
CUTLASS = "cutlass"
27-
28-
# available only when CUDA capability is greater than or equal to 10.0
29-
# available on recent versions of PyTorch nightly, with https://github.com/pytorch/pytorch/pull/147548
30-
# note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873
31-
CUBLAS = "cublas"
18+
from torchao.quantization.quantize_.common.kernel_preference import KernelPreference
3219

3320

3421
class MXFP8Dim1CastKernelChoice(Enum):
@@ -85,22 +72,17 @@ def _validate_elem_dtype(elem_dtype):
8572
)
8673

8774

88-
def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
89-
if gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
90-
assert block_size == 32, (
91-
f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}"
92-
)
93-
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
94-
assert elem_dtype in valid_dtypes, (
95-
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
96-
)
97-
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
98-
assert block_size in [16, 32], (
99-
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
100-
)
101-
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
102-
assert elem_dtype in valid_dtypes, (
103-
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
75+
def _validate_kernel_preference(kernel_preference, block_size, elem_dtype):
76+
if kernel_preference == KernelPreference.AUTO:
77+
if elem_dtype in (torch.float8_e4m3fn, torch.float4_e2m1fn_x2):
78+
assert block_size == 32, f"block_size must be 32, got {block_size}"
79+
else:
80+
raise AssertionError(
81+
f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}"
82+
)
83+
else:
84+
assert kernel_preference == KernelPreference.EMULATED, (
85+
f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}"
10486
)
10587

10688

@@ -135,9 +117,9 @@ class MXLinearConfig(AOBaseConfig):
135117
elem_dtype_weight_override: Optional[Any] = None
136118
elem_dtype_grad_output_override: Optional[Any] = None
137119

138-
# defines the gemm kernel choice, if the chosen kernel is not supported
120+
# defines the kernel preference, if the chosen kernel is not supported
139121
# on the given hardware an exception will be thrown
140-
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
122+
kernel_preference: KernelPreference = KernelPreference.EMULATED
141123

142124
# define which kernel to use for mxfp8 casting
143125
# TODO(1945): remove this config option once torch.compile gives us
@@ -150,15 +132,15 @@ class MXLinearConfig(AOBaseConfig):
150132

151133
def __post_init__(self):
152134
_validate_elem_dtype(self.elem_dtype)
153-
_validate_gemm_kernel_choice(
154-
self.gemm_kernel_choice, self.block_size, self.elem_dtype
135+
_validate_kernel_preference(
136+
self.kernel_preference, self.block_size, self.elem_dtype
155137
)
156138
if self.elem_dtype_weight_override is not None:
157139
_validate_elem_dtype(self.elem_dtype_weight_override)
158-
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"
140+
assert self.kernel_preference == KernelPreference.EMULATED, "unsupported"
159141
if self.elem_dtype_grad_output_override is not None:
160142
_validate_elem_dtype(self.elem_dtype_grad_output_override)
161-
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"
143+
assert self.kernel_preference == KernelPreference.EMULATED, "unsupported"
162144
_validate_mxfp8_cast_kernel_choice(
163145
self.mxfp8_cast_kernel_choice, self.scale_calculation_mode
164146
)
@@ -182,12 +164,12 @@ def from_recipe_name(
182164
return MXLinearConfig()
183165
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
184166
return MXLinearConfig(
185-
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
167+
kernel_preference=KernelPreference.AUTO,
186168
mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
187169
)
188170
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL:
189171
return MXLinearConfig(
190-
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
172+
kernel_preference=KernelPreference.AUTO,
191173
mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
192174
scale_calculation_mode=ScaleCalculationMode.RCEIL,
193175
)
@@ -196,7 +178,7 @@ def from_recipe_name(
196178
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
197179
return MXLinearConfig(
198180
elem_dtype=torch.float4_e2m1fn_x2,
199-
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
181+
kernel_preference=KernelPreference.AUTO,
200182
)
201183
else:
202184
raise AssertionError(f"unknown recipe_name {recipe_name}")
@@ -212,7 +194,7 @@ def short_str(self) -> str:
212194
)
213195
if self.elem_dtype_grad_output_override is not None:
214196
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
215-
s += f", kernel={self.gemm_kernel_choice.value}"
197+
s += f", kernel={self.kernel_preference.value}"
216198
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
217199
if self.scale_calculation_mode != ScaleCalculationMode.FLOOR:
218200
s += f", scale_calculation_mode={self.scale_calculation_mode}"

0 commit comments

Comments
 (0)