diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index df11b71e66..9f6e3ee6ef 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -1157,6 +1157,12 @@ def test_per_row_config_before_dim(self): assert config_deser.granularity[0].dim == -1 assert config_deser.granularity[1].dim == -1 + @common_utils.parametrize("dim", [-2, -1]) + def test_chunk(self, dim): + x = torch.randn(16, 5120, 16384, device="cuda", dtype=torch.bfloat16) + x_fp8 = Float8Tensor.from_hp(x) + self._test_chunk_similar_to_vllm_llama4(x_fp8, dim) + common_utils.instantiate_parametrized_tests(TestFloat8Tensor) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 733d7a17a5..0e0028b324 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -956,6 +956,72 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new_tensor) +@implements(aten.split.Tensor) +def _(func, types, args, kwargs): + tensor, split_size_or_sections, dim = args + assert isinstance(split_size_or_sections, int), "unimplemented" + + # 2D case + # + # orig + # qdata.shape [M, K] + # scale.shape [M, 1] + # block_size [1, K] + # + # split with size (K // 2) across dim -1: + # qdata.shape [M, K // 2], [M, K // 2] + # scale.shape [M, 1], [M, 1] + # block_size [1, K // 2], [1, K // 2] + # + # split with size (M // 2) across dim 0: + # qdata.shape [M // 2, K], [M // 2, K] + # scale.shape [M // 2, 1], [M // 2, 1] + # block_size [1, K], [1, K] + + # split the qdata + new_qdatas = func(tensor.qdata, split_size_or_sections, dim) + num_chunks = len(new_qdatas) + + # split the scale + new_scales = [] + new_block_sizes = [] + if tensor.scale.shape[dim] == 1 and tensor.block_size[dim] == tensor.shape[dim]: + # repeat the scale, split block_size + for _ in range(num_chunks): + new_scales.append(tensor.scale) + new_block_size = tensor.block_size + new_block_size[dim] = new_block_size[dim] // split_size_or_sections + new_block_sizes.append(new_block_size) + + elif tensor.scale.shape[dim] == tensor.shape[dim] and tensor.block_size[dim] == 1: + # repeat the block size, split scale + new_scales = func(tensor.scale, split_size_or_sections, dim) + for _ in range(num_chunks): + new_block_sizes.append(tensor.block_size) + + else: + raise AssertionError( + f"`aten.split.Tensor` with {dim=} and {tensor.scale.shape=} is not yet implemented" + ) + + new_tensors_list = [] + for idx in range(num_chunks): + new_tensor = tensor.__class__( + new_qdatas[idx], + new_scales[idx], + new_block_sizes[idx], + tensor.mm_config, + tensor.act_quant_kwargs, + tensor.kernel_preference, + tensor.dtype, + ) + new_tensor = return_and_correct_aliasing(func, args, kwargs, new_tensor) + new_tensors_list.append(new_tensor) + + new_tensors_tuple = tuple(new_tensors_list) + return new_tensors_tuple + + Float8Tensor.__module__ = "torchao.quantization" # Allow a model with Float8Tensor weights to be loaded with `weights_only=True` diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 10315d45f5..b371b21f06 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -641,6 +641,15 @@ def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig): quantize_(l, config) _w_slice = l.weight[0] + def _test_chunk_similar_to_vllm_llama4(self, ao_tensor, dim): + # source code in vLLM LLaMa 4: + # https://github.com/vllm-project/vllm/blob/34553b9d2702dd2a27a578fec819e88e76dcbfb4/vllm/model_executor/models/llama4.py#L455 + ao_tensor_chunked = ao_tensor.chunk(2, dim=dim) + ao_tensor_unchunked = torch.cat(ao_tensor_chunked, dim=dim) + torch.testing.assert_close( + ao_tensor.dequantize(), ao_tensor_unchunked.dequantize(), atol=0, rtol=0 + ) + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)