diff --git a/docs/source/torchao_vllm_integration.md b/docs/source/torchao_vllm_integration.md index 870a6c2958..dbe3e6ef05 100644 --- a/docs/source/torchao_vllm_integration.md +++ b/docs/source/torchao_vllm_integration.md @@ -171,7 +171,7 @@ class MyNewQuantConfig(AOBaseConfig): VERSION: ClassVar[int] = 1 class MyQuantizedTensor(TorchAOBaseTensor): - """Example based on FbgemmFp8Tensor - stores quantized data + scale""" + """Example based on Float8Tensor - stores quantized data + scale""" tensor_data_attrs = ["quantized_data", "scale"] tensor_attributes = ["dtype"] diff --git a/test/core/test_config.py b/test/core/test_config.py index 0bf975fa3b..0df31194ac 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -24,7 +24,6 @@ AWQStep, ) from torchao.quantization.quant_api import ( - FbgemmConfig, Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, Float8WeightOnlyConfig, @@ -92,7 +91,6 @@ ), AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOADING), AWQConfig(Int4WeightOnlyConfig(group_size=128), step="prepare_for_loading"), - FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256]), ] diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 56f42a8043..83f32c8420 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -24,9 +24,7 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from torchao.float8.config import e4m3_dtype from torchao.quantization import ( - FbgemmConfig, Float8WeightOnlyConfig, GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, @@ -44,7 +42,6 @@ is_fbcode, is_ROCM, is_sm_at_least_89, - is_sm_at_least_90, ) is_cusparselt_available = ( @@ -100,10 +97,6 @@ def get_quantization_functions( if is_sm_at_least_89(): base_functions.append(Float8WeightOnlyConfig()) - if is_sm_at_least_90(): - base_functions.append(FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16)) - base_functions.append(FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16)) - return base_functions diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 81e1ff2815..da1b848bcb 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -434,25 +434,6 @@ def ffn_or_attn_only(mod, fqn): model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1), ) - elif "fbgemm" in quantization and "int4" in quantization: - from torchao.quantization import FbgemmConfig - - _, precision, group_size = quantization.split("-") - group_size = int(group_size) - block_size = [1, group_size] - assert precision == "int4", f"FbegemmConfig({precision=}) not supported yet" - quantize_( - model, - FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, block_size), - ) - elif "fbgemm" in quantization and "fp8" in quantization: - from torchao.float8.config import e4m3_dtype - from torchao.quantization import FbgemmConfig - - quantize_( - model, - FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16), - ) elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 575e154091..07f03c7ed9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -8,7 +8,6 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8 from .floatx import ( CutlassSemiSparseLayout, Float8Layout, diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py deleted file mode 100644 index 6f007c9339..0000000000 --- a/torchao/dtypes/fbgemm_fp8_tensor.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Optional - -import torch -from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.utils import ( - TorchAOBaseTensor, - fill_defaults, -) - -__all__ = [ - "to_fbgemm_fp8", - "FbgemmFp8Tensor", -] - -aten = torch.ops.aten - - -class FbgemmFp8Tensor(TorchAOBaseTensor): - """ - TODO: needs padding for cutlass kernels - """ - - tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"] - tensor_attributes = ["dtype"] - - def __new__(cls, float8_data, scale, activation_scale_ub, dtype): - shape = float8_data.shape - kwargs = {} - kwargs["device"] = float8_data.device - kwargs["dtype"] = dtype - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, float8_data, scale, activation_scale_ub, dtype): - self.float8_data = float8_data - self.scale = scale - self.activation_scale_ub = activation_scale_ub - - def __tensor_flatten__(self): - return self.tensor_data_attrs, [ - getattr(self, attr) for attr in self.tensor_attributes - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - return cls( - *[tensor_data_dict[name] for name in cls.tensor_data_attrs], - *tensor_attributes, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], - *[getattr(self, attr) for attr in self.tensor_attributes], - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(weight={self.float8_data}, scale={self.scale}, " - f"activation_scale_ub={self.activation_scale_ub}, " - f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - - def _quantization_type(self): - return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}" - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.float8_data.to(device), - self.scale.to(device), - self.activation_scale_ub.to(device), - self.dtype, - ) - - @classmethod - def from_float( - cls, - w: torch.Tensor, - activation_scale_ub: Optional[float] = None, - ): - if activation_scale_ub is None: - activation_scale_ub = 1200.0 - - activation_scale_ub = torch.tensor( - [activation_scale_ub], - dtype=torch.float, - device=w.device, - ) - wq, w_scale = torch.ops.triton.quantize_fp8_row(w) - # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) - dtype = w.dtype - del w - return FbgemmFp8Tensor( - wq, - w_scale, - activation_scale_ub=activation_scale_ub, - dtype=dtype, - ) - - -implements = FbgemmFp8Tensor.implements - - -@implements([torch.nn.functional.linear, aten.linear.default]) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - orig_act_size = input_tensor.size() - orig_out_features = weight_tensor.shape[-2] - - # not used - num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - input_tensor, num_tokens, weight_tensor.activation_scale_ub - ) - - a_data = xq - b_data = weight_tensor.float8_data - - res = torch.ops.fbgemm.f8f8bf16_rowwise( - a_data, - b_data, - x_scale, - weight_tensor.scale, - use_fast_accum=True, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: - res = res + bias - - return res - - -@implements(torch.bmm) -def _(func, types, args, kwargs): - input_tensor, weight_tensor = ( - args[0], - args[1], - ) - orig_act_size = input_tensor.size() - # not used - num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - input_tensor, num_tokens, weight_tensor.activation_scale_ub - ) - - a_data = xq - b_data = weight_tensor.float8_data - orig_out_features = b_data.shape[-2] - - res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( - a_data, - b_data, - x_scale, - weight_tensor.scale, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - return res - - -@implements([aten.detach.default, aten.alias.default]) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -def _same_metadata(self: "FbgemmFp8Tensor", src: "FbgemmFp8Tensor") -> bool: - return ( - isinstance(self, FbgemmFp8Tensor) - and isinstance(src, FbgemmFp8Tensor) - and self.shape == src.shape - and self.float8_data.shape == src.float8_data.shape - and self.scale.shape == src.scale.shape - and self.activation_scale_ub.shape == src.activation_scale_ub.shape - and self.dtype == src.dtype - ) - - -@implements(aten.copy_.default) -def _(func, types, args, kwargs): - self = args[0] - src = args[1] - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - -@implements(aten.slice.Tensor) -def _(func, types, args, kwargs): - """Only supports slicing for dim == 1 and dim == 2 - original tensor shape has dimension (N, K) - float8_data has dimension (N, K) - scale (per row quantization) has dimension: (N,) - - since float8_data has the same dimension as original tensor, we can directly slice that - for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1 - - Note that we need to call slice on the float8_data and scale directly because slice - is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_fp8` - for - """ - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - assert step == 1 - assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" - if end >= self.shape[dim]: - end = self.shape[dim] - - assert self.float8_data.ndim == 2, ( - f"Expected packed weight to have dim 2, got {self.float8_data.dim}" - ) - - # Always slice the float8_data - sliced_data = aten.slice.Tensor( - self.float8_data, dim, start, end, step - ).contiguous() - - if dim == 0: - # scale has dimension (N,) where N is the dim 0 of `self` - # so we do the same slice on scale for dimension 0 - sliced_scale = aten.slice.Tensor(self.scale, 0, start, end, step) - else: - # since scale is per row, slicing along the dim == 1 dimension does - # not change the scale - sliced_scale = self.scale - - return return_and_correct_aliasing( - func, - args, - kwargs, - FbgemmFp8Tensor( - sliced_data, sliced_scale, self.activation_scale_ub, dtype=self.dtype - ), - ) - - -to_fbgemm_fp8 = FbgemmFp8Tensor.from_float - - -# Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True` -torch.serialization.add_safe_globals([FbgemmFp8Tensor]) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 407a83bcd7..b32868b684 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -43,7 +43,6 @@ ) from .quant_api import ( CutlassInt4PackedLayout, - FbgemmConfig, Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, @@ -161,7 +160,6 @@ "GemliteUIntXWeightOnlyConfig", "AOPerModuleConfig", "ModuleFqnToConfig", - "FbgemmConfig", # tensor subclasses "Int4Tensor", "Int4PlainInt32Tensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ef4b247819..3a6ecc08a7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -46,7 +46,6 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, - to_fbgemm_fp8, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -93,7 +92,6 @@ ) from torchao.utils import ( _ConfigDeprecationWrapper, - _is_fbgemm_genai_gpu_available, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -160,7 +158,6 @@ "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", - "FbgemmConfig", ] LAYOUT_TO_ZERO_POINT_DOMAIN = { @@ -2312,86 +2309,6 @@ def _fpx_weight_only_transform( return module -@dataclass -class FbgemmConfig(AOBaseConfig): - """Quantization Config for fbgemm-genai kernels - Args: - input_dtype (torch.dtype): input dtype of the kernel - weight_dtype (torch.dtype): weight dtype of the kernel - output_dtype (torch.dtype): output dtype of the kernel - group_size (int): The group size for weight - preshuffle (bool): whether preshuffle the weights or not - """ - - input_dtype: torch.dtype - weight_dtype: torch.dtype - output_dtype: torch.dtype - block_size: Optional[List[int]] = None - activation_scale_ub: float = 1200.0 - preshuffle: bool = False - - -@register_quantize_module_handler(FbgemmConfig) -def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: - if not _is_fbgemm_genai_gpu_available(): - raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - - _SUPPORTED_DTYPES = { - (torch.bfloat16, torch.int4, torch.bfloat16), - (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16), - } - - if ( - (config.input_dtype == torch.bfloat16) - and (config.weight_dtype == torch.int4) - and (config.output_dtype == torch.bfloat16) - ): - if config.preshuffle: - weight = Int4PreshuffledTensor.from_hp( - module.weight, - config.block_size, - activation_dtype=torch.bfloat16, - ) - else: - weight = Int4Tensor.from_hp( - module.weight, - config.block_size, - ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - if ( - (config.input_dtype == e4m3_dtype) - and (config.weight_dtype == torch.int4) - and (config.output_dtype == torch.bfloat16) - ): - if config.preshuffle: - weight = Int4PreshuffledTensor.from_hp( - module.weight, - config.block_size, - activation_dtype=torch.float8_e4m3fn, - ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - elif ( - (config.input_dtype == e4m3_dtype) - and (config.weight_dtype == e4m3_dtype) - and (config.output_dtype == torch.bfloat16) - ): - weight = to_fbgemm_fp8( - module.weight, - config.activation_scale_ub, - ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - else: - raise NotImplementedError( - f"{config} is not supported. supported input, weight, output kernel dtypes are: {_SUPPORTED_DTYPES}" - ) - - @dataclass class ModuleFqnToConfig(AOBaseConfig): """Per module configurations for torchao quantize_ API