-
Notifications
You must be signed in to change notification settings - Fork 358
[CPU] add Float8OpaqueTensor for dynamic float8 act float8 weight #3075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d460134
cf8dc09
4333727
6e1c2a2
7980de8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import tempfile | ||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.testing._internal import common_utils | ||
| from torch.testing._internal.common_utils import ( | ||
| TestCase, | ||
| run_tests, | ||
| ) | ||
|
|
||
| from torchao import quantize_ | ||
| from torchao.quantization import PerGroup, PerRow, PerTensor | ||
| from torchao.quantization.quant_api import ( | ||
| Float8DynamicActivationFloat8WeightConfig, | ||
| ) | ||
| from torchao.quantization.utils import compute_error | ||
| from torchao.utils import ( | ||
| torch_version_at_least, | ||
| ) | ||
|
|
||
|
|
||
| def get_config(granularity): | ||
| return Float8DynamicActivationFloat8WeightConfig( | ||
| activation_dtype=torch.float8_e4m3fn, | ||
| granularity=granularity, | ||
| float8_packing_format="opaque", | ||
| ) | ||
|
|
||
|
|
||
| class ToyLinearModel(torch.nn.Module): | ||
| def __init__(self, K=64, N=32, bias=False): | ||
| super().__init__() | ||
| self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float) | ||
| self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float) | ||
|
|
||
| def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): | ||
| return ( | ||
| torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device) | ||
| * 0.1, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| x = self.linear1(x) | ||
| x = self.linear2(x) | ||
| return x | ||
|
|
||
|
|
||
| class TestFloat8OpaqueTensor(TestCase): | ||
| """Test cases for Float8OpaqueTensor on CPU""" | ||
|
|
||
| @unittest.skipIf( | ||
| "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
| reason="cpp kernels not built", | ||
| ) | ||
| @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") | ||
| @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) | ||
| @common_utils.parametrize("x_dim", [2, 3]) | ||
| @common_utils.parametrize("bias", [True, False]) | ||
| @common_utils.parametrize("bs", [1, 160]) | ||
| @common_utils.parametrize( | ||
| "x_granularity", | ||
| [PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)], | ||
| ) | ||
| @common_utils.parametrize( | ||
| "w_granularity", | ||
| [PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)], | ||
| ) | ||
| def test_dynamic_float8_linear( | ||
| self, dtype, x_dim, bias, bs, x_granularity, w_granularity | ||
| ): | ||
| if isinstance(x_granularity, PerGroup): | ||
| if not isinstance(w_granularity, PerGroup): | ||
| return | ||
| if w_granularity.group_size != x_granularity.group_size: | ||
| return | ||
| device = "cpu" | ||
| m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) | ||
| example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) | ||
| if x_dim == 3: | ||
| example_inputs = (example_inputs[0].unsqueeze(0),) | ||
| y = m(*example_inputs) | ||
|
|
||
| with torch.no_grad(): | ||
| quantize_( | ||
| m, | ||
| get_config([x_granularity, w_granularity]), | ||
| ) | ||
| y1 = m(*example_inputs) | ||
| assert compute_error(y, y1) > 20 | ||
| y2, code = torch._inductor.utils.run_and_get_code( | ||
| torch.compile(m, fullgraph=True, dynamic=True), | ||
| *example_inputs, | ||
| ) | ||
| # ensure the expected op is in the code | ||
| assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] | ||
| assert compute_error(y, y2) > 20 | ||
|
|
||
| @unittest.skipIf( | ||
| "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
| reason="cpp kernels not built", | ||
| ) | ||
| @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") | ||
| @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) | ||
| @common_utils.parametrize("x_dim", [2, 3]) | ||
| @common_utils.parametrize("bias", [True, False]) | ||
| @common_utils.parametrize("bs", [4, 128]) | ||
| def test_dynamic_float8_linear_ref(self, dtype, x_dim, bias, bs): | ||
| device = "cpu" | ||
| # the shape is not supported by cpp kernel, so the ref path will be used. | ||
| m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device) | ||
| example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) | ||
| if x_dim == 3: | ||
| example_inputs = (example_inputs[0].unsqueeze(0),) | ||
| y = m(*example_inputs) | ||
|
|
||
| with torch.no_grad(): | ||
| quantize_( | ||
| m, | ||
| get_config(PerRow()), | ||
| ) | ||
| y1 = m(*example_inputs) | ||
| assert compute_error(y, y1) > 20 | ||
| y2, code = torch._inductor.utils.run_and_get_code( | ||
| torch.compile(m, fullgraph=True, dynamic=True), | ||
| *example_inputs, | ||
| ) | ||
| # ensure the expected op is in the code | ||
| assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] | ||
| assert compute_error(y, y2) > 20 | ||
|
|
||
| @unittest.skipIf( | ||
| "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
| reason="cpp kernels not built", | ||
| ) | ||
| @common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) | ||
| def test_module_path(self, dtype): | ||
| linear = torch.nn.Linear(128, 256, dtype=dtype) | ||
| quantize_(linear, get_config(PerRow())) | ||
| self.assertEqual( | ||
| str(type(linear.weight)), | ||
| "<class 'torchao.quantization.Float8OpaqueTensor'>", | ||
| ) | ||
|
|
||
| with tempfile.NamedTemporaryFile() as f: | ||
| torch.save(linear.state_dict(), f) | ||
| f.seek(0) | ||
| state_dict = torch.load(f) | ||
| self.assertEqual( | ||
| str(type(state_dict["weight"])), | ||
| "<class 'torchao.quantization.Float8OpaqueTensor'>", | ||
| ) | ||
|
|
||
|
|
||
| common_utils.instantiate_parametrized_tests(TestFloat8OpaqueTensor) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |||||||
| from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul | ||||||||
| from torchao.float8.types import FP8Granularity | ||||||||
| from torchao.quantization.granularity import ( | ||||||||
| PerGroup, | ||||||||
| PerRow, | ||||||||
| PerTensor, | ||||||||
| ) | ||||||||
|
|
@@ -204,28 +205,41 @@ def _normalize_granularity( | |||||||
| list[FP8Granularity], | ||||||||
| ] | ||||||||
| ], | ||||||||
| supported_granularities: tuple[FP8Granularity] = (PerTensor, PerRow), | ||||||||
| support_different_granularities: bool = False, | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is weird, I think we should have normalize_granularity to only do normalize, not also validation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel the same actually. Where should we put the validation? Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems to probably will be clearer if you do this in a separate PR, that is move the original There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. Will do. Thanks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about version=1? Call
How shall we do validation at these locations? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @jerryzh168 Do you have any suggestions for this? Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think just create There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh actually looks like
I think you can do the following: torchao/quantization/quantize_/workflows/float8/utils.py torchao/quantization/quantize_/workflows/float8/float8_tensor.py torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py |
||||||||
| ) -> Tuple[FP8Granularity, FP8Granularity]: | ||||||||
| processed_granularity = None | ||||||||
| if granularity is None: | ||||||||
| processed_granularity = (PerTensor(), PerTensor()) | ||||||||
| elif isinstance(granularity, (PerTensor, PerRow)): | ||||||||
| elif isinstance(granularity, supported_granularities): | ||||||||
| processed_granularity = (granularity, granularity) | ||||||||
| elif isinstance(granularity, (tuple, list)) and len(granularity) == 2: | ||||||||
| if not ( | ||||||||
| isinstance(granularity[0], (PerTensor, PerRow)) | ||||||||
| and isinstance(granularity[1], (PerTensor, PerRow)) | ||||||||
| isinstance(granularity[0], supported_granularities) | ||||||||
| and isinstance(granularity[1], supported_granularities) | ||||||||
| ): | ||||||||
| raise ValueError( | ||||||||
| f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported." | ||||||||
| f"Invalid granularity types: {granularity}, only {supported_granularities} are supported." | ||||||||
| ) | ||||||||
| if not isinstance(granularity[0], type(granularity[1])): | ||||||||
| if not support_different_granularities and not isinstance( | ||||||||
| granularity[0], type(granularity[1]) | ||||||||
| ): | ||||||||
| raise ValueError( | ||||||||
| f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." | ||||||||
| f"Different granularities for activation and weight are not supported: {granularity}, only {supported_granularities} are supported." | ||||||||
| ) | ||||||||
| if isinstance(granularity[0], PerGroup): | ||||||||
| if not isinstance(granularity[1], PerGroup): | ||||||||
| raise ValueError( | ||||||||
| "When granularity for activation is PerGroup, granularity for weight must be PerGroup, too." | ||||||||
| ) | ||||||||
| if granularity[0].group_size != granularity[1].group_size: | ||||||||
| raise ValueError( | ||||||||
| f"Group sizes for activation and weight must be the same, got {granularity[0].group_size} and {granularity[1].group_size}." | ||||||||
| ) | ||||||||
| processed_granularity = tuple(granularity) | ||||||||
| else: | ||||||||
| raise ValueError( | ||||||||
| f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported." | ||||||||
| f"Invalid granularity specification: {granularity}, only {supported_granularities} are supported." | ||||||||
| ) | ||||||||
| return processed_granularity | ||||||||
|
|
||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.