Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions test/quantization/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.quantization.utils import compute_error
from torchao.utils import get_current_accelerator_device

torch.manual_seed(0)

_DEVICE = get_current_accelerator_device()


class TestGPTQ(TestCase):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_gptq_quantizer_int4_weight_only(self):
from torchao._models._eval import (
LMEvalInputRecorder,
Expand All @@ -33,7 +36,7 @@ def test_gptq_quantizer_int4_weight_only(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer

precision = torch.bfloat16
device = "cuda"
device = _DEVICE
checkpoint_path = Path(
"../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"
)
Expand Down Expand Up @@ -80,15 +83,15 @@ def test_gptq_quantizer_int4_weight_only(self):
)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)

model = quantizer.quantize(model, *inputs).cuda()
model = quantizer.quantize(model, *inputs).to(_DEVICE)

model.reset_caches()
with torch.device("cuda"):
with torch.device(_DEVICE):
model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size)

limit = 1
result = TransformerEvalWrapper(
model.cuda(),
model.to(_DEVICE),
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
Expand All @@ -104,7 +107,7 @@ def test_gptq_quantizer_int4_weight_only(self):


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_multitensor_add_tensors(self):
from torchao.quantization.GPTQ import MultiTensor

Expand All @@ -116,7 +119,7 @@ def test_multitensor_add_tensors(self):
self.assertTrue(torch.equal(mt.values[0], tensor1))
self.assertTrue(torch.equal(mt.values[1], tensor2))

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_multitensor_pad_unpad(self):
from torchao.quantization.GPTQ import MultiTensor

Expand All @@ -127,7 +130,7 @@ def test_multitensor_pad_unpad(self):
mt.unpad()
self.assertEqual(mt.count, 1)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_multitensor_inplace_operation(self):
from torchao.quantization.GPTQ import MultiTensor

Expand All @@ -138,7 +141,7 @@ def test_multitensor_inplace_operation(self):


class TestMultiTensorInputRecorder(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_multitensor_input_recorder(self):
from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder

Expand All @@ -159,7 +162,7 @@ def test_multitensor_input_recorder(self):
self.assertTrue(isinstance(MT_input[2][2], MultiTensor))
self.assertEqual(MT_input[3], torch.float)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_gptq_with_input_recorder(self):
from torchao.quantization.GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Expand All @@ -170,7 +173,7 @@ def test_gptq_with_input_recorder(self):

config = ModelArgs(n_layer=2)

with torch.device("cuda"):
with torch.device(_DEVICE):
model = Transformer(config)
model.setup_caches(max_batch_size=2, max_seq_length=100)
idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32)
Expand All @@ -191,7 +194,14 @@ def test_gptq_with_input_recorder(self):

args = input_recorder.get_recorded_inputs()

quantizer = Int4WeightOnlyGPTQQuantizer()
if _DEVICE.type == "xpu":
from torchao.dtypes import Int4XPULayout

quantizer = Int4WeightOnlyGPTQQuantizer(
device=torch.device("xpu"), layout=Int4XPULayout()
)
else:
quantizer = Int4WeightOnlyGPTQQuantizer()

quantizer.quantize(model, *args)

Expand Down
12 changes: 8 additions & 4 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@
from torchao.utils import (
check_cpu_version,
check_xpu_version,
get_current_accelerator_device,
is_fbcode,
)

_SEED = 1234
torch.manual_seed(_SEED)

_DEVICE = get_current_accelerator_device()


# Helper function to run a function twice
# and verify that the result is the same.
Expand Down Expand Up @@ -592,16 +595,17 @@ def test_choose_qparams_tensor_asym_eps(self):
self.assertEqual(scale, eps)

@unittest.skipIf(
not torch.cuda.is_available(), "skipping when cuda is not available"
not torch.accelerator.is_available(), "skipping when gpu is not available"
)
def test_get_group_qparams_symmetric_memory(self):
"""Check the memory usage of the op"""
weight = torch.randn(1024, 1024).to(device="cuda")
original_mem_use = torch.cuda.memory_allocated()
weight = torch.randn(1024, 1024).to(device=_DEVICE)
device_module = torch.get_device_module(_DEVICE)
original_mem_use = device_module.memory_allocated()
n_bit = 4
groupsize = 128
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
after_choose_qparams_mem_use = torch.cuda.memory_allocated()
after_choose_qparams_mem_use = device_module.memory_allocated()
self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use)

def test_raises(self):
Expand Down
Loading