Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ runtime.python_library(
"export_llama_lib.py",
"model.py",
"source_transformation/apply_spin_quant_r1_r2.py",
"source_transformation/lora.py",
"source_transformation/pre_quantization.py",
"source_transformation/prune_output.py",
"source_transformation/quantize.py",
Expand Down
17 changes: 17 additions & 0 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,23 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
)

parser.add_argument(
"-qat",
"--use_qat",
default=False,
action="store_true",
help="Whether the checkpoin is pre-quantized with QAT or not.",
)

parser.add_argument(
"-lora",
"--use_lora",
type=int,
default=0,
help="Whether the checkpoint contains LoRA adaptors or not. 0: no LoRA adaptors; "
"otherwise, it means the rank of LoRA adaptors. Currently it only works if QAT is enabled.",
)

parser.add_argument(
"--preq_mode",
type=str,
Expand Down
137 changes: 80 additions & 57 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,73 +191,31 @@ def __init__(self, **kwargs):
)
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
print("Using SPIN quantization.")
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
assert self.args.preq_mode in [
"8da4w",
"8da4w_output_8da8w",
], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
assert hasattr(
self.args, "preq_group_size"
), "preq_group_size must be specified"
assert hasattr(
self.args, "dtype_override"
), "dtype_override must be specified"
self._transform_for_pre_quantization(checkpoint)

from .source_transformation.pre_quantization import (
sanitize_checkpoint_from_pre_quantization,
transform_linear_for_pre_quantization,
)

mapping = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

# Transform the output layer first if needed.
if self.args.preq_mode == "8da4w_output_8da8w":
from .source_transformation.pre_quantization import (
transform_output_linear_for_pre_quantization,
)

self.model_ = transform_output_linear_for_pre_quantization(
module=self.model_,
checkpoint=checkpoint,
dtype=mapping[self.args.dtype_override],
)

self.model_ = transform_linear_for_pre_quantization(
self.model_,
checkpoint,
self.args.preq_group_size,
mapping[self.args.dtype_override],
)

embedding_bit_width, embedding_group_size = None, None
if hasattr(self.args, "preq_embedding_quantize"):
embedding_bit_width, embedding_group_size = (
self.args.preq_embedding_quantize.split(",")
)
from .source_transformation.pre_quantization import (
transform_embedding_for_pre_quantization,
sanitize_checkpoint_from_pre_quantization(checkpoint)
elif hasattr(self.args, "use_qat") and self.args.use_qat:
print("Using QAT quantization.")
self._transform_for_pre_quantization(checkpoint)
if hasattr(self.args, "use_lora") and self.args.use_lora:
from .source_transformation.lora import (
transform_linear_for_lora_after_quantization,
)

if (
embedding_group_size == "none"
or embedding_group_size == "None"
or embedding_group_size == "0"
):
embedding_group_size = None
else:
embedding_group_size = int(embedding_group_size)

self.model_ = transform_embedding_for_pre_quantization(
self.model_ = transform_linear_for_lora_after_quantization(
self.model_,
checkpoint,
mapping[self.args.dtype_override],
int(embedding_bit_width),
embedding_group_size,
self.args.use_lora,
)

from .source_transformation.pre_quantization import (
sanitize_checkpoint_from_pre_quantization,
)

sanitize_checkpoint_from_pre_quantization(checkpoint)

# assign=True: load params/buffers by assignment instead of performing an in-place copy.
Expand Down Expand Up @@ -318,3 +276,68 @@ def get_example_inputs_kvcache_sdpa(self):
[0], dtype=torch.long
), # start_pos, what token of output are we on.
)

def _transform_for_pre_quantization(self, checkpoint):
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
assert self.args.preq_mode in [
"8da4w",
"8da4w_output_8da8w",
], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
assert hasattr(
self.args, "preq_group_size"
), "preq_group_size must be specified"
assert hasattr(self.args, "dtype_override"), "dtype_override must be specified"
from .source_transformation.pre_quantization import (
transform_linear_for_pre_quantization,
)

mapping = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

# Transform the output layer first if needed.
if self.args.preq_mode == "8da4w_output_8da8w":
from .source_transformation.pre_quantization import (
transform_output_linear_for_pre_quantization,
)

self.model_ = transform_output_linear_for_pre_quantization(
module=self.model_,
checkpoint=checkpoint,
dtype=mapping[self.args.dtype_override],
)

self.model_ = transform_linear_for_pre_quantization(
self.model_,
checkpoint,
self.args.preq_group_size,
mapping[self.args.dtype_override],
)

embedding_bit_width, embedding_group_size = None, None
if hasattr(self.args, "preq_embedding_quantize"):
embedding_bit_width, embedding_group_size = (
self.args.preq_embedding_quantize.split(",")
)
from .source_transformation.pre_quantization import (
transform_embedding_for_pre_quantization,
)

if (
embedding_group_size == "none"
or embedding_group_size == "None"
or embedding_group_size == "0"
):
embedding_group_size = None
else:
embedding_group_size = int(embedding_group_size)

self.model_ = transform_embedding_for_pre_quantization(
self.model_,
checkpoint,
mapping[self.args.dtype_override],
int(embedding_bit_width),
embedding_group_size,
)
141 changes: 141 additions & 0 deletions examples/models/llama2/source_transformation/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

# Helper functions for tranforming the model to be able to load checkpoints with
# LoRA adaptors. See https://arxiv.org/abs/2106.09685 for more details about LoRA.

from typing import Any

import torch
from torch import nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter


class LoRAAdaptorLinear(nn.Module):
"""
LoRA adaptor for linear layers.

This class implements Low-Rank Adaptation(LoRA) for linear layers.
See more details about LoRA here https://arxiv.org/abs/2106.09685.
"""

def __init__(
self,
in_features: int,
out_features: int,
rank: int,
scale: float = 2.0,
dtype=torch.float32,
device=None,
) -> None:
super().__init__()
self.scale = scale
self.A = nn.Linear(in_features, rank, bias=False, dtype=dtype, device=device)
self.B = nn.Linear(rank, out_features, bias=False, dtype=dtype, device=device)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.scale * self.B(self.A(x)) # pyre-ignore[7]


class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
"""
Int8DynActInt4WeightLinear with LoRA adaptor.
"""

def __init__(
self,
in_features: int,
out_features: int,
lora_rank: int,
bias=True,
device=None,
groupsize: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
lora_adaptor_precision: torch.dtype = torch.bfloat16,
lora_scale: float = 2.0,
) -> None:
super().__init__(
in_features,
out_features,
bias=bias,
device=device,
groupsize=groupsize,
precision=precision,
scales_precision=scales_precision,
)
self.adaptor = LoRAAdaptorLinear(
in_features,
out_features,
lora_rank,
scale=lora_scale,
dtype=lora_adaptor_precision,
device=device,
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input) + self.adaptor(input).to(dtype=self.precision)


def _replace_linear_8da4w_for_lora(
module: torch.nn.Module,
checkpoint: Any,
lora_rank: int,
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
# Only replace linear layers where the checkpoint contains explicit adaptors
adaptor_A_key = f"{cur_fqn}.adaptor.A.weight"
adaptor_B_key = f"{cur_fqn}.adaptor.B.weight"
if (
isinstance(child, Int8DynActInt4WeightLinear)
and adaptor_A_key in checkpoint
and adaptor_B_key in checkpoint
):
assert checkpoint[adaptor_A_key].dtype == torch.bfloat16
assert checkpoint[adaptor_A_key].shape[0] == lora_rank
assert checkpoint[adaptor_A_key].shape[1] == child.in_features
assert checkpoint[adaptor_B_key].dtype == torch.bfloat16
assert checkpoint[adaptor_B_key].shape[0] == child.out_features
assert checkpoint[adaptor_B_key].shape[1] == lora_rank
return True
return False

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = Int8DynActInt4WeightLinearLoRA(
child.in_features,
child.out_features,
lora_rank=lora_rank,
bias=False,
device=child.weight.device,
groupsize=child.groupsize,
precision=child.precision,
scales_precision=child.scales.dtype,
)
return new_linear

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def transform_linear_for_lora_after_quantization(
module: torch.nn.Module,
checkpoint: Any,
lora_rank: int,
) -> torch.nn.Module:
"""
Transform the model to be able to load checkpoints with LoRA adaptors.
The model should be already transformed to be able to load pre-quantized
checkpoints. The checkpoint should have been pre-quantized and added with
LoRA adaptors.
"""
_replace_linear_8da4w_for_lora(
module,
checkpoint,
lora_rank,
)
return module
Loading