Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@
except RuntimeError:
pytest.skip("torchao.ops not available")

from torchao.quantization import PerGroup, PerRow, PerTensor
from torchao.quantization.quant_primitives import (
_choose_scale_float8,
_dequantize_affine_float8,
_quantize_affine_float8,
)
from torchao.quantization.utils import (
get_block_size,
get_groupwise_affine_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
groupwise_affine_quantize_tensor_from_qparams,
Expand Down Expand Up @@ -901,5 +908,91 @@ def _test_scaled_embedding_bag_cpu_helper(
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)


@pytest.mark.skipif(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
@pytest.mark.parametrize(
"multi_hot, batch_size, vector_size, index_type",
EMBEDINGBAG_TEST_PARAMS,
ids=str,
)
def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type):
_test_scaled_embedding_bag_cpu_helper(
multi_hot, batch_size, vector_size, index_type, torch.int8
)


@pytest.mark.skipif(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
@pytest.mark.parametrize(
"multi_hot, batch_size, vector_size, index_type",
EMBEDINGBAG_TEST_PARAMS,
ids=str,
)
def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type):
_test_scaled_embedding_bag_cpu_helper(
multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn
)


@pytest.mark.skipif(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu")
or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@pytest.mark.skipif(
not torch_version_at_least("2.6.0"), reason="Test only enabled for 2.6+"
)
@pytest.mark.parametrize("shape", [(64, 64), (256, 256)])
@pytest.mark.parametrize("bs", [1, 160])
@pytest.mark.parametrize("out_dtype", [torch.float, torch.bfloat16, torch.half])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("x_granularity", [PerTensor(), PerRow(), PerGroup(128)])
@pytest.mark.parametrize("w_granularity", [PerTensor(), PerRow(), PerGroup(128)])
def test_float8_linear_cpu(shape, bs, out_dtype, bias, x_granularity, w_granularity):
in_feature, out_feature = shape
if isinstance(x_granularity, PerGroup):
if x_granularity.group_size >= in_feature:
return
if not isinstance(w_granularity, PerGroup):
return
if isinstance(w_granularity, PerGroup):
if w_granularity.group_size >= in_feature:
return
m = torch.nn.Linear(in_feature, out_feature, bias=bias).eval()
b = m.bias
x = torch.randn(bs, in_feature)
x_block_size = get_block_size(x.shape, x_granularity)
x_scale = _choose_scale_float8(
x,
float8_dtype=torch.float8_e4m3fn,
block_size=x_block_size,
)
x_fp8 = _quantize_affine_float8(x, x_scale, torch.float8_e4m3fn)

w = m.weight.detach()
w_block_size = get_block_size(w.shape, w_granularity)
w_scale = _choose_scale_float8(
w,
float8_dtype=torch.float8_e4m3fn,
block_size=w_block_size,
)
w_fp8 = _quantize_affine_float8(w, w_scale, torch.float8_e4m3fn)

x_dq = _dequantize_affine_float8(x_fp8, x_scale)
w_dq = _dequantize_affine_float8(w_fp8, w_scale)
ref = torch.nn.functional.linear(x_dq, w_dq, b).to(out_dtype)

packed_w, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu(w_fp8, w_scale)
y = torch.ops.torchao.float8_linear_cpu(
x_fp8, x_scale, packed_w, packed_scale, b, out_dtype
)

torch.testing.assert_close(y, ref, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
pytest.main(sys.argv)
Loading
Loading