Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions test/dtypes/test_bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from torch.utils._triton import has_triton

from torchao.dtypes.uintx.bitpacking import pack, pack_cpu, unpack, unpack_cpu
from torchao.utils import get_current_accelerator_device

bit_widths = (1, 2, 3, 4, 5, 6, 7)
dimensions = (0, -1, 1)
_DEVICE = get_current_accelerator_device()


@pytest.fixture(autouse=True)
Expand All @@ -30,40 +32,46 @@ def test_CPU(bit_width, dim):
assert unpacked.allclose(test_tensor)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dim", dimensions)
def test_GPU(bit_width, dim):
test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda()
test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).to(
_DEVICE
)
packed = pack(test_tensor, bit_width, dim=dim)
unpacked = unpack(packed, bit_width, dim=dim)
assert unpacked.allclose(test_tensor)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dim", dimensions)
def test_compile(bit_width, dim):
torch._dynamo.config.specialize_int = True
torch.compile(pack, fullgraph=True)
torch.compile(unpack, fullgraph=True)
test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda()
test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).to(
_DEVICE
)
packed = pack(test_tensor, bit_width, dim=dim)
unpacked = unpack(packed, bit_width, dim=dim)
assert unpacked.allclose(test_tensor)


# these test cases are for the example pack walk through in the bitpacking.py file
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
def test_pack_example():
test_tensor = torch.tensor(
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
).cuda()
).to(_DEVICE)
shard_4, shard_2 = pack(test_tensor, 6)
print(shard_4, shard_2)
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4)
assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2)
assert (
torch.tensor([0, 105, 151, 37], dtype=torch.uint8).to(_DEVICE).allclose(shard_4)
)
assert torch.tensor([39, 146], dtype=torch.uint8).to(_DEVICE).allclose(shard_2)
unpacked = unpack([shard_4, shard_2], 6)
assert unpacked.allclose(test_tensor)

Expand Down
13 changes: 7 additions & 6 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@
quantize_,
)
from torchao.testing.utils import skip_if_rocm
from torchao.utils import is_fbcode
from torchao.utils import get_current_accelerator_device, is_fbcode

_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_Floatx_DTYPES = [(3, 2), (2, 2)]
_DEVICE = get_current_accelerator_device()
_DEVICES = ["cpu"] + ([_DEVICE] if torch.accelerator.is_available() else [])


class TestFloatxTensorCoreAQTTensorImpl(TestCase):
Expand Down Expand Up @@ -87,7 +88,7 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device):
)
torch.testing.assert_close(actual, expected)

@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), reason="GPU not available")
@parametrize("ebits,mbits", _Floatx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
from torchao.quantization.quant_primitives import (
Expand All @@ -101,8 +102,8 @@ def test_to_copy_device(self, ebits, mbits):
_layout = FloatxTensorCoreLayout(ebits, mbits)
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(
x, scale, None, _layout
).cuda()
assert floatx_tensor_impl.device.type == "cuda"
).to(_DEVICE)
assert floatx_tensor_impl.device.type == _DEVICE.type
floatx_tensor_impl = floatx_tensor_impl.cpu()
assert floatx_tensor_impl.device.type == "cpu"

Expand All @@ -114,7 +115,7 @@ def test_to_copy_device(self, ebits, mbits):
@skip_if_rocm("ROCm enablement in progress")
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
N, OC, IC = 4, 256, 64
device = "cuda"
device = _DEVICE

linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype)
fpx_linear = copy.deepcopy(linear)
Expand Down
Loading