From 3c5d971866c55f4602c6ddf3e8f5f829d371a53b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 18 Sep 2025 15:25:04 -0700 Subject: [PATCH] Remove FbgemmConfig and remaining Fbgemm tensors Summary: This is used for prototype previously, not used now, we now expose fbgemm kernels through Int4WeightOnlyConfig (for int4) and Float8DynamicActivationFloat8WeightConfig (for FP8) Not considering this BC breaking since we haven't publicized the API yet Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- docs/source/torchao_vllm_integration.md | 2 +- test/core/test_config.py | 2 - test/dtypes/test_affine_quantized.py | 7 - torchao/_models/llama/generate.py | 19 -- torchao/dtypes/__init__.py | 1 - torchao/dtypes/fbgemm_fp8_tensor.py | 268 ------------------------ torchao/quantization/__init__.py | 2 - torchao/quantization/quant_api.py | 83 -------- 8 files changed, 1 insertion(+), 383 deletions(-) delete mode 100644 torchao/dtypes/fbgemm_fp8_tensor.py diff --git a/docs/source/torchao_vllm_integration.md b/docs/source/torchao_vllm_integration.md index 870a6c2958..dbe3e6ef05 100644 --- a/docs/source/torchao_vllm_integration.md +++ b/docs/source/torchao_vllm_integration.md @@ -171,7 +171,7 @@ class MyNewQuantConfig(AOBaseConfig): VERSION: ClassVar[int] = 1 class MyQuantizedTensor(TorchAOBaseTensor): - """Example based on FbgemmFp8Tensor - stores quantized data + scale""" + """Example based on Float8Tensor - stores quantized data + scale""" tensor_data_attrs = ["quantized_data", "scale"] tensor_attributes = ["dtype"] diff --git a/test/core/test_config.py b/test/core/test_config.py index 0bf975fa3b..0df31194ac 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -24,7 +24,6 @@ AWQStep, ) from torchao.quantization.quant_api import ( - FbgemmConfig, Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, Float8WeightOnlyConfig, @@ -92,7 +91,6 @@ ), AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOADING), AWQConfig(Int4WeightOnlyConfig(group_size=128), step="prepare_for_loading"), - FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256]), ] diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 56f42a8043..83f32c8420 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -24,9 +24,7 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from torchao.float8.config import e4m3_dtype from torchao.quantization import ( - FbgemmConfig, Float8WeightOnlyConfig, GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, @@ -44,7 +42,6 @@ is_fbcode, is_ROCM, is_sm_at_least_89, - is_sm_at_least_90, ) is_cusparselt_available = ( @@ -100,10 +97,6 @@ def get_quantization_functions( if is_sm_at_least_89(): base_functions.append(Float8WeightOnlyConfig()) - if is_sm_at_least_90(): - base_functions.append(FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16)) - base_functions.append(FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16)) - return base_functions diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 81e1ff2815..da1b848bcb 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -434,25 +434,6 @@ def ffn_or_attn_only(mod, fqn): model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1), ) - elif "fbgemm" in quantization and "int4" in quantization: - from torchao.quantization import FbgemmConfig - - _, precision, group_size = quantization.split("-") - group_size = int(group_size) - block_size = [1, group_size] - assert precision == "int4", f"FbegemmConfig({precision=}) not supported yet" - quantize_( - model, - FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, block_size), - ) - elif "fbgemm" in quantization and "fp8" in quantization: - from torchao.float8.config import e4m3_dtype - from torchao.quantization import FbgemmConfig - - quantize_( - model, - FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16), - ) elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 575e154091..07f03c7ed9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -8,7 +8,6 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8 from .floatx import ( CutlassSemiSparseLayout, Float8Layout, diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py deleted file mode 100644 index 6f007c9339..0000000000 --- a/torchao/dtypes/fbgemm_fp8_tensor.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Optional - -import torch -from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.utils import ( - TorchAOBaseTensor, - fill_defaults, -) - -__all__ = [ - "to_fbgemm_fp8", - "FbgemmFp8Tensor", -] - -aten = torch.ops.aten - - -class FbgemmFp8Tensor(TorchAOBaseTensor): - """ - TODO: needs padding for cutlass kernels - """ - - tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"] - tensor_attributes = ["dtype"] - - def __new__(cls, float8_data, scale, activation_scale_ub, dtype): - shape = float8_data.shape - kwargs = {} - kwargs["device"] = float8_data.device - kwargs["dtype"] = dtype - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, float8_data, scale, activation_scale_ub, dtype): - self.float8_data = float8_data - self.scale = scale - self.activation_scale_ub = activation_scale_ub - - def __tensor_flatten__(self): - return self.tensor_data_attrs, [ - getattr(self, attr) for attr in self.tensor_attributes - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - return cls( - *[tensor_data_dict[name] for name in cls.tensor_data_attrs], - *tensor_attributes, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], - *[getattr(self, attr) for attr in self.tensor_attributes], - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(weight={self.float8_data}, scale={self.scale}, " - f"activation_scale_ub={self.activation_scale_ub}, " - f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - - def _quantization_type(self): - return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}" - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.float8_data.to(device), - self.scale.to(device), - self.activation_scale_ub.to(device), - self.dtype, - ) - - @classmethod - def from_float( - cls, - w: torch.Tensor, - activation_scale_ub: Optional[float] = None, - ): - if activation_scale_ub is None: - activation_scale_ub = 1200.0 - - activation_scale_ub = torch.tensor( - [activation_scale_ub], - dtype=torch.float, - device=w.device, - ) - wq, w_scale = torch.ops.triton.quantize_fp8_row(w) - # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) - dtype = w.dtype - del w - return FbgemmFp8Tensor( - wq, - w_scale, - activation_scale_ub=activation_scale_ub, - dtype=dtype, - ) - - -implements = FbgemmFp8Tensor.implements - - -@implements([torch.nn.functional.linear, aten.linear.default]) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - orig_act_size = input_tensor.size() - orig_out_features = weight_tensor.shape[-2] - - # not used - num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - input_tensor, num_tokens, weight_tensor.activation_scale_ub - ) - - a_data = xq - b_data = weight_tensor.float8_data - - res = torch.ops.fbgemm.f8f8bf16_rowwise( - a_data, - b_data, - x_scale, - weight_tensor.scale, - use_fast_accum=True, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: - res = res + bias - - return res - - -@implements(torch.bmm) -def _(func, types, args, kwargs): - input_tensor, weight_tensor = ( - args[0], - args[1], - ) - orig_act_size = input_tensor.size() - # not used - num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - input_tensor, num_tokens, weight_tensor.activation_scale_ub - ) - - a_data = xq - b_data = weight_tensor.float8_data - orig_out_features = b_data.shape[-2] - - res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( - a_data, - b_data, - x_scale, - weight_tensor.scale, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - return res - - -@implements([aten.detach.default, aten.alias.default]) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -def _same_metadata(self: "FbgemmFp8Tensor", src: "FbgemmFp8Tensor") -> bool: - return ( - isinstance(self, FbgemmFp8Tensor) - and isinstance(src, FbgemmFp8Tensor) - and self.shape == src.shape - and self.float8_data.shape == src.float8_data.shape - and self.scale.shape == src.scale.shape - and self.activation_scale_ub.shape == src.activation_scale_ub.shape - and self.dtype == src.dtype - ) - - -@implements(aten.copy_.default) -def _(func, types, args, kwargs): - self = args[0] - src = args[1] - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - -@implements(aten.slice.Tensor) -def _(func, types, args, kwargs): - """Only supports slicing for dim == 1 and dim == 2 - original tensor shape has dimension (N, K) - float8_data has dimension (N, K) - scale (per row quantization) has dimension: (N,) - - since float8_data has the same dimension as original tensor, we can directly slice that - for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1 - - Note that we need to call slice on the float8_data and scale directly because slice - is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_fp8` - for - """ - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - assert step == 1 - assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" - if end >= self.shape[dim]: - end = self.shape[dim] - - assert self.float8_data.ndim == 2, ( - f"Expected packed weight to have dim 2, got {self.float8_data.dim}" - ) - - # Always slice the float8_data - sliced_data = aten.slice.Tensor( - self.float8_data, dim, start, end, step - ).contiguous() - - if dim == 0: - # scale has dimension (N,) where N is the dim 0 of `self` - # so we do the same slice on scale for dimension 0 - sliced_scale = aten.slice.Tensor(self.scale, 0, start, end, step) - else: - # since scale is per row, slicing along the dim == 1 dimension does - # not change the scale - sliced_scale = self.scale - - return return_and_correct_aliasing( - func, - args, - kwargs, - FbgemmFp8Tensor( - sliced_data, sliced_scale, self.activation_scale_ub, dtype=self.dtype - ), - ) - - -to_fbgemm_fp8 = FbgemmFp8Tensor.from_float - - -# Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([FbgemmFp8Tensor]) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 407a83bcd7..b32868b684 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -43,7 +43,6 @@ ) from .quant_api import ( CutlassInt4PackedLayout, - FbgemmConfig, Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, @@ -161,7 +160,6 @@ "GemliteUIntXWeightOnlyConfig", "AOPerModuleConfig", "ModuleFqnToConfig", - "FbgemmConfig", # tensor subclasses "Int4Tensor", "Int4PlainInt32Tensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ef4b247819..3a6ecc08a7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -46,7 +46,6 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, - to_fbgemm_fp8, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -93,7 +92,6 @@ ) from torchao.utils import ( _ConfigDeprecationWrapper, - _is_fbgemm_genai_gpu_available, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -160,7 +158,6 @@ "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", - "FbgemmConfig", ] LAYOUT_TO_ZERO_POINT_DOMAIN = { @@ -2312,86 +2309,6 @@ def _fpx_weight_only_transform( return module -@dataclass -class FbgemmConfig(AOBaseConfig): - """Quantization Config for fbgemm-genai kernels - Args: - input_dtype (torch.dtype): input dtype of the kernel - weight_dtype (torch.dtype): weight dtype of the kernel - output_dtype (torch.dtype): output dtype of the kernel - group_size (int): The group size for weight - preshuffle (bool): whether preshuffle the weights or not - """ - - input_dtype: torch.dtype - weight_dtype: torch.dtype - output_dtype: torch.dtype - block_size: Optional[List[int]] = None - activation_scale_ub: float = 1200.0 - preshuffle: bool = False - - -@register_quantize_module_handler(FbgemmConfig) -def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: - if not _is_fbgemm_genai_gpu_available(): - raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - - _SUPPORTED_DTYPES = { - (torch.bfloat16, torch.int4, torch.bfloat16), - (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16), - } - - if ( - (config.input_dtype == torch.bfloat16) - and (config.weight_dtype == torch.int4) - and (config.output_dtype == torch.bfloat16) - ): - if config.preshuffle: - weight = Int4PreshuffledTensor.from_hp( - module.weight, - config.block_size, - activation_dtype=torch.bfloat16, - ) - else: - weight = Int4Tensor.from_hp( - module.weight, - config.block_size, - ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - if ( - (config.input_dtype == e4m3_dtype) - and (config.weight_dtype == torch.int4) - and (config.output_dtype == torch.bfloat16) - ): - if config.preshuffle: - weight = Int4PreshuffledTensor.from_hp( - module.weight, - config.block_size, - activation_dtype=torch.float8_e4m3fn, - ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - elif ( - (config.input_dtype == e4m3_dtype) - and (config.weight_dtype == e4m3_dtype) - and (config.output_dtype == torch.bfloat16) - ): - weight = to_fbgemm_fp8( - module.weight, - config.activation_scale_ub, - ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - else: - raise NotImplementedError( - f"{config} is not supported. supported input, weight, output kernel dtypes are: {_SUPPORTED_DTYPES}" - ) - - @dataclass class ModuleFqnToConfig(AOBaseConfig): """Per module configurations for torchao quantize_ API