1515 DTYPE_TO_SHORT_STR ,
1616 SUPPORTED_ELEM_DTYPES ,
1717)
18-
19-
20- class MXGemmKernelChoice (Enum ):
21- # always available - MX operands are dequantized and a high precision
22- # gemm is run
23- EMULATED = "emulated"
24-
25- # available only when CUDA capability is greater than or equal to 10.0
26- CUTLASS = "cutlass"
27-
28- # available only when CUDA capability is greater than or equal to 10.0
29- # available on recent versions of PyTorch nightly, with https://github.com/pytorch/pytorch/pull/147548
30- # note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873
31- CUBLAS = "cublas"
18+ from torchao .quantization .quantize_ .common .kernel_preference import KernelPreference
3219
3320
3421class MXFP8Dim1CastKernelChoice (Enum ):
@@ -85,22 +72,17 @@ def _validate_elem_dtype(elem_dtype):
8572 )
8673
8774
88- def _validate_gemm_kernel_choice (gemm_kernel_choice , block_size , elem_dtype ):
89- if gemm_kernel_choice == MXGemmKernelChoice .CUTLASS :
90- assert block_size == 32 , (
91- f"block_size must be 32 to use the CUTLASS MX gemm kernels, got { block_size } "
92- )
93- valid_dtypes = [torch .float8_e4m3fn , torch .float4_e2m1fn_x2 ]
94- assert elem_dtype in valid_dtypes , (
95- f"elem_dtype must be one of { valid_dtypes } to use the CUTLASS MX gemm kernels, got { elem_dtype } "
96- )
97- elif gemm_kernel_choice == MXGemmKernelChoice .CUBLAS :
98- assert block_size in [16 , 32 ], (
99- f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got { block_size } "
100- )
101- valid_dtypes = [torch .float8_e4m3fn , torch .float4_e2m1fn_x2 ]
102- assert elem_dtype in valid_dtypes , (
103- f"elem_dtype must be one of { valid_dtypes } to use the CUTLASS MX gemm kernels, got { elem_dtype } "
75+ def _validate_kernel_preference (kernel_preference , block_size , elem_dtype ):
76+ if kernel_preference == KernelPreference .AUTO :
77+ if elem_dtype in (torch .float8_e4m3fn , torch .float4_e2m1fn_x2 ):
78+ assert block_size == 32 , f"block_size must be 32, got { block_size } "
79+ else :
80+ raise AssertionError (
81+ f"unsupported { kernel_preference = } , { block_size = } , { elem_dtype = } "
82+ )
83+ else :
84+ assert kernel_preference == KernelPreference .EMULATED , (
85+ f"unsupported { kernel_preference = } , { block_size = } , { elem_dtype = } "
10486 )
10587
10688
@@ -135,9 +117,9 @@ class MXLinearConfig(AOBaseConfig):
135117 elem_dtype_weight_override : Optional [Any ] = None
136118 elem_dtype_grad_output_override : Optional [Any ] = None
137119
138- # defines the gemm kernel choice , if the chosen kernel is not supported
120+ # defines the kernel preference , if the chosen kernel is not supported
139121 # on the given hardware an exception will be thrown
140- gemm_kernel_choice : MXGemmKernelChoice = MXGemmKernelChoice .EMULATED
122+ kernel_preference : KernelPreference = KernelPreference .EMULATED
141123
142124 # define which kernel to use for mxfp8 casting
143125 # TODO(1945): remove this config option once torch.compile gives us
@@ -150,15 +132,15 @@ class MXLinearConfig(AOBaseConfig):
150132
151133 def __post_init__ (self ):
152134 _validate_elem_dtype (self .elem_dtype )
153- _validate_gemm_kernel_choice (
154- self .gemm_kernel_choice , self .block_size , self .elem_dtype
135+ _validate_kernel_preference (
136+ self .kernel_preference , self .block_size , self .elem_dtype
155137 )
156138 if self .elem_dtype_weight_override is not None :
157139 _validate_elem_dtype (self .elem_dtype_weight_override )
158- assert self .gemm_kernel_choice == MXGemmKernelChoice .EMULATED , "unsupported"
140+ assert self .kernel_preference == KernelPreference .EMULATED , "unsupported"
159141 if self .elem_dtype_grad_output_override is not None :
160142 _validate_elem_dtype (self .elem_dtype_grad_output_override )
161- assert self .gemm_kernel_choice == MXGemmKernelChoice .EMULATED , "unsupported"
143+ assert self .kernel_preference == KernelPreference .EMULATED , "unsupported"
162144 _validate_mxfp8_cast_kernel_choice (
163145 self .mxfp8_cast_kernel_choice , self .scale_calculation_mode
164146 )
@@ -182,12 +164,12 @@ def from_recipe_name(
182164 return MXLinearConfig ()
183165 elif recipe_name is MXLinearRecipeName .MXFP8_CUBLAS :
184166 return MXLinearConfig (
185- gemm_kernel_choice = MXGemmKernelChoice . CUBLAS ,
167+ kernel_preference = KernelPreference . AUTO ,
186168 mxfp8_cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
187169 )
188170 elif recipe_name is MXLinearRecipeName .MXFP8_CUBLAS_RCEIL :
189171 return MXLinearConfig (
190- gemm_kernel_choice = MXGemmKernelChoice . CUBLAS ,
172+ kernel_preference = KernelPreference . AUTO ,
191173 mxfp8_cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
192174 scale_calculation_mode = ScaleCalculationMode .RCEIL ,
193175 )
@@ -196,7 +178,7 @@ def from_recipe_name(
196178 elif recipe_name is MXLinearRecipeName .MXFP4_CUTLASS :
197179 return MXLinearConfig (
198180 elem_dtype = torch .float4_e2m1fn_x2 ,
199- gemm_kernel_choice = MXGemmKernelChoice . CUTLASS ,
181+ kernel_preference = KernelPreference . AUTO ,
200182 )
201183 else :
202184 raise AssertionError (f"unknown recipe_name { recipe_name } " )
@@ -212,7 +194,7 @@ def short_str(self) -> str:
212194 )
213195 if self .elem_dtype_grad_output_override is not None :
214196 s += f", lp_go_override={ DTYPE_TO_SHORT_STR [self .elem_dtype_grad_output_override ]} "
215- s += f", kernel={ self .gemm_kernel_choice .value } "
197+ s += f", kernel={ self .kernel_preference .value } "
216198 s += f", mxfp8_cast_kernel_choice={ self .mxfp8_cast_kernel_choice .value } "
217199 if self .scale_calculation_mode != ScaleCalculationMode .FLOOR :
218200 s += f", scale_calculation_mode={ self .scale_calculation_mode } "
0 commit comments