From 0bcef98e0d2b5e8de09a620fdfbfdade4322e501 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Nov 2025 12:02:50 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 6 ++-- test/integration/test_vllm.py | 4 +-- .../mx_formats/test_inference_workflow.py | 8 ++--- .../mx_formats/test_mx_serialization.py | 4 +-- torchao/prototype/mx_formats/README.md | 6 ++-- torchao/prototype/mx_formats/__init__.py | 4 +-- .../mx_formats/inference_workflow.py | 29 ++----------------- 7 files changed, 19 insertions(+), 42 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index dc732dc77a..e1eaa43c1b 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -39,7 +39,7 @@ import torchao from torchao.prototype.mx_formats.inference_workflow import ( - MXFPInferenceConfig, + MXDynamicActivationMXWeightConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) @@ -433,13 +433,13 @@ def run( kernel_preference=KernelPreference.TORCH, ) elif recipe_name == "mxfp8_cublas": - config = MXFPInferenceConfig( + config = MXDynamicActivationMXWeightConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, kernel_preference=KernelPreference.AUTO, ) elif recipe_name == "mxfp4_cutlass": - config = MXFPInferenceConfig( + config = MXDynamicActivationMXWeightConfig( activation_dtype=torch.float4_e2m1fn_x2, weight_dtype=torch.float4_e2m1fn_x2, kernel_preference=KernelPreference.AUTO, diff --git a/test/integration/test_vllm.py b/test/integration/test_vllm.py index 32a7a8b405..9371d83a99 100644 --- a/test/integration/test_vllm.py +++ b/test/integration/test_vllm.py @@ -41,7 +41,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig from vllm import LLM, SamplingParams -from torchao.prototype.mx_formats import MXFPInferenceConfig +from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.quant_api import ( CutlassInt4PackedLayout, @@ -70,7 +70,7 @@ def get_tests() -> List[TorchAoConfig]: Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout()) ) ] - SM100_TESTS = [TorchAoConfig(MXFPInferenceConfig())] + SM100_TESTS = [TorchAoConfig(MXDynamicActivationMXWeightConfig())] # Check CUDA availability first if not torch.cuda.is_available(): diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 8dad950c4c..fd03eec29f 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -13,7 +13,7 @@ from torch.profiler import ProfilerActivity, profile from torchao.prototype.mx_formats.inference_workflow import ( - MXFPInferenceConfig, + MXDynamicActivationMXWeightConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) @@ -106,7 +106,7 @@ def test_inference_workflow_mx( kernel_choice = KernelPreference.EMULATED else: kernel_choice = KernelPreference.AUTO - config = MXFPInferenceConfig( + config = MXDynamicActivationMXWeightConfig( activation_dtype=elem_dtype, weight_dtype=elem_dtype, kernel_preference=kernel_choice, @@ -247,7 +247,7 @@ class VLLMIntegrationTestCase(TorchAOIntegrationTestCase): reason="torch.compile requires PyTorch 2.8+", ) def test_slice_and_copy_similar_to_vllm(self): - config = MXFPInferenceConfig( + config = MXDynamicActivationMXWeightConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, kernel_preference=KernelPreference.EMULATED, @@ -260,7 +260,7 @@ def test_slice_and_copy_similar_to_vllm(self): reason="torch.compile requires PyTorch 2.8+", ) def test_narrow_similar_to_vllm(self): - config = MXFPInferenceConfig( + config = MXDynamicActivationMXWeightConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, kernel_preference=KernelPreference.EMULATED, diff --git a/test/prototype/mx_formats/test_mx_serialization.py b/test/prototype/mx_formats/test_mx_serialization.py index 930dc1dfaa..9649da98f7 100644 --- a/test/prototype/mx_formats/test_mx_serialization.py +++ b/test/prototype/mx_formats/test_mx_serialization.py @@ -13,7 +13,7 @@ import torch.nn as nn from torchao.prototype.mx_formats.inference_workflow import ( - MXFPInferenceConfig, + MXDynamicActivationMXWeightConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) @@ -41,7 +41,7 @@ def test_serialization(recipe_name): fname = None with tempfile.NamedTemporaryFile(delete=False, mode="w") as f: if recipe_name == "mxfp8": - config = MXFPInferenceConfig( + config = MXDynamicActivationMXWeightConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, kernel_preference=KernelPreference.EMULATED, diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 6c36c2eaed..f4a9ff1045 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -108,7 +108,7 @@ import torch.nn as nn from torchao.quantization import quantize_ import torchao.prototype.mx_formats from torchao.prototype.mx_formats.inference_workflow import ( - MXFPInferenceConfig, + MXDynamicActivationMXWeightConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) @@ -120,7 +120,7 @@ x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) # mxfp8 m_mxfp8 = copy.deepcopy(m) -config = MXFPInferenceConfig( +config = MXDynamicActivationMXWeightConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, kernel_preference=KernelPreference.AUTO, @@ -132,7 +132,7 @@ y_mxfp8 = m_mxfp8(x) # mxfp4 m_mxfp4 = copy.deepcopy(m) -config = MXFPInferenceConfig( +config = MXDynamicActivationMXWeightConfig( activation_dtype=torch.float4_e2m1fn_x2, weight_dtype=torch.float4_e2m1fn_x2, kernel_preference=KernelPreference.AUTO, diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index 8d1455d6f3..bd0d05eee1 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -5,7 +5,7 @@ # Note: Prototype and subject to change from torchao.prototype.mx_formats.inference_workflow import ( - MXFPInferenceConfig, + MXDynamicActivationMXWeightConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) @@ -17,7 +17,7 @@ __all__ = [ "MXLinearConfig", "MXLinearRecipeName", - "MXFPInferenceConfig", + "MXDynamicActivationMXWeightConfig", "NVFP4InferenceConfig", "NVFP4MMConfig", ] diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 5991d8557e..6cd011fc0a 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -36,39 +36,16 @@ ) -# TODO The naming for these configs is a little weird, rename before moving to public API -# Note: This API is extra prototype and will change in the future @dataclass -class MXFPInferenceConfig(AOBaseConfig): +class MXDynamicActivationMXWeightConfig(AOBaseConfig): """ MX Format Inference Quantization This module provides support for running inference with float8 quantization using MX formats. - The quantization flow works as follows: - - 1. Weight Quantization: - - In _mx_inference_linear_transform(), the module's weight is converted to an MXTensor - - The weight is quantized to the specified dtype (float8_e4m3fn by default) - - This happens when quantize_() is called with an MXFPInferenceConfig - - 2. Activation Quantization: - - A callable (_input_activation_quant_func_mxfp) is defined that will quantize - activations during inference to the same dtype - - This function is passed to to_linear_activation_quantized() along with the - already-quantized weight - - 3. Runtime Flow: - - When the quantized module is called, the input goes through the LinearActivationQuantizedTensor - - The input (activation) is quantized just-in-time using the provided function - - The MX quantized activation and MX weight are used together in F.linear Requirements: - NVIDIA SM100+ hardware (Blackwell or newer) is required for execution - PyTorch 2.5+ for proper serialization support - - See also: - - LinearActivationQuantizedTensor in torchao.quantization.quant_api - - MXTensor in torchao.prototype.mx_formats.mx_tensor """ block_size: int = 32 @@ -95,9 +72,9 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" -@register_quantize_module_handler(MXFPInferenceConfig) +@register_quantize_module_handler(MXDynamicActivationMXWeightConfig) def _mx_inference_linear_transform( - module: torch.nn.Module, config: MXFPInferenceConfig + module: torch.nn.Module, config: MXDynamicActivationMXWeightConfig ): weight = module.weight