Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,8 @@ def build_args_parser() -> argparse.ArgumentParser:
"--spin_qmode",
type=str,
default=None,
choices=["8da4w"],
help="Quantization mode for SpinQuant. Only support 8da4w right now.",
choices=["8da4w", "8da4w_output_8da8w"],
help="Quantization mode for SpinQuant. Only support 8da4w and 8da4w_output_8da8w right now.",
)

parser.add_argument(
Expand Down
17 changes: 16 additions & 1 deletion examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ def __init__(self, **kwargs):
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
print("Using SPIN quantization.")
assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified"
assert self.args.spin_qmode in [
"8da4w",
"8da4w_output_8da8w",
], f"Quantization mode {self.args.spin_qmode} is not compatible with SpinQuant."
assert hasattr(
self.args, "spin_group_size"
), "spin_group_size must be specified"
Expand All @@ -209,11 +213,22 @@ def __init__(self, **kwargs):
"bf16": torch.bfloat16,
}

# Transform the output layer first if needed.
if self.args.spin_qmode == "8da4w_output_8da8w":
from .source_transformation.spin_quant import (
transform_output_linear_for_spinquant,
)

self.model_ = transform_output_linear_for_spinquant(
module=self.model_,
checkpoint=checkpoint,
dtype=mapping[self.args.dtype_override],
)

self.model_ = transform_linear_for_spinquant(
self.model_,
checkpoint,
self.args.spin_group_size,
self.args.spin_qmode,
mapping[self.args.dtype_override],
)

Expand Down
65 changes: 56 additions & 9 deletions examples/models/llama2/source_transformation/spin_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

from .quantize import QuantizedGroupEmbedding
from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding


def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
Expand Down Expand Up @@ -129,20 +129,16 @@ def transform_linear_for_spinquant(
module: torch.nn.Module,
checkpoint: Any,
group_size: int,
quantization_mode: str,
dtype: torch.dtype,
) -> torch.nn.Module:
"""
Transform the model to be able to load SpinQuant checkpoints that
are quantized with the given group size and quantization mode.
are quantized with the given group size and quantization mode for
linear layers.
"""

if group_size not in [32, 64, 128, 256]:
raise ValueError(f"Group size {group_size} is not supported for SpinQuant.")
if quantization_mode not in ["8da4w"]:
raise ValueError(
f"Quantization mode {quantization_mode} is not compatible with SpinQuant."
)
_replace_linear_with_linear_8da4w_for_spin_quant(
module,
checkpoint,
Expand All @@ -153,6 +149,53 @@ def transform_linear_for_spinquant(
return module


def _replace_output_linear_with_linear_int8_for_spinquant(
module: torch.nn.Module,
checkpoint: Any,
dtype: torch.dtype,
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
scales_key = f"{cur_fqn}.scale"
if (
isinstance(child, nn.Linear)
and scales_key in checkpoint
and "output" in cur_fqn
):
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
assert checkpoint[scales_key].dtype == dtype
return True
return False

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = Int8DynActInt8WeightLinear(
device=child.weight.device,
in_features=child.in_features,
out_features=child.out_features,
precision=dtype,
bias=False,
)
return new_linear

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def transform_output_linear_for_spinquant(
module: torch.nn.Module,
checkpoint: Any,
dtype: torch.dtype,
) -> torch.nn.Module:
"""
Transform the model to be able to load SpinQuant checkpoints that
has the output layer quantized per-channel.
"""
_replace_output_linear_with_linear_int8_for_spinquant(
module,
checkpoint,
dtype,
)
return module


def _replace_embedding_with_quantized_group_embedding_for_spinquant(
module: torch.nn.Module,
checkpoint: Any,
Expand Down Expand Up @@ -233,8 +276,10 @@ def sanitize_checkpoint_from_spinquant(
module_name = new_key[0 : new_key.rfind(".")]
sub_module = module.get_submodule(module_name)
assert sub_module is not None
assert isinstance(sub_module, Int8DynActInt4WeightLinear) or isinstance(
sub_module, QuantizedGroupEmbedding
assert (
isinstance(sub_module, Int8DynActInt4WeightLinear)
or isinstance(sub_module, QuantizedGroupEmbedding)
or isinstance(sub_module, Int8DynActInt8WeightLinear)
)
# Checkpoints with SpinQuant could come with two formats for scales:
# 1. scales is grouped by group size
Expand All @@ -245,6 +290,8 @@ def sanitize_checkpoint_from_spinquant(
checkpoint[new_key] = (
old_val if linear_group_size == -1 else old_val[:, ::linear_group_size]
)
elif isinstance(sub_module, Int8DynActInt8WeightLinear):
checkpoint[new_key] = old_val[:, 0]
elif isinstance(sub_module, QuantizedGroupEmbedding):
if (
embedding_group_size is None or embedding_group_size == 0
Expand Down
54 changes: 52 additions & 2 deletions examples/models/llama2/tests/test_spinquant_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

import torch
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
from executorch.examples.models.llama2.source_transformation.quantize import (
dynamically_quantize_per_channel,
)
from executorch.examples.models.llama2.source_transformation.spin_quant import (
sanitize_checkpoint_from_spinquant,
transform_embedding_for_spinquant,
transform_linear_for_spinquant,
transform_output_linear_for_spinquant,
)
from torchao.quantization.utils import group_quantize_tensor_symmetric

Expand Down Expand Up @@ -51,8 +55,7 @@ def test_transform_linear_for_spinquant(self):
n_bit = 4
scales_precision = torch.float32
for fqn, mod in model.named_modules():
# Quantize everything except the last layer
if isinstance(mod, torch.nn.Linear) and ("output" not in fqn):
if isinstance(mod, torch.nn.Linear):
weight = mod.weight.data
(
weight_int8,
Expand Down Expand Up @@ -92,6 +95,53 @@ def test_transform_linear_for_spinquant(self):
# have to iterate over the keys.
self.assertTrue(torch.allclose(new_checkpoint[k], v))

def test_transform_output_linear_for_spinquant(self):
# Step 1: Create llama class with dummy weights
model = self._prepare_dummy_model()
checkpoint = model.state_dict()

# Step 2:
# Do per-channel quantization and amend the checkpoints with
# int8 weight and fp32 scales
for fqn, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear) and fqn == "output":
weight = mod.weight.data
weight_int8, scales, _ = dynamically_quantize_per_channel(
weight,
quant_min=-128,
quant_max=127,
target_dtype=torch.int8,
scales_dtype=torch.float32,
)
checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
checkpoint[f"{fqn}.scale"] = scales.to("cpu")

# Step 3:
# Transform the model so that it is compatible with the new checkpoint
transform_output_linear_for_spinquant(
model,
checkpoint,
torch.float32,
)
sanitize_checkpoint_from_spinquant(
model,
checkpoint,
-1,
)

model.load_state_dict(
checkpoint,
strict=False,
assign=True,
)

new_checkpoint = model.state_dict()

for k, v in checkpoint.items():
# The new_checkpoint contains zeros so
# have to iterate over the keys.
self.assertTrue(torch.allclose(new_checkpoint[k], v))

def test_transform_embedding_for_spinquant(self):

# Step 1: Create llama class with dummy weights
Expand Down
Loading