Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# This test takes a long time to run
import unittest
import torch
import os
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
Expand All @@ -18,9 +19,10 @@
get_symmetric_quantization_config,
)

from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.quant_api import apply_dynamic_quant
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
apply_dynamic_quant,
apply_weight_only_int8_quant,
Quantizer,
TwoStepQuantizer,
)
Expand Down Expand Up @@ -137,6 +139,26 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int8_wo_quant_save_load(self):
m = M().eval().cpu()
apply_weight_only_int8_quant(m)
example_inputs = m.example_inputs()
ref = m(*example_inputs)
_TMP_FN = "_test.pt"
torch.save(m.state_dict(), _TMP_FN)

state_dict = torch.load(_TMP_FN)
os.remove(_TMP_FN)
m2 = M().eval()
apply_weight_only_int8_quant(m2)
m2.load_state_dict(state_dict)
m2 = m2.to(device="cuda")
example_inputs = map(lambda x: x.cuda(), example_inputs)
res = m2(*example_inputs)

torch.testing.assert_close(ref, res.cpu())

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_8da4w_quantizer(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -300,7 +322,6 @@ def test_gptq_quantizer_gpt_fast(self):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down Expand Up @@ -357,6 +378,41 @@ def test_gptq_quantizer_int4wo(self):
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)
groupsize = 128
quantizer = Int4WeightOnlyQuantizer(
groupsize,
)
model = quantizer.quantize(model).cuda()
result = TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert result['results']['wikitext']['word_perplexity,none'] < 8.24, (
f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper(self):
from torchao.quantization.GPTQ import TransformerEvalWrapper
Expand Down
111 changes: 98 additions & 13 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
groupwise_affine_quantize_tensor_from_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
pack_tinygemm_scales_and_zeros,
groupwise_affine_quantize_tensor,
)
aten = torch.ops.aten

Expand Down Expand Up @@ -65,8 +66,8 @@

__all__ = [
"MultiInput",
"WeightOnlyInt4Linear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
] + add_ons

if lm_eval_available:
Expand Down Expand Up @@ -117,7 +118,10 @@ def __init__(

@property
def eot_token_id(self):
return self._tokenizer.eos_id()
try:
return self._tokenizer.eos_id()
except:
return self._tokenizer.eos_id

@property
def max_length(self):
Expand All @@ -139,7 +143,10 @@ def tok_encode(self, string: str, **kwargs):
# TODO: verify this for multi-batch as well
tokens = self._tokenizer.encode(string)
if hasattr(self._tokenizer, "bos_id"):
tokens = [self._tokenizer.bos_id()] + tokens
try:
tokens = [self._tokenizer.bos_id()] + tokens
except:
tokens = [self._tokenizer.bos_id] + tokens
return tokens

def tok_decode(self, tokens):
Expand Down Expand Up @@ -747,6 +754,12 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module:
pass

def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
k_divisible_by_groupsize = k % groupsize == 0
if inner_k_tiles is not None:
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
return k_divisible_by_groupsize

def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
Expand All @@ -767,7 +780,7 @@ def __init__(
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
) -> None:
super().__init__()
self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
from model import find_multiple
self.origin_in_features = in_features
Expand Down Expand Up @@ -806,14 +819,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)


def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
k_divisible_by_groupsize = k % groupsize == 0
if inner_k_tiles is not None:
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
return k_divisible_by_groupsize

def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None):

for name, child in module.named_children():
Expand All @@ -826,6 +831,83 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func)

class Int4WeightOnlyQuantizer(Quantizer):
def __init__(
self,
groupsize: int = 256,
padding_allowed: bool = True,
inner_k_tiles: Optional[int] = 8,
) -> None:
super().__init__()
assert inner_k_tiles in [2, 4, 8]
assert groupsize in [32, 64, 128, 256]

self.inner_k_tiles = inner_k_tiles
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed

@torch.no_grad()
def _create_quantized_state_dict(
self, model: torch.nn.Module
) -> Dict[str, torch.Tensor]:
cur_state_dict = model.state_dict()
for fqn, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
# assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")

assert (
in_features % self.groupsize == 0
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"

weight = mod.weight.data
if not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
):
if self.padding_allowed:
from .utils import find_multiple
import torch.nn.functional as F
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
else:
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
"and that groupsize and inner_k_tiles*16 evenly divide into it")
continue
(
w_int4x8,
scales_and_zeros
) = groupwise_affine_quantize_tensor(
weight,
4, # n_bit
self.groupsize,
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to("cuda"), self.inner_k_tiles)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cuda")
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cuda")
return cur_state_dict

def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
replace_linear_int4(
model,
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
)
return model

def quantize(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
state_dict = self._create_quantized_state_dict(model)
model = self._convert_for_runtime(model)
# TODO: make it strict
model.load_state_dict(state_dict, strict=False)
return model

class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer):
def __init__(
self,
Expand Down Expand Up @@ -868,7 +950,10 @@ def __init__(
# TODO: this is the gpt-fast version, merge with the main version later
def make_names_and_values_dict_func(q, qparams):
k = q.shape[1]
new_k = find_multiple(k, 1024)
if not _check_linear_int4_k(k, groupsize):
new_k = find_multiple(k, 1024)
else:
new_k = k
# how much we need to pad the weight
delta_k = new_k - q.shape[1]
q = q.to(torch.int32)
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@
"compute_error",
"get_model_size_in_bytes",
"WeightOnlyInt8QuantLinear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
]
2 changes: 2 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .unified import Quantizer, TwoStepQuantizer
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)


Expand All @@ -45,6 +46,7 @@
"Quantizer",
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer"
]

if TORCH_VERSION_AFTER_2_3:
Expand Down
1 change: 0 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def pack_tinygemm_scales_and_zeros(scales, zeros):

def unpack_tinygemm_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)


Expand Down
5 changes: 2 additions & 3 deletions torchao/quantization/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def __init__(self, *args, **kwargs):
scales = kwargs.pop("scales")
super().__init__(*args, **kwargs)

self.w_int8 = w_int8

self.scales = scales
self.register_buffer("w_int8", w_int8)
self.register_buffer("scales", scales)

def forward(self, x, *args, **kwargs):
"""
Expand Down