From 4f3902e2a067886a728262958103e85d06e9b652 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 2 Dec 2024 17:45:16 -0800 Subject: [PATCH 01/11] [WIP] Enable float8 attention support (q/k/v) Summary: att, right now we need to manually add quantize call for q/k/v before sdpa op, but we can explore other APIs in the future Test Plan: TBD Reviewers: Subscribers: Tasks: Tags: --- .../_models/sam2/modeling/sam/transformer.py | 11 ++ torchao/dtypes/affine_quantized_tensor.py | 4 +- torchao/dtypes/affine_quantized_tensor_ops.py | 107 +++++++++++------- torchao/quantization/quant_api.py | 23 ++-- 4 files changed, 94 insertions(+), 51 deletions(-) diff --git a/torchao/_models/sam2/modeling/sam/transformer.py b/torchao/_models/sam2/modeling/sam/transformer.py index 2e3d85ccd4..06b2c956fe 100644 --- a/torchao/_models/sam2/modeling/sam/transformer.py +++ b/torchao/_models/sam2/modeling/sam/transformer.py @@ -17,6 +17,7 @@ from torchao._models.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from torchao._models.sam2.modeling.sam2_utils import MLP from torchao._models.sam2.utils.misc import get_sdpa_settings +from torchao.quantization.quant_api import _float8_symmetric_per_token_quant warnings.simplefilter(action="ignore", category=FutureWarning) # Check whether Flash Attention is available (and use it by default) @@ -263,6 +264,11 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) + # quantize q/k/v + q = _float8_symmetric_per_token_quant(q) + k = _float8_symmetric_per_token_quant(k) + v = _float8_symmetric_per_token_quant(v) + dropout_p = self.dropout_p if self.training else 0.0 # # Attention # try: @@ -323,6 +329,11 @@ def forward( k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) + # quantize q/k/v + q = _float8_symmetric_per_token_quant(q) + k = _float8_symmetric_per_token_quant(k) + v = _float8_symmetric_per_token_quant(v) + # Apply rotary position encoding w = h = math.sqrt(q.shape[-2]) self.freqs_cis = self.freqs_cis.to(q.device) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 93d2766d1e..af4ff01911 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -510,9 +510,9 @@ def from_hp_to_intx( ) -###################################################### +############################################### # Layout and TensorImpl Subclass Registration # -###################################################### +############################################### register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index bd7ff7d333..b6eeedd914 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -177,45 +177,6 @@ def _(func, types, args, kwargs): return torch.nn.functional.linear(input_tensor, weight_tensor, bias) -@implements(torch.nn.functional.embedding) -def _(func, types, args, kwargs): - # new_arg1 = args[1].dequantize() - # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) - assert isinstance( - args[1].tensor_impl, PlainAQTTensorImpl - ), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" - assert ( - kwargs["padding_idx"] is None - and kwargs["max_norm"] is None - and not kwargs["scale_grad_by_freq"] - and not kwargs["sparse"] - and kwargs["norm_type"] == 2.0 - ) - idx = args[0] - int_data, scale, zero_point = args[1].tensor_impl.get_plain() - - sliced_data, sliced_scale, sliced_zero_point = ( - int_data[idx], - scale[idx], - zero_point[idx], - ) - # Block size is expecting 2 dimensions [1, group size] but - # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so - # we need to increase block size to correct dim - new_blocks = idx.dim() - 1 - return dequantize_affine( - sliced_data, - new_blocks * [1] + list(args[1].block_size), - sliced_scale, - sliced_zero_point, - sliced_data.dtype, - args[1].quant_min, - args[1].quant_max, - args[1].zero_point_domain, - output_dtype=sliced_scale.dtype, - ) - - @implements(aten.addmm.default) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -277,6 +238,74 @@ def _(func, types, args, kwargs): return func(input_tensor, weight_tensor) +@implements(torch.nn.functional.embedding) +def _(func, types, args, kwargs): + # new_arg1 = args[1].dequantize() + # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) + assert isinstance( + args[1].tensor_impl, PlainAQTTensorImpl + ), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" + assert ( + kwargs["padding_idx"] is None + and kwargs["max_norm"] is None + and not kwargs["scale_grad_by_freq"] + and not kwargs["sparse"] + and kwargs["norm_type"] == 2.0 + ) + idx = args[0] + int_data, scale, zero_point = args[1].tensor_impl.get_plain() + + sliced_data, sliced_scale, sliced_zero_point = ( + int_data[idx], + scale[idx], + zero_point[idx], + ) + # Block size is expecting 2 dimensions [1, group size] but + # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so + # we need to increase block size to correct dim + new_blocks = idx.dim() - 1 + return dequantize_affine( + sliced_data, + new_blocks * [1] + list(args[1].block_size), + sliced_scale, + sliced_zero_point, + sliced_data.dtype, + args[1].quant_min, + args[1].quant_max, + args[1].zero_point_domain, + output_dtype=sliced_scale.dtype, + ) + + +@implements([torch.nn.functional.scaled_dot_product_attention]) +def _(func, types, args, kwargs): + # for libc10.so + import torch + from hopper.flash_attn_interface import flash_attn_func + q, k, v = args[:3] + q_tensor_impl = q.tensor_impl + assert not q_tensor_impl.transposed + q_float8_data = q_tensor_impl.float8_data + q_scale = q.scale + + k_tensor_impl = k.tensor_impl + assert not k_tensor_impl.transposed + k_float8_data = k_tensor_impl.float8_data + k_scale = k.scale + + v_tensor_impl = v.tensor_impl + assert not v_tensor_impl.transposed + v_float8_data = v_tensor_impl.float8_data + v_scale = v.scale + + softmax_scale = kwargs.get("scale", None) + dropout_p = kwargs.get("dropout_p", None) + assert dropout_p is None or dropout_p == 0.0, "dropout_p should be set to 0.0 during inference" + + breakpoint() + return flash_attn_func(q_float8_data, k_float8_data, v_float8_data, softmax_scale=softmax_scale, descale_q=q_scale, descale_k=k_scale, descale_v=v_scale) + + @implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 96ccb1889c..9e0a11b2c9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -803,6 +803,17 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) +def _float8_symmetric_per_token_quant(x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): + from torchao.dtypes import to_affine_quantized_floatx + + return to_affine_quantized_floatx( + input_float=x, + block_size=_get_per_token_block_size(x), + target_dtype=dtype, + scale_dtype=None, + _layout=Float8Layout(mm_config=None), + ) + def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): """ Applies float8 weight-only symmetric per-channel quantization to linear layers. @@ -814,17 +825,8 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): The actual matmul will be computed in original precision of the weight tensor. """ - from torchao.dtypes import to_affine_quantized_floatx - def apply_float8wo_quant(weight): - block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=None, - _layout=Float8Layout(mm_config=None), - ) + return _float8_symmetric_per_token_quant(weight, weight_dtype) return _get_linear_subclass_inserter(apply_float8wo_quant) @@ -1172,5 +1174,6 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: _int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant, _input_activation_quant_func_fp8, + _float8_symmetric_per_token_quant, ] ) From 4db8febbaf68bd8c55829e6f3bd9b402dd0fb6b7 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Dec 2024 16:57:16 -0800 Subject: [PATCH 02/11] add float8 from fa3 to aqt --- dev-requirements.txt | 2 + test/dtypes/test_affine_quantized_float.py | 142 ++++++++++++++++++ .../_models/sam2/modeling/sam/transformer.py | 25 ++- torchao/dtypes/affine_quantized_tensor_ops.py | 41 ++--- torchao/dtypes/floatx/float8_layout.py | 72 +++++++++ torchao/dtypes/uintx/semi_sparse_layout.py | 3 +- torchao/quantization/README.md | 30 ++++ torchao/quantization/quant_api.py | 22 ++- 8 files changed, 306 insertions(+), 31 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index cb7d7e1152..47963ec163 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -14,6 +14,8 @@ matplotlib pandas fire # QOL for commandline scripts tabulate # QOL for printing tables to stdout +einops # for testing flash attention 3 + # Custom CUDA Extensions ninja diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 4d8312b427..464fd1b1ed 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -288,7 +288,149 @@ def test_fp8_weight_dimension_warning(self): ) +# need to install einops +import math + +from einops import rearrange, repeat + + +# copied from https://github.com/Dao-AILab/flash-attention/blob/1feb711f46563960fc10a8e659c93c300619504b/tests/test_util.py#L185 +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores /= softcap + scores = scores.tanh() + scores *= softcap + if key_padding_mask is not None: + scores.masked_fill_( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + ) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + key_leftpad=key_leftpad, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill( + torch.all(local_mask, dim=-1, keepdim=True), 0.0 + ) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + ) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + if key_padding_mask is not None: + output.masked_fill_( + rearrange( + torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1" + ), + 0.0, + ) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +class TestAffineQuantizedFloat8Attention(common_utils.TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_float8_attention(self): + import torch.nn.functional as F + + from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant + + class MyModel(torch.nn.Module): + def forward(self, q, k, v, float8_quantize=False): + if float8_quantize: + q = _float8_symmetric_per_tensor_quant(q) + k = _float8_symmetric_per_tensor_quant(k) + v = _float8_symmetric_per_tensor_quant(v) + return F.scaled_dot_product_attention(q, k, v) + + # note: last headdim must be 64, 128, 256 + q = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") + k = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") + v = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") + + m = MyModel().eval() + # it differs a lot from the non-quantized implementation + # sqnr = -2.5 + # ref = m(q, k, v) + + # but matches the custom attention implementation in flash attention repo + ref = attention_ref(q, k, v)[0] + quantized = m(q, k, v, True) + assert compute_error(ref, quantized) > 25.0 + + common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) if __name__ == "__main__": pytest.main([__file__]) + common_utils.run_tests() diff --git a/torchao/_models/sam2/modeling/sam/transformer.py b/torchao/_models/sam2/modeling/sam/transformer.py index 06b2c956fe..66c39cf5ba 100644 --- a/torchao/_models/sam2/modeling/sam/transformer.py +++ b/torchao/_models/sam2/modeling/sam/transformer.py @@ -17,7 +17,6 @@ from torchao._models.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from torchao._models.sam2.modeling.sam2_utils import MLP from torchao._models.sam2.utils.misc import get_sdpa_settings -from torchao.quantization.quant_api import _float8_symmetric_per_token_quant warnings.simplefilter(action="ignore", category=FutureWarning) # Check whether Flash Attention is available (and use it by default) @@ -25,6 +24,9 @@ # A fallback setting to allow all available kernels if Flash Attention fails ALLOW_ALL_KERNELS = False +# whether to turn on float8 quantization for sdpa or not +_QUANTIZE_ATTN = False + def sdp_kernel_context(dropout_p): """ @@ -265,9 +267,21 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: v = self._separate_heads(v, self.num_heads) # quantize q/k/v - q = _float8_symmetric_per_token_quant(q) - k = _float8_symmetric_per_token_quant(k) - v = _float8_symmetric_per_token_quant(v) + if _QUANTIZE_ATTN: + from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant + original_head_dim = list(q.shape)[-1] + padded = False + # padding: + if q.shape[-1] == 32: + q = F.pad(q, (0, 32)) + k = F.pad(k, (0, 32)) + v = F.pad(v, (0, 32)) + padded = True + + if q.shape[-1] in [64, 128, 256]: + q = _float8_symmetric_per_tensor_quant(q) + k = _float8_symmetric_per_tensor_quant(k) + v = _float8_symmetric_per_tensor_quant(v) dropout_p = self.dropout_p if self.training else 0.0 # # Attention @@ -287,6 +301,9 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) # TODO: This scale should not be needed. But without it compile causes a NaN. out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, scale=(1.0 / math.sqrt(q.size(-1)))) + if _QUANTIZE_ATTN and padded: + out = out[:, :, :, :original_head_dim] + out = out.to(v.dtype) out = self._recombine_heads(out) out = self.out_proj(out) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index b6eeedd914..f7b84b764b 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -11,6 +11,8 @@ _linear_fp8_act_fp8_weight_impl, _linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl, + _sdpa_float8_check, + _sdpa_float8_impl, ) from torchao.dtypes.floatx.floatx_tensor_core_layout import ( _linear_f16_bf16_act_floatx_weight_check, @@ -277,33 +279,22 @@ def _(func, types, args, kwargs): ) -@implements([torch.nn.functional.scaled_dot_product_attention]) +@implements(torch.nn.functional.scaled_dot_product_attention) def _(func, types, args, kwargs): - # for libc10.so - import torch - from hopper.flash_attn_interface import flash_attn_func q, k, v = args[:3] - q_tensor_impl = q.tensor_impl - assert not q_tensor_impl.transposed - q_float8_data = q_tensor_impl.float8_data - q_scale = q.scale - - k_tensor_impl = k.tensor_impl - assert not k_tensor_impl.transposed - k_float8_data = k_tensor_impl.float8_data - k_scale = k.scale - - v_tensor_impl = v.tensor_impl - assert not v_tensor_impl.transposed - v_float8_data = v_tensor_impl.float8_data - v_scale = v.scale - - softmax_scale = kwargs.get("scale", None) - dropout_p = kwargs.get("dropout_p", None) - assert dropout_p is None or dropout_p == 0.0, "dropout_p should be set to 0.0 during inference" - - breakpoint() - return flash_attn_func(q_float8_data, k_float8_data, v_float8_data, softmax_scale=softmax_scale, descale_q=q_scale, descale_k=k_scale, descale_v=v_scale) + if not _sdpa_float8_check(q, k, v): + # dequantize and call original op + if hasattr(q, "dequantize"): + q = q.dequantize() + if hasattr(k, "dequantize"): + k = k.dequantize() + if hasattr(v, "dequantize"): + v = v.dequantize() + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, *args[3:], **kwargs + ) + else: + return _sdpa_float8_impl(k, q, v, kwargs) @implements(aten.detach.default) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index dd995fb157..af02d3634d 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -311,3 +311,75 @@ def _linear_fp_act_fp8_weight_impl( bias: Optional[torch.Tensor], ): return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) + + +def _sdpa_float8_check( + q: Union[torch.Tensor, "AffineQuantizedTensor"], + k: Union[torch.Tensor, "AffineQuantizedTensor"], + v: Union[torch.Tensor, "AffineQuantizedTensor"], +) -> bool: + def is_per_tensor_float8_aqt(t): + # tensor is float8 quantized affine quantized tensor + return ( + isinstance(t, AffineQuantizedTensor) + and isinstance(t._layout, Float8Layout) + and t.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and (t.shape == t.block_size) + ) + + return ( + is_per_tensor_float8_aqt(q) + and is_per_tensor_float8_aqt(k) + and is_per_tensor_float8_aqt(v) + ) + + +def _sdpa_float8_impl( + q: Union[torch.Tensor, "AffineQuantizedTensor"], + k: Union[torch.Tensor, "AffineQuantizedTensor"], + v: Union[torch.Tensor, "AffineQuantizedTensor"], + kwargs, +) -> torch.Tensor: + # requires build from source + # https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#flashattention-3-beta-release + # for libc10.so + import torch + from hopper.flash_attn_interface import flash_attn_func + + q_tensor_impl = q.tensor_impl + assert not q_tensor_impl.transposed + q_float8_data = q_tensor_impl.float8_data + # change from scalar to tensor of size [1] + q_scale = q_tensor_impl.scale + q_scale = torch.tensor([q_scale], device=q_scale.device) + + k_tensor_impl = k.tensor_impl + assert not k_tensor_impl.transposed + k_float8_data = k_tensor_impl.float8_data + k_scale = k_tensor_impl.scale + k_scale = torch.tensor([k_scale], device=k_scale.device) + + v_tensor_impl = v.tensor_impl + assert not v_tensor_impl.transposed + v_float8_data = v_tensor_impl.float8_data + v_scale = v_tensor_impl.scale + v_scale = torch.tensor([v_scale], device=v_scale.device) + + dropout_p = kwargs.get("dropout_p", None) + assert ( + dropout_p is None or dropout_p == 0.0 + ), "dropout_p should be set to 0.0 during inference" + causal = kwargs.get("causal", False) + + gqa_parallel = False + out = flash_attn_func( + q_float8_data, + k_float8_data, + v_float8_data, + causal=causal, + descale_q=q_scale, + descale_k=k_scale, + descale_v=v_scale, + ) + + return out[0] diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index d832731657..a554fd9bc6 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -44,6 +44,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( # must pad row, col = tmp.shape from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( @@ -51,7 +52,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( tmp_padded.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, - ).t()[:row, :] + ).t()[:row, :] y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 3fc2cb5ef0..70da8148d6 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -360,6 +360,36 @@ We have kernels that do 8-bit dynamic quantization of activations and uintx grou You try can out these apis with the `quantize_` api as above alongside the constructor `int8_dynamic_activation_intx_weight`. An example can be found in `torchao/_models/llama/generate.py`. +### Float8 `scaled_dot_product_attention` Support +We also have initial support for per tensor float8 `scaled_dot_product_attention`, using flash attention 3 (optimized for H100 GPUs). To use the feature: + +1. Install from source: +https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#flashattention-3-beta-release + +2. Modify the model to quantize q/k/v to per tensor float8 tensors +``` +from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant +import torch.nn.functional as F + +class MyModel(torch.nn.Module): + def forward(self, q, k, v, float8_quantize=False): + if float8_quantize: + q = _float8_symmetric_per_tensor_quant(q) + k = _float8_symmetric_per_tensor_quant(k) + v = _float8_symmetric_per_tensor_quant(v) + return F.scaled_dot_product_attention(q, k, v) + +# note: (last dimension) headdim must be 64, 128, 256 +q = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") +k = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") +v = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") +``` +See `test_float8_attention` in `test/dtypes/test_affine_quantized_float.py` on the full test. + +Note that right now the float8 attention implementation differs a lot from the original unquantized version, but matches more closely with their reference attention implementation in [flash attention repo](https://github.com/Dao-AILab/flash-attention/blob/1feb711f46563960fc10a8e659c93c300619504b/tests/test_util.py#L185) we still need to invetigate why. + +We might be adding new variations of attention implementation in the future (per row, per column, per block scaling etc.). + ### Automatic Inductor Configuration The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9e0a11b2c9..4fc5513462 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -803,7 +803,9 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) -def _float8_symmetric_per_token_quant(x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): +def _float8_symmetric_per_token_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn +): from torchao.dtypes import to_affine_quantized_floatx return to_affine_quantized_floatx( @@ -814,6 +816,23 @@ def _float8_symmetric_per_token_quant(x: torch.Tensor, dtype: torch.dtype = torc _layout=Float8Layout(mm_config=None), ) + +def _float8_symmetric_per_tensor_quant( + x: torch.Tensor, + dtype: torch.dtype = torch.float8_e4m3fn, + mm_config: Optional[Float8MMConfig] = None, +): + from torchao.dtypes import to_affine_quantized_floatx + + return to_affine_quantized_floatx( + input_float=x, + block_size=tuple(x.shape), + target_dtype=dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) + + def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): """ Applies float8 weight-only symmetric per-channel quantization to linear layers. @@ -825,6 +844,7 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): The actual matmul will be computed in original precision of the weight tensor. """ + def apply_float8wo_quant(weight): return _float8_symmetric_per_token_quant(weight, weight_dtype) From 6df8f41dfcb75af85efaee4af8adc8e27d598e7f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Dec 2024 17:09:09 -0800 Subject: [PATCH 03/11] formatting --- test/dtypes/test_affine_quantized_float.py | 43 +++++++++++++++++-- .../_models/sam2/modeling/sam/transformer.py | 6 +-- torchao/dtypes/floatx/float8_layout.py | 1 - 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 464fd1b1ed..5a202cb939 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -9,6 +9,7 @@ import copy import io +import math import random import unittest from contextlib import nullcontext @@ -17,6 +18,7 @@ import pytest import torch +from einops import rearrange, repeat from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils @@ -288,13 +290,46 @@ def test_fp8_weight_dimension_warning(self): ) -# need to install einops -import math +# copied from https://github.com/Dao-AILab/flash-attention/blob/1feb711f46563960fc10a8e659c93c300619504b/tests/test_util.py#L185 -from einops import rearrange, repeat + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, + key_leftpad=None, +): + row_idx = rearrange( + torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" + ) + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) -# copied from https://github.com/Dao-AILab/flash-attention/blob/1feb711f46563960fc10a8e659c93c300619504b/tests/test_util.py#L185 def attention_ref( q, k, diff --git a/torchao/_models/sam2/modeling/sam/transformer.py b/torchao/_models/sam2/modeling/sam/transformer.py index 66c39cf5ba..ff454cd3c6 100644 --- a/torchao/_models/sam2/modeling/sam/transformer.py +++ b/torchao/_models/sam2/modeling/sam/transformer.py @@ -267,6 +267,7 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: v = self._separate_heads(v, self.num_heads) # quantize q/k/v + padded = False if _QUANTIZE_ATTN: from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant original_head_dim = list(q.shape)[-1] @@ -346,11 +347,6 @@ def forward( k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) - # quantize q/k/v - q = _float8_symmetric_per_token_quant(q) - k = _float8_symmetric_per_token_quant(k) - v = _float8_symmetric_per_token_quant(v) - # Apply rotary position encoding w = h = math.sqrt(q.shape[-2]) self.freqs_cis = self.freqs_cis.to(q.device) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index af02d3634d..a3f5c5975f 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -371,7 +371,6 @@ def _sdpa_float8_impl( ), "dropout_p should be set to 0.0 during inference" causal = kwargs.get("causal", False) - gqa_parallel = False out = flash_attn_func( q_float8_data, k_float8_data, From 57ac986dc54a70ead1a495e957861e2498e28404 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Dec 2024 18:16:26 -0800 Subject: [PATCH 04/11] add eval support for llama --- torchao/_models/llama/model.py | 40 +++++++++++++------ .../_models/sam2/modeling/sam/transformer.py | 3 +- torchao/dtypes/affine_quantized_tensor_ops.py | 3 +- torchao/dtypes/floatx/float8_layout.py | 17 +++++--- 4 files changed, 44 insertions(+), 19 deletions(-) diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 74cad30cbd..6f1942514d 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -13,6 +13,8 @@ from torch.nn import functional as F from torchao.utils import find_multiple +_QUANTIZE_ATTN = True + # TODO remove suplerfluous arg def prepare_inputs_for_model(inps, max_new_tokens=1): # this is because input from lm-eval is 2d @@ -85,7 +87,7 @@ def from_name(cls, name: str): ), } -# this is a model specific variable that controls whether index_put is used for the kv_cache update, +# this is a model specific variable that controls whether index_put is used for the kv_cache update, # it is needed for GPTQ but otherwise attenuates perf so the default is to not use it use_index_put_for_kv_cache = False @@ -124,7 +126,7 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtyp self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8)) self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) - + def update(self, input_pos, k_val, v_val): # quantize current k_val and store it in the cache q_k_val, k_scale = quantize_activation_per_token_absmax(k_val) @@ -138,7 +140,7 @@ def update(self, input_pos, k_val, v_val): self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1) v_out = self.v_cache*self.v_cache_scale v_out[:, :, input_pos] = v_val - + return k_out, v_out @classmethod @@ -194,16 +196,16 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_ else: b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) self.freqs_cis = precompute_freqs_cis( - self.config.block_size, - self.config.dim // self.config.n_head, - self.config.rope_base, - dtype, + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + dtype, use_scaled=self.config.use_scaled_rope ) def reset_caches(self): """Reset caches. - + The caches used by training stage and inference stage may be different, reset them before switching. """ self.max_batch_size = -1 @@ -215,7 +217,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: """Forward pass of the model. Args: - idx (`torch.LongTensor` of shape `(batch_size, seq_length)`): + idx (`torch.LongTensor` of shape `(batch_size, seq_length)`): Indices of input sequence tokens in the vocabulary. input_pos (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. @@ -227,7 +229,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: """ assert self.freqs_cis is not None, "Caches must be initialized first" - if input_pos is None: + if input_pos is None: mask = None freqs_cis = self.freqs_cis[:idx.shape[1]] else: @@ -311,11 +313,25 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Optional[Tensor], input_po k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + + # quantize q/k/v with per tensor float8 quantization + padded = False + if _QUANTIZE_ATTN: + from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant + original_dtype = v.dtype + if q.shape[-1] in [64, 128, 256]: + q = _float8_symmetric_per_tensor_quant(q) + k = _float8_symmetric_per_tensor_quant(k) + v = _float8_symmetric_per_tensor_quant(v) + if mask is not None: y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) else: y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) + if _QUANTIZE_ATTN: + y = y.to(original_dtype) + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) y = self.wo(y) @@ -371,8 +387,8 @@ def apply_scaling(freqs: torch.Tensor): return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) def precompute_freqs_cis( - seq_len: int, - n_elem: int, + seq_len: int, + n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16, use_scaled: bool=False diff --git a/torchao/_models/sam2/modeling/sam/transformer.py b/torchao/_models/sam2/modeling/sam/transformer.py index ff454cd3c6..d475c7aea2 100644 --- a/torchao/_models/sam2/modeling/sam/transformer.py +++ b/torchao/_models/sam2/modeling/sam/transformer.py @@ -266,11 +266,12 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) - # quantize q/k/v + # quantize q/k/v with per tensor float8 quantization padded = False if _QUANTIZE_ATTN: from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant original_head_dim = list(q.shape)[-1] + original_dtype = v.dtype padded = False # padding: if q.shape[-1] == 32: diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index f7b84b764b..5702945405 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -282,7 +282,7 @@ def _(func, types, args, kwargs): @implements(torch.nn.functional.scaled_dot_product_attention) def _(func, types, args, kwargs): q, k, v = args[:3] - if not _sdpa_float8_check(q, k, v): + if not _sdpa_float8_check(q, k, v, kwargs): # dequantize and call original op if hasattr(q, "dequantize"): q = q.dequantize() @@ -294,6 +294,7 @@ def _(func, types, args, kwargs): q, k, v, *args[3:], **kwargs ) else: + print("calling float8 impl") return _sdpa_float8_impl(k, q, v, kwargs) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index a3f5c5975f..3a5eb05df2 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -317,20 +317,27 @@ def _sdpa_float8_check( q: Union[torch.Tensor, "AffineQuantizedTensor"], k: Union[torch.Tensor, "AffineQuantizedTensor"], v: Union[torch.Tensor, "AffineQuantizedTensor"], + kwargs, ) -> bool: - def is_per_tensor_float8_aqt(t): + + def is_compatible_per_tensor_float8_aqt(t): # tensor is float8 quantized affine quantized tensor return ( isinstance(t, AffineQuantizedTensor) and isinstance(t._layout, Float8Layout) and t.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and (t.shape == t.block_size) + and (t.shape == t.block_size) and + t.shape[-1] in [64, 128, 256] ) + dropout_p = kwargs.get("dropout_p", 0.0) + return ( - is_per_tensor_float8_aqt(q) - and is_per_tensor_float8_aqt(k) - and is_per_tensor_float8_aqt(v) + is_compatible_per_tensor_float8_aqt(q) and + is_compatible_per_tensor_float8_aqt(k) and + is_compatible_per_tensor_float8_aqt(v) and + "attn_mask" not in kwargs and + dropout_p == 0.0 ) From 01ffb209d5489154f3c68e12d38bd7378c76046a Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Dec 2024 18:18:10 -0800 Subject: [PATCH 05/11] ruff --- torchao/dtypes/floatx/float8_layout.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 3a5eb05df2..a277022cc9 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -319,25 +319,24 @@ def _sdpa_float8_check( v: Union[torch.Tensor, "AffineQuantizedTensor"], kwargs, ) -> bool: - def is_compatible_per_tensor_float8_aqt(t): # tensor is float8 quantized affine quantized tensor return ( isinstance(t, AffineQuantizedTensor) and isinstance(t._layout, Float8Layout) and t.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and (t.shape == t.block_size) and - t.shape[-1] in [64, 128, 256] + and (t.shape == t.block_size) + and t.shape[-1] in [64, 128, 256] ) dropout_p = kwargs.get("dropout_p", 0.0) return ( - is_compatible_per_tensor_float8_aqt(q) and - is_compatible_per_tensor_float8_aqt(k) and - is_compatible_per_tensor_float8_aqt(v) and - "attn_mask" not in kwargs and - dropout_p == 0.0 + is_compatible_per_tensor_float8_aqt(q) + and is_compatible_per_tensor_float8_aqt(k) + and is_compatible_per_tensor_float8_aqt(v) + and "attn_mask" not in kwargs + and dropout_p == 0.0 ) From 9ce7a62981b615293251d9c753e73621d8454493 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Dec 2024 18:42:27 -0800 Subject: [PATCH 06/11] guard import error --- torchao/dtypes/affine_quantized_tensor_ops.py | 5 ++-- torchao/dtypes/floatx/float8_layout.py | 25 +++++++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 5702945405..9273b9de57 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -282,7 +282,7 @@ def _(func, types, args, kwargs): @implements(torch.nn.functional.scaled_dot_product_attention) def _(func, types, args, kwargs): q, k, v = args[:3] - if not _sdpa_float8_check(q, k, v, kwargs): + if not _sdpa_float8_check(q, k, v, args, kwargs): # dequantize and call original op if hasattr(q, "dequantize"): q = q.dequantize() @@ -294,8 +294,7 @@ def _(func, types, args, kwargs): q, k, v, *args[3:], **kwargs ) else: - print("calling float8 impl") - return _sdpa_float8_impl(k, q, v, kwargs) + return _sdpa_float8_impl(k, q, v, args, kwargs) @implements(aten.detach.default) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index a277022cc9..2b24e1aaa9 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -317,6 +317,7 @@ def _sdpa_float8_check( q: Union[torch.Tensor, "AffineQuantizedTensor"], k: Union[torch.Tensor, "AffineQuantizedTensor"], v: Union[torch.Tensor, "AffineQuantizedTensor"], + args, kwargs, ) -> bool: def is_compatible_per_tensor_float8_aqt(t): @@ -344,13 +345,27 @@ def _sdpa_float8_impl( q: Union[torch.Tensor, "AffineQuantizedTensor"], k: Union[torch.Tensor, "AffineQuantizedTensor"], v: Union[torch.Tensor, "AffineQuantizedTensor"], + args, kwargs, ) -> torch.Tensor: - # requires build from source - # https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#flashattention-3-beta-release - # for libc10.so - import torch - from hopper.flash_attn_interface import flash_attn_func + try: + # requires build from source + # https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#flashattention-3-beta-release + # for libc10.so + import torch + from hopper.flash_attn_interface import flash_attn_func + except ImportError: + # fallback + # dequantize and call original op + if hasattr(q, "dequantize"): + q = q.dequantize() + if hasattr(k, "dequantize"): + k = k.dequantize() + if hasattr(v, "dequantize"): + v = v.dequantize() + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, *args[3:], **kwargs + ) q_tensor_impl = q.tensor_impl assert not q_tensor_impl.transposed From 4f985187c7b0715f2011d4e9104f584c907480e2 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 5 Dec 2024 15:47:00 -0800 Subject: [PATCH 07/11] fix numeric error --- test/dtypes/test_affine_quantized_float.py | 168 ++---------------- .../_models/sam2/modeling/sam/transformer.py | 3 + torchao/dtypes/affine_quantized_tensor_ops.py | 23 ++- torchao/dtypes/floatx/float8_layout.py | 26 ++- torchao/quantization/README.md | 4 +- 5 files changed, 43 insertions(+), 181 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 5a202cb939..29ba2277e9 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -9,7 +9,6 @@ import copy import io -import math import random import unittest from contextlib import nullcontext @@ -18,7 +17,6 @@ import pytest import torch -from einops import rearrange, repeat from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing._internal import common_utils @@ -290,149 +288,6 @@ def test_fp8_weight_dimension_warning(self): ) -# copied from https://github.com/Dao-AILab/flash-attention/blob/1feb711f46563960fc10a8e659c93c300619504b/tests/test_util.py#L185 - - -def construct_local_mask( - seqlen_q, - seqlen_k, - window_size=(-1, -1), # -1 means infinite window size - query_padding_mask=None, - key_padding_mask=None, - device=None, - key_leftpad=None, -): - row_idx = rearrange( - torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" - ) - col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - if key_leftpad is not None: - key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") - col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) - col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) - sk = ( - seqlen_k - if key_padding_mask is None - else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - ) - sq = ( - seqlen_q - if query_padding_mask is None - else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") - ) - if window_size[0] < 0: - return col_idx > row_idx + sk - sq + window_size[1] - else: - sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk - return torch.logical_or( - col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - col_idx < row_idx + sk - sq - window_size[0], - ) - - -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - attn_bias=None, - dropout_p=0.0, - dropout_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - softcap=0.0, - upcast=True, - reorder_ops=False, - key_leftpad=None, -): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - causal: whether to apply causal masking - window_size: (int, int), left and right window size - upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) - without changing the math. This is to estimate the numerical error from operation - reordering. - Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout - """ - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - seqlen_q, seqlen_k = q.shape[1], k.shape[1] - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) - d = q.shape[-1] - if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) - else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) - if softcap > 0: - scores /= softcap - scores = scores.tanh() - scores *= softcap - if key_padding_mask is not None: - scores.masked_fill_( - rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") - ) - if window_size[0] >= 0 or window_size[1] >= 0: - local_mask = construct_local_mask( - seqlen_q, - seqlen_k, - window_size, - query_padding_mask, - key_padding_mask, - q.device, - key_leftpad=key_leftpad, - ) - scores.masked_fill_(local_mask, float("-inf")) - if attn_bias is not None: - scores = scores + attn_bias - attention = torch.softmax(scores, dim=-1).to(v.dtype) - # Some rows might be completely masked out so we fill them with zero instead of NaN - if window_size[0] >= 0 or window_size[1] >= 0: - attention = attention.masked_fill( - torch.all(local_mask, dim=-1, keepdim=True), 0.0 - ) - # We want to mask here so that the attention matrix doesn't have any NaNs - # Otherwise we'll get NaN in dV - if query_padding_mask is not None: - attention = attention.masked_fill( - rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 - ) - dropout_scaling = 1.0 / (1 - dropout_p) - # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling - # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) - if dropout_mask is not None: - attention_drop = attention.masked_fill(~dropout_mask, 0.0) - else: - attention_drop = attention - output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - if key_padding_mask is not None: - output.masked_fill_( - rearrange( - torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1" - ), - 0.0, - ) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - class TestAffineQuantizedFloat8Attention(common_utils.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_float8_attention(self): @@ -443,29 +298,36 @@ def test_float8_attention(self): class MyModel(torch.nn.Module): def forward(self, q, k, v, float8_quantize=False): if float8_quantize: + # F.scaled_dot_product_attention is using (batch_size, nheads, seqlen, headdim) + # while flash attention kernel has (batch_size, seqlen, nheads, headdim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) q = _float8_symmetric_per_tensor_quant(q) k = _float8_symmetric_per_tensor_quant(k) v = _float8_symmetric_per_tensor_quant(v) + return F.scaled_dot_product_attention(q, k, v) - # note: last headdim must be 64, 128, 256 + # note: last dim headdim must be 64, 128 or 256 q = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") k = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") v = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") m = MyModel().eval() - # it differs a lot from the non-quantized implementation - # sqnr = -2.5 - # ref = m(q, k, v) - # but matches the custom attention implementation in flash attention repo - ref = attention_ref(q, k, v)[0] + # bfloat16 ref result + ref = m(q, k, v) + + # float8 quantized result quantized = m(q, k, v, True) - assert compute_error(ref, quantized) > 25.0 + + sqnr = compute_error(ref, quantized) + assert sqnr > 25.0, f"Got sqnr: {sqnr}" common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) common_utils.run_tests() diff --git a/torchao/_models/sam2/modeling/sam/transformer.py b/torchao/_models/sam2/modeling/sam/transformer.py index d475c7aea2..3d7bc4e248 100644 --- a/torchao/_models/sam2/modeling/sam/transformer.py +++ b/torchao/_models/sam2/modeling/sam/transformer.py @@ -281,6 +281,9 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: padded = True if q.shape[-1] in [64, 128, 256]: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) q = _float8_symmetric_per_tensor_quant(q) k = _float8_symmetric_per_tensor_quant(k) v = _float8_symmetric_per_tensor_quant(v) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 9273b9de57..a2319a9647 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -91,6 +91,12 @@ class QuantizedLinearNotImplementedError(NotImplementedError): pass +class QuantizedSDPANotImplementedError(NotImplementedError): + """Thin wrapper around NotImplementedError to make it easier to catch this error during dispatch""" + + pass + + @staticmethod def _quantized_linear_op(input_tensor, weight_tensor, bias): for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): @@ -282,19 +288,12 @@ def _(func, types, args, kwargs): @implements(torch.nn.functional.scaled_dot_product_attention) def _(func, types, args, kwargs): q, k, v = args[:3] - if not _sdpa_float8_check(q, k, v, args, kwargs): - # dequantize and call original op - if hasattr(q, "dequantize"): - q = q.dequantize() - if hasattr(k, "dequantize"): - k = k.dequantize() - if hasattr(v, "dequantize"): - v = v.dequantize() - return torch.nn.functional.scaled_dot_product_attention( - q, k, v, *args[3:], **kwargs - ) + if _sdpa_float8_check(q, k, v, args, kwargs): + return _sdpa_float8_impl(q, k, v, args, kwargs) else: - return _sdpa_float8_impl(k, q, v, args, kwargs) + raise QuantizedSDPANotImplementedError( + "No specialized dispatch found for quantized sdpa" + ) @implements(aten.detach.default) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 2b24e1aaa9..b3d120880d 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -349,22 +349,12 @@ def _sdpa_float8_impl( kwargs, ) -> torch.Tensor: try: - # requires build from source - # https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#flashattention-3-beta-release # for libc10.so import torch from hopper.flash_attn_interface import flash_attn_func - except ImportError: - # fallback - # dequantize and call original op - if hasattr(q, "dequantize"): - q = q.dequantize() - if hasattr(k, "dequantize"): - k = k.dequantize() - if hasattr(v, "dequantize"): - v = v.dequantize() - return torch.nn.functional.scaled_dot_product_attention( - q, k, v, *args[3:], **kwargs + except ImportError as e: + raise ImportError( + f"please install FlashAttention 3 before using float8 sdpa: https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#flashattention-3-beta-release, original import error {e}" ) q_tensor_impl = q.tensor_impl @@ -392,14 +382,20 @@ def _sdpa_float8_impl( ), "dropout_p should be set to 0.0 during inference" causal = kwargs.get("causal", False) - out = flash_attn_func( + out, _ = flash_attn_func( q_float8_data, k_float8_data, v_float8_data, causal=causal, + window_size=(-1, -1), descale_q=q_scale, descale_k=k_scale, descale_v=v_scale, ) - return out[0] + # F.scaled_dot_product_attention is using (batch_size, nheads, seqlen, headdim) + # while flash attention kernel has (batch_size, seqlen, nheads, headdim) + # so we need to transpose output to match the expected dimension + out = out.transpose(1, 2) + + return out diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 70da8148d6..8ba6c39578 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -305,7 +305,9 @@ Note that the workaround will not be needed after https://github.com/pytorch/pyt Note that the workaround is also required for `torch.compile` with `freezing` (`torch._inductor.config.freezing=True`) until https://github.com/pytorch/pytorch/pull/136265 is fixed. -## Other Available Quantization Techniques +## [Prototype Features] Other Available Quantization Techniques + +Note: APIs in this section are prototype and subject to change. ### KV Cache Quantization We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference. From 07a07c1c50e55226ff2de2c75c62cd20839b3fc3 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 5 Dec 2024 15:55:25 -0800 Subject: [PATCH 08/11] fix --- dev-requirements.txt | 1 - torchao/_models/llama/model.py | 2 +- torchao/quantization/README.md | 7 ++++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 47963ec163..f30eac73a8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -14,7 +14,6 @@ matplotlib pandas fire # QOL for commandline scripts tabulate # QOL for printing tables to stdout -einops # for testing flash attention 3 # Custom CUDA Extensions diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 6f1942514d..af834479f1 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -13,7 +13,7 @@ from torch.nn import functional as F from torchao.utils import find_multiple -_QUANTIZE_ATTN = True +_QUANTIZE_ATTN = False # TODO remove suplerfluous arg def prepare_inputs_for_model(inps, max_new_tokens=1): diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 8ba6c39578..7255340161 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -376,6 +376,9 @@ import torch.nn.functional as F class MyModel(torch.nn.Module): def forward(self, q, k, v, float8_quantize=False): if float8_quantize: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) q = _float8_symmetric_per_tensor_quant(q) k = _float8_symmetric_per_tensor_quant(k) v = _float8_symmetric_per_tensor_quant(v) @@ -388,9 +391,7 @@ v = torch.randn([64, 8, 8, 64], dtype=torch.bfloat16, device="cuda") ``` See `test_float8_attention` in `test/dtypes/test_affine_quantized_float.py` on the full test. -Note that right now the float8 attention implementation differs a lot from the original unquantized version, but matches more closely with their reference attention implementation in [flash attention repo](https://github.com/Dao-AILab/flash-attention/blob/1feb711f46563960fc10a8e659c93c300619504b/tests/test_util.py#L185) we still need to invetigate why. - -We might be adding new variations of attention implementation in the future (per row, per column, per block scaling etc.). +We might be adding new variations of attention implementation in the future (per row, per column, per block scaling etc.), and supporting arguments like `attn_mask`. ### Automatic Inductor Configuration The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. From e35a2a6eb3746212dcb35fc1fd387408fe0ccd34 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 5 Dec 2024 16:42:43 -0800 Subject: [PATCH 09/11] fix --- dev-requirements.txt | 1 - test/dtypes/test_affine_quantized_float.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index f30eac73a8..cb7d7e1152 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -15,7 +15,6 @@ pandas fire # QOL for commandline scripts tabulate # QOL for printing tables to stdout - # Custom CUDA Extensions ninja diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 29ba2277e9..58184c07cd 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -329,5 +329,5 @@ def forward(self, q, k, v, float8_quantize=False): common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) if __name__ == "__main__": - # pytest.main([__file__]) + pytest.main([__file__]) common_utils.run_tests() From f3a8e8dcd324f7d36d8753973cc9514e95c55d01 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 5 Dec 2024 21:32:39 -0800 Subject: [PATCH 10/11] skip fa8 test in CI --- test/dtypes/test_affine_quantized_float.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 58184c07cd..cdadeae4dc 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -288,6 +288,8 @@ def test_fp8_weight_dimension_warning(self): ) +@unittest.skip("Only running locally so we don't need to add installation of fa3 " + "hopper kernels to CI, we'll probably copy paste kernel in the future") class TestAffineQuantizedFloat8Attention(common_utils.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_float8_attention(self): From 289562677f07f442197faae165f18695dfcd745b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 6 Dec 2024 16:16:38 -0800 Subject: [PATCH 11/11] ruff --- test/dtypes/test_affine_quantized_float.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index cdadeae4dc..9460672add 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -288,8 +288,10 @@ def test_fp8_weight_dimension_warning(self): ) -@unittest.skip("Only running locally so we don't need to add installation of fa3 " - "hopper kernels to CI, we'll probably copy paste kernel in the future") +@unittest.skip( + "Only running locally so we don't need to add installation of fa3 " + "hopper kernels to CI, we'll probably copy paste kernel in the future" +) class TestAffineQuantizedFloat8Attention(common_utils.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_float8_attention(self):