Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@
)

import torchao
from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
)
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
Expand Down Expand Up @@ -439,13 +436,13 @@ def run(
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
kernel_preference=KernelPreference.AUTO,
)
elif recipe_name == "mxfp4_cutlass":
config = MXFPInferenceConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
kernel_preference=KernelPreference.AUTO,
)
elif recipe_name == "nvfp4":
config = NVFP4InferenceConfig(
Expand Down
16 changes: 6 additions & 10 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
import torch.nn as nn
from torch.profiler import ProfilerActivity, profile

from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
)
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.quantize_.common import KernelPreference
from torchao.quantization.utils import compute_error
from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm
from torchao.utils import (
Expand Down Expand Up @@ -105,15 +103,13 @@ def test_inference_workflow_mx(
m_mx = copy.deepcopy(m)

if emulate:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can just use KernelPreference in parametrize as well I think

kernel_choice = MXGemmKernelChoice.EMULATED
elif elem_dtype == torch.float4_e2m1fn_x2:
kernel_choice = MXGemmKernelChoice.CUTLASS
kernel_choice = KernelPreference.EMULATED
else:
kernel_choice = MXGemmKernelChoice.CUBLAS
kernel_choice = KernelPreference.AUTO
config = MXFPInferenceConfig(
activation_dtype=elem_dtype,
weight_dtype=elem_dtype,
gemm_kernel_choice=kernel_choice,
kernel_preference=kernel_choice,
)
quantize_(m_mx, config=config)
if compile:
Expand Down Expand Up @@ -254,7 +250,7 @@ def test_slice_and_copy_similar_to_vllm(self):
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
kernel_preference=KernelPreference.EMULATED,
)
self._test_slice_and_copy_similar_to_vllm(config)

Expand All @@ -267,7 +263,7 @@ def test_narrow_similar_to_vllm(self):
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
kernel_preference=KernelPreference.EMULATED,
)
self._test_narrow_similar_to_vllm(config)

Expand Down
6 changes: 2 additions & 4 deletions test/prototype/mx_formats/test_mx_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
import torch
import torch.nn as nn

from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
)
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.quantize_.common import KernelPreference
from torchao.utils import (
is_sm_at_least_100,
torch_version_at_least,
Expand All @@ -46,7 +44,7 @@ def test_serialization(recipe_name):
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
kernel_preference=KernelPreference.EMULATED,
)
else:
assert recipe_name == "nvfp4", "unsupported"
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
Expand All @@ -25,6 +24,7 @@
to_dtype,
)
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
from torchao.quantization.quantize_.common import KernelPreference
from torchao.quantization.utils import compute_error
from torchao.utils import (
is_sm_at_least_89,
Expand Down Expand Up @@ -375,7 +375,7 @@ def test_exponent_nan_out(elem_dtype):
elem_dtype,
block_size,
torch.float,
MXGemmKernelChoice.EMULATED,
KernelPreference.EMULATED,
None,
False,
)
Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
)
from torchao.prototype.mx_formats.config import (
MXFP8Dim1CastKernelChoice,
MXGemmKernelChoice,
ScaleCalculationMode,
)
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper
from torchao.quantization.quantize_.common import KernelPreference

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -412,7 +412,7 @@ def backward(ctx, grad_out: torch.Tensor):
block_size,
elem_dtype=torch.float8_e4m3fn,
hp_dtype=grad_out.dtype,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
kernel_preference=KernelPreference.AUTO, # Not used
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
scale_calculation_mode=scale_calculation_mode,
)
Expand All @@ -428,7 +428,7 @@ def backward(ctx, grad_out: torch.Tensor):
block_size,
elem_dtype=torch.float8_e4m3fn,
hp_dtype=A.dtype,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
kernel_preference=KernelPreference.AUTO, # Not used
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
scale_calculation_mode=scale_calculation_mode,
)
Expand Down Expand Up @@ -475,7 +475,7 @@ def _to_mxfp8_dim1_3d(
block_size,
elem_dtype=torch.float8_e4m3fn,
hp_dtype=B_reshaped.dtype,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
kernel_preference=KernelPreference.AUTO, # Not used
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
scale_calculation_mode=scaling_mode,
)
Expand Down
22 changes: 10 additions & 12 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ Below is a toy training loop. For an example real training loop, see our torchti
import torch
from torchao.quantization import quantize_
import torchao.prototype.mx_formats
from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice, ScaleCalculationMode
from torchao.prototype.mx_formats import MXLinearConfig, ScaleCalculationMode
from torchao.quantization.quantize_.common import KernelPreference

# on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS
# on older NVIDIA gpus, you can run training with emulated MX gemm
# gemm_kernel_choice = MXGemmKernelChoice.EMULATED
# low precision gemm, requires CUDA capability 10.0+
kernel_preference = KernelPreference.AUTO
# or, emulated gemm
# kernel_preference = KernelPreference.EMULATED

scale_calculation_mode = ScaleCalculationMode.FLOOR
# other supported modes: RCEIL, CEIL, EVEN
Expand All @@ -89,7 +89,7 @@ m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
config = MXLinearConfig(
elem_dtype=torch.float8_e4m3fn,
block_size=32,
gemm_kernel_choice=gemm_kernel_choice,
kernel_preference=kernel_preference,
scale_calculation_mode=scale_calculation_mode,
)
quantize_(m, config)
Expand All @@ -107,14 +107,12 @@ import torch
import torch.nn as nn
from torchao.quantization import quantize_
import torchao.prototype.mx_formats
from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
)
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
from torchao.quantization.quantize_.common import KernelPreference

m = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda")
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
Expand All @@ -125,7 +123,7 @@ m_mxfp8 = copy.deepcopy(m)
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
kernel_preference=KernelPreference.AUTO,
)
quantize_(m_mxfp8, config=config)
m_mxfp8 = torch.compile(m_mxfp8, fullgraph=True)
Expand All @@ -137,7 +135,7 @@ m_mxfp4 = copy.deepcopy(m)
config = MXFPInferenceConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
kernel_preference=KernelPreference.AUTO,
)
quantize_(m_mxfp4, config=config)
m_mxfp4 = torch.compile(m_mxfp4, fullgraph=True)
Expand Down
2 changes: 0 additions & 2 deletions torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
MXLinearConfig,
MXLinearRecipeName,
)
Expand All @@ -16,7 +15,6 @@
import torchao.prototype.mx_formats.mx_linear # noqa: F401

__all__ = [
"MXGemmKernelChoice",
"MXLinearConfig",
"MXLinearRecipeName",
"MXFPInferenceConfig",
Expand Down
62 changes: 22 additions & 40 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,7 @@
DTYPE_TO_SHORT_STR,
SUPPORTED_ELEM_DTYPES,
)


class MXGemmKernelChoice(Enum):
# always available - MX operands are dequantized and a high precision
# gemm is run
EMULATED = "emulated"

# available only when CUDA capability is greater than or equal to 10.0
CUTLASS = "cutlass"

# available only when CUDA capability is greater than or equal to 10.0
# available on recent versions of PyTorch nightly, with https://github.com/pytorch/pytorch/pull/147548
# note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873
CUBLAS = "cublas"
from torchao.quantization.quantize_.common.kernel_preference import KernelPreference


class MXFP8Dim1CastKernelChoice(Enum):
Expand Down Expand Up @@ -85,22 +72,17 @@ def _validate_elem_dtype(elem_dtype):
)


def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
if gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
assert block_size == 32, (
f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}"
)
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
assert elem_dtype in valid_dtypes, (
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
)
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
assert block_size in [16, 32], (
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
)
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
assert elem_dtype in valid_dtypes, (
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
def _validate_kernel_preference(kernel_preference, block_size, elem_dtype):
if kernel_preference == KernelPreference.AUTO:
if elem_dtype in (torch.float8_e4m3fn, torch.float4_e2m1fn_x2):
assert block_size == 32, f"block_size must be 32, got {block_size}"
else:
raise AssertionError(
f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}"
)
else:
assert kernel_preference == KernelPreference.EMULATED, (
f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}"
)


Expand Down Expand Up @@ -135,9 +117,9 @@ class MXLinearConfig(AOBaseConfig):
elem_dtype_weight_override: Optional[Any] = None
elem_dtype_grad_output_override: Optional[Any] = None

# defines the gemm kernel choice, if the chosen kernel is not supported
# defines the kernel preference, if the chosen kernel is not supported
# on the given hardware an exception will be thrown
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
kernel_preference: KernelPreference = KernelPreference.EMULATED

# define which kernel to use for mxfp8 casting
# TODO(1945): remove this config option once torch.compile gives us
Expand All @@ -150,15 +132,15 @@ class MXLinearConfig(AOBaseConfig):

def __post_init__(self):
_validate_elem_dtype(self.elem_dtype)
_validate_gemm_kernel_choice(
self.gemm_kernel_choice, self.block_size, self.elem_dtype
_validate_kernel_preference(
self.kernel_preference, self.block_size, self.elem_dtype
)
if self.elem_dtype_weight_override is not None:
_validate_elem_dtype(self.elem_dtype_weight_override)
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"
assert self.kernel_preference == KernelPreference.EMULATED, "unsupported"
if self.elem_dtype_grad_output_override is not None:
_validate_elem_dtype(self.elem_dtype_grad_output_override)
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"
assert self.kernel_preference == KernelPreference.EMULATED, "unsupported"
_validate_mxfp8_cast_kernel_choice(
self.mxfp8_cast_kernel_choice, self.scale_calculation_mode
)
Expand All @@ -182,12 +164,12 @@ def from_recipe_name(
return MXLinearConfig()
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
return MXLinearConfig(
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
kernel_preference=KernelPreference.AUTO,
mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
)
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL:
return MXLinearConfig(
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
kernel_preference=KernelPreference.AUTO,
mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
scale_calculation_mode=ScaleCalculationMode.RCEIL,
)
Expand All @@ -196,7 +178,7 @@ def from_recipe_name(
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
return MXLinearConfig(
elem_dtype=torch.float4_e2m1fn_x2,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
kernel_preference=KernelPreference.AUTO,
)
else:
raise AssertionError(f"unknown recipe_name {recipe_name}")
Expand All @@ -212,7 +194,7 @@ def short_str(self) -> str:
)
if self.elem_dtype_grad_output_override is not None:
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
s += f", kernel={self.gemm_kernel_choice.value}"
s += f", kernel={self.kernel_preference.value}"
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
if self.scale_calculation_mode != ScaleCalculationMode.FLOOR:
s += f", scale_calculation_mode={self.scale_calculation_mode}"
Expand Down
Loading
Loading