Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
20041de
introduce new int8 quantization API
namgyu-youn Oct 24, 2025
6e1e06f
refactor ops and update test cases
namgyu-youn Oct 26, 2025
7354407
update granularity slicing support
namgyu-youn Oct 26, 2025
e0d3794
add 3D support to api, build linear variants test
namgyu-youn Oct 31, 2025
ffe9e7e
add kernel detection test case
namgyu-youn Oct 31, 2025
a9a6318
refactor kernel test
namgyu-youn Oct 31, 2025
4988ae0
update linear variant, kernel detection test
namgyu-youn Nov 4, 2025
0a3bdd7
update default granularity, kernel test
namgyu-youn Nov 11, 2025
43469f0
fix quantization ops
namgyu-youn Nov 11, 2025
29074ea
merge test cases with cleanup
namgyu-youn Nov 12, 2025
d0a75b3
update `block_size` args to `granularity`
namgyu-youn Nov 16, 2025
4dd0f9e
update expected kernel test
namgyu-youn Nov 16, 2025
900af9d
use Granularity for slicing logic instead of block_size
namgyu-youn Nov 20, 2025
69eccac
update tensor slicing for per-tensor/row/block scales
namgyu-youn Nov 20, 2025
c498862
fix __new__/__init__ signatures and formatting
namgyu-youn Nov 20, 2025
8c13748
reland toy linear model
namgyu-youn Nov 22, 2025
4757d4a
update casting logic
namgyu-youn Nov 22, 2025
f1968fb
add block_size attribute, separate version 1 from 2
namgyu-youn Nov 22, 2025
a0d94a2
fix activation kwargs
namgyu-youn Nov 22, 2025
3441751
add todo for future update
namgyu-youn Nov 24, 2025
9906ff5
Add PerBlock to safe globals (#3370)
andrewor14 Nov 23, 2025
59b4333
Merge branch 'main' into int8-quant-api
namgyu-youn Nov 25, 2025
e940d2b
empty commit to trigger CI
namgyu-youn Nov 25, 2025
fb29f5e
Move float8_opaque_tensor to prototype (#3365)
Xia-Weiwen Nov 24, 2025
aabab76
revert scale slicing, implement slicing logic for Int8Tensor directly
namgyu-youn Nov 25, 2025
0760eac
drop unrelated commit
namgyu-youn Nov 26, 2025
d83f615
fix after rebase
namgyu-youn Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/quantization_overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ First we want to lay out the torchao stack::

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor
---------------------------------------------------------------------------------------------
Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma
- scaled int4
- preshuffled (special format to optimize for loading)
- float8 act + int4 weight dynamic quantization and int4 weight only quantization
* - Int8Tensor
- plain

.. note::
We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options.
Expand Down
217 changes: 217 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

import torch
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal import common_utils

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.utils import compute_error, get_block_size
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.testing.utils import TorchAOIntegrationTestCase


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
class TestInt8Tensor(TorchAOIntegrationTestCase):
def setUp(self):
super().setUp()

self.test_shape = (32, 20)
self.dtype = torch.bfloat16
self.batch_size = 32

torch.manual_seed(42)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_creation_and_attributes(self, config):
"""Test tensor creation, dtypes, and ranges"""
linear = torch.nn.Linear(
self.test_shape[1],
self.test_shape[0],
bias=False,
dtype=self.dtype,
device="cuda",
)
quantize_(linear, config)

w = linear.weight

self.assertEqual(w.shape, self.test_shape)
self.assertEqual(w.qdata.dtype, torch.int8)
self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127))

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128), # 2D
((32, 128), 64, 256), # 3D
],
)
def test_int8_linear_variants(
self,
dtype: torch.dtype,
config,
compile: bool,
sizes: tuple,
):
"""Test linear operation supports including shape and compile"""
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval()
model_q = copy.deepcopy(model)

quantize_(model_q, config)

self.assertEqual(model_q.linear2.weight.scale.shape, (K,))
self.assertEqual(model_q.linear2.weight.scale.ndim, 1)

if compile:
model_q = torch.compile(model_q, fullgraph=True)

output_fp = model(input_tensor)
output_quantized = model_q(input_tensor)

assert compute_error(output_fp, output_quantized) > 20, (
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize("device", ["cpu", "cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_slice(self, config, device, dtype):
"""Test tensor slicing with per-row quantization"""
tensor_size = 256
slice_sizes = (64, 128)

dummy = torch.nn.Linear(
tensor_size, tensor_size, bias=False, dtype=dtype, device=device
)
quantize_(dummy, config)

weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0])
weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1])

self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0]))
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]))
self.assertEqual(weight2.scale, dummy.weight.scale)
with self.assertRaises(NotImplementedError):
_ = dummy.weight[::2]

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
],
)
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
def test_index_select(self, config, granularity):
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
N, K = 256, 512
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
linear.weight.data = x

config = config(version=2, granularity=granularity)
quantize_(linear, config)

x_int8 = linear.weight
x_int8_0 = x_int8[0]

# Test dequantization consistency
torch.testing.assert_close(
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0
)

# Test block_size granularity
if isinstance(granularity, PerRow):
self.assertEqual(
list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K]
)
elif isinstance(granularity, PerTensor):
self.assertEqual(
list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K]
)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_dequantization_accuracy(self, config):
"""Test dequantization accuracy separately"""
linear = torch.nn.Linear(
256, 512, bias=False, dtype=torch.bfloat16, device="cuda"
)
weight_fp = copy.deepcopy(linear.weight)
quantize_(linear, config)

tensor = linear.weight
dequantized = tensor.dequantize()
self.assertEqual(dequantized.shape, weight_fp.shape)
assert compute_error(dequantized, weight_fp) > 20, (
f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}"
)

def test_available_gpu_kernels(self):
"""Check which GPU kernels are used"""
torch.compiler.reset()

M, K, N = 128, 256, 512
m = torch.nn.Sequential(
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
)

config = Int8DynamicActivationInt8WeightConfig(version=2)
quantize_(m, config)

m = torch.compile(m)
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)

out, code = run_and_get_code(m, x)

# Check expected kernels are present
FileCheck().check_count("triton_per_fused", 1).check_count(
"extern_kernels._int_mm", 1
).check_count("triton_poi_fused", 1).run(code[0])


if __name__ == "__main__":
common_utils.run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8Tensor,
IntxOpaqueTensor,
IntxUnpackedToInt8Tensor,
)
Expand Down Expand Up @@ -171,6 +172,7 @@
"IntxOpaqueTensor",
"IntxUnpackedToInt8Tensor",
"Int4TilePackedTo4dTensor",
"Int8Tensor",
"Float8Tensor",
# smooth quant - subject to change
"get_scale",
Expand Down
Loading