Skip to content
Open
50 changes: 26 additions & 24 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from torchao.utils import (
check_cpu_version,
check_xpu_version,
get_current_accelerator_device,
is_fbcode,
is_ROCM,
is_sm_at_least_89,
Expand All @@ -47,10 +48,11 @@
is_cusparselt_available = (
hasattr(torch.backends, "cusparselt") and torch.backends.cusparselt.is_available()
)
_DEVICE = get_current_accelerator_device()


def get_quantization_functions(
do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
do_sparse: bool, do_int4: bool, device: str = _DEVICE, int4_zp_int: bool = False
):
base_functions = [
Int8WeightOnlyConfig(),
Expand Down Expand Up @@ -105,9 +107,9 @@ class TestAffineQuantized(TestCase):
["xpu"] if torch.xpu.is_available() else []
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_tensor_core_layout_transpose(self):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
t = linear.weight
shape = t.shape
apply_int4_weight_only_quant = Int4WeightOnlyConfig(group_size=32, version=1)
Expand Down Expand Up @@ -169,7 +171,7 @@ def _apply(module, config_or_subclass_inserter):
ql = _apply(linear, apply_quant)
ql.to(device)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_register_new_dispatch(self):
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes.affine_quantized_tensor_ops import (
Expand Down Expand Up @@ -206,10 +208,10 @@ def apply_uint6_weight_only_quant(linear):
)
return linear

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
apply_uint6_weight_only_quant(linear)

example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE)
with self.assertRaisesRegex(
AssertionError, "dispatching to my impl for uint6 weight only quant"
):
Expand All @@ -234,11 +236,11 @@ def test_print_quantized_module(self):

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"apply_quant", get_quantization_functions(False, True, "cuda", False)
"apply_quant", get_quantization_functions(False, True, _DEVICE, False)
)
def test_test_copy__apply(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)

if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
Expand All @@ -249,20 +251,20 @@ def test_test_copy__apply(self, apply_quant):
ql = apply_quant(linear)
ql2 = apply_quant(linear2)

example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE)
output = ql(example_input)
ql2.weight.copy_(ql.weight)
ql2.bias = ql.bias
output2 = ql2(example_input)
self.assertEqual(output, output2)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
@common_utils.parametrize(
"apply_quant", get_quantization_functions(False, True, "cuda", False)
"apply_quant", get_quantization_functions(False, True, _DEVICE, False)
)
def test_copy__mismatch_metadata(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda")
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device=_DEVICE)

if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
Expand Down Expand Up @@ -336,7 +338,7 @@ def test_alias(self, device, dtype):
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
_ = dummy.weight[...]

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.bfloat16])
@skip_if_no_cuda()
@skip_if_rocm("ROCm enablement in progress")
Expand All @@ -350,9 +352,9 @@ def test_slice_int4wo(self, device, dtype):
_ = dummy.weight.narrow(0, 0, 64)
_ = dummy.weight.narrow(1, 0, 128)

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.float16, torch.bfloat16])
@skip_if_no_cuda()
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
@skip_if_no_gemlite()
def test_slice_gemlite(self, device, dtype):
# in_feature not divisible by 1024
Expand Down Expand Up @@ -433,7 +435,7 @@ def dequant(input_layer, in_features, orig_shape):
)
self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0)

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.bfloat16])
def test_matmul(self, device, dtype):
x = torch.randn(53, 2048)
Expand All @@ -450,14 +452,14 @@ def test_matmul(self, device, dtype):
# make sure it runs
torch.matmul(x, w.t())

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.bfloat16])
@skip_if_no_cuda()
@skip_if_rocm("ROCm enablement in progress")
def test_slice_and_copy_int4wo(self, device, dtype):
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
torch.zeros(1024, 1024, dtype=torch.bfloat16, device=_DEVICE)
)
quantize_(l, Int4WeightOnlyConfig(version=1))
param = l.weight
Expand All @@ -474,7 +476,7 @@ def test_slice_and_copy_int4wo(self, device, dtype):
assert param.data.dequantize()[0][0] == 0

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
dummy_l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16)
quantize_(dummy_l, Int4WeightOnlyConfig(version=1))
quantized = dummy_l.weight
quantized = quantized.narrow(0, 0, 512)
Expand All @@ -484,9 +486,9 @@ def test_slice_and_copy_int4wo(self, device, dtype):
# making sure param.data is updated
assert param.data.dequantize()[0][0] != 0

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.bfloat16])
@skip_if_no_cuda()
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
@skip_if_rocm("ROCm enablement in progress")
def test_mm_int4wo(self, device, dtype):
weight = torch.randn(512, 1024).to(device).to(dtype)
Expand Down
Loading
Loading