Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/models/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pathlib import Path
from typing import Any, Dict, Optional

import torch


def get_default_model_resource_dir(model_file_path: str) -> Path:
"""
Expand Down Expand Up @@ -52,7 +54,7 @@ def get_default_model_resource_dir(model_file_path: str) -> Path:
return resource_dir


def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]:
"""
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
"""
Expand Down
67 changes: 60 additions & 7 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import shlex
from enum import Enum
from functools import partial
from json import JSONDecodeError
from pathlib import Path
from typing import Callable, List, Optional, Union
Expand Down Expand Up @@ -594,9 +595,36 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
)

# At this point, the model is loaded in the default fp32.

# Checkpoint dtype should be lower or equal precision to the dtype override.
checkpoint_dtype = edge_manager.model.checkpoint_dtype
if not (
checkpoint_dtype == dtype_override.to_torch_dtype()
or (
checkpoint_dtype == torch.float16
and dtype_override.to_torch_dtype() == torch.float32
)
or (
checkpoint_dtype == torch.bfloat16
and dtype_override.to_torch_dtype() == torch.float32
)
):
logging.warning(
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
)

edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
edge_manager.set_output_dir(output_dir_path).source_transform(
_get_source_transforms(args.model, dtype_override, args)

# We want to quantize (in the source transforms) the weights of the model
# in the checkpoint dtype.
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
_get_source_transforms(
modelname=args.model,
dtype_override=dtype_override,
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype),
args=args,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

edge_manager.model(
    torch.tensor([[2, 3, 4]], dtype=torch.long),
    {"input_pos": torch.tensor([0], dtype=torch.long)},
)

Here to test

)

return edge_manager
Expand Down Expand Up @@ -784,8 +812,6 @@ def _to_edge_and_lower_llama( # noqa: C901
shares=args.num_sharding,
)

from functools import partial

# pyre-ignore
from executorch.backends.qualcomm.quantizer.custom_annotation import (
get_custom_quant_ios_dtype,
Expand Down Expand Up @@ -1069,8 +1095,31 @@ def _load_llama_model(


def _get_source_transforms( # noqa
modelname: str, dtype_override: Optional[DType], args
modelname: str,
dtype_override: DType,
*,
checkpoint_dtype: Optional[DType] = None,
args,
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
"""
Return a list of functions that transform a graph.

Args:
modelname: The name of the model.
dtype_override: The dtype to use for the model.
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
it means that you want to run quantize transformations on the weights represented
in their original dtype, while the overall dtype of the model maybe something
different. If not specified, defaults to dtype_override.
args: The arguments passed to the script.

Returns:
A list of transformation functions.
"""

if not checkpoint_dtype:
checkpoint_dtype = dtype_override

transforms = []

if args.use_spin_quant:
Expand Down Expand Up @@ -1103,7 +1152,11 @@ def _get_source_transforms( # noqa
"""
modelname = f"{modelname}_q"
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
get_quant_weight_transform(
args=args,
computation_dtype=dtype_override,
checkpoint_dtype=checkpoint_dtype,
)
)

if args.embedding_quantize:
Expand All @@ -1117,7 +1170,7 @@ def _get_source_transforms( # noqa
this wil be a no-op.
"""
modelname = f"{modelname}_e"
transforms.append(get_quant_embedding_transform(args))
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))

if args.expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)
Expand Down
99 changes: 82 additions & 17 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from sentencepiece import SentencePieceProcessor


try:
from fairseq2.nn.embedding import (
Embedding as fsEmbedding,
Expand All @@ -36,7 +37,8 @@
def quantize( # noqa C901
model: torch.nn.Module,
qmode: str,
activation_dtype: Optional[DType],
computation_dtype: Optional[DType] = None,
checkpoint_dtype: Optional[DType] = None,
checkpoint_path: Optional[Path] = None,
# following arguments only available when setting int4 or gptq quantization.
group_size: Optional[int] = 128,
Expand All @@ -52,20 +54,33 @@ def quantize( # noqa C901
) -> torch.nn.Module:
"""
Quantizes a model by converting all weights to int8.

Args:
model: A model to quantize.
qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq
model: The model to quantize.
qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq.
computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization).
Also the dtype of the rest of the non-quantized compoents of the model.
checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to
quantize the weight in its original dtype.

Returns:
A quantized model.
"""
if activation_dtype is not None:
torch_dtype = activation_dtype.to_torch_dtype()
if computation_dtype:
computation_torch_dtype = computation_dtype.to_torch_dtype()
else:
torch_dtype = torch.float16
computation_torch_dtype = torch.float32

if not checkpoint_dtype:
checkpoint_torch_dtype = computation_torch_dtype
else:
checkpoint_torch_dtype = checkpoint_dtype.to_torch_dtype()

if qmode == "int8":
# Add quantization mode options here: group size, bit width, etc.
return WeightOnlyInt8QuantHandler(model).quantized_model()
return WeightOnlyInt8QuantHandler(
model, precision=checkpoint_torch_dtype
).quantized_model()
elif qmode.startswith("torchao:fpa"):
pattern = r"torchao:fpa(\d+)w"
matches = re.findall(pattern, qmode)
Expand All @@ -75,10 +90,12 @@ def quantize( # noqa C901
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer

with torch.no_grad():
# This quantize() is currently doing a model.to(self.precision) so cannot
# decouple computation and checkpoint dtypes.
model = (
UIntxWeightOnlyLinearQuantizer(
device="mps",
precision=torch.float32,
precision=computation_torch_dtype,
groupsize=group_size,
bitwidth=bitwidth,
)
Expand All @@ -101,6 +118,8 @@ def quantize( # noqa C901
from torchao.utils import unwrap_tensor_subclass

with torch.no_grad():
# Computation dtype is fixed to fp32 in the implementation of quantize_, so
# no way to decouple checkpoint and computation dtype.
quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
Expand All @@ -121,9 +140,12 @@ def quantize( # noqa C901
raise Exception("For 8da4w quantization, group size must be specified.")
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

# 1. Quantize in checkpoint dtype.
model = Int8DynActInt4WeightQuantizer(
precision=torch_dtype, groupsize=group_size
precision=checkpoint_torch_dtype, groupsize=group_size
).quantize(model)
# 2. Set the computation dtype (what weights/acts dequantize to).
model = set_8da4w_computation_dtype(model, computation_torch_dtype)

if verbose:
print("quantized model:", model)
Expand Down Expand Up @@ -177,7 +199,7 @@ def quantize( # noqa C901
blocksize,
percdamp,
group_size,
)
) # TODO: separate computation and checkpoint dtype for GPTQ.
model = gptq_quantizer.quantize(model, inputs)
return model
elif qmode == "vulkan_4w":
Expand All @@ -190,9 +212,12 @@ def quantize( # noqa C901
# at the moment
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

# 1. Quantize in checkpoint dtype.
model = Int8DynActInt4WeightQuantizer(
precision=torch_dtype, groupsize=q_group_size
precision=checkpoint_torch_dtype, groupsize=q_group_size
).quantize(model)
# 2. Set the computation dtype (what weights/acts dequantize to).
model = set_8da4w_computation_dtype(model, computation_torch_dtype)

return model
else:
Expand Down Expand Up @@ -348,6 +373,7 @@ def __init__(
node_type: str = "*",
bitwidth: Optional[int] = None,
group_size: Optional[int] = None,
precision: torch.dtype = torch.float32,
):
self.mod = mod
self.group_size = group_size
Expand All @@ -356,6 +382,7 @@ def __init__(
self.bitwidth = 8
else:
self.bitwidth = bitwidth
self.precision = precision

@torch.no_grad()
def create_quantized_state_dict(self) -> Dict:
Expand Down Expand Up @@ -391,7 +418,7 @@ def create_quantized_state_dict(self) -> Dict:

# print(f"expanded weight shape {input_weight.shape}")
weight, scales, _ = dynamically_quantize_per_channel(
input_weight,
input_weight.to(dtype=self.precision),
range_min,
range_max,
torch.int8,
Expand Down Expand Up @@ -576,6 +603,7 @@ def __init__(
bitwidth: int = 8,
group_size: Optional[int] = None,
packed=False,
precision: Optional[torch.dtype] = None,
):
if isinstance(packed, str):
packed = packed == "True"
Expand All @@ -584,6 +612,8 @@ def __init__(
self.group_size = group_size
self.bitwidth = bitwidth
self.packed = packed
# Dtype of the weights right before quantization.
self.precision = precision
if (bitwidth not in [2, 4]) and packed:
raise RuntimeError("pack only works with bitsize 2, 4")

Expand Down Expand Up @@ -614,7 +644,11 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
)
weight, scales, _ = dynamically_quantize_per_channel(
mod.weight.float(),
(
mod.weight.to(dtype=self.precision)
if self.precision
else mod.weight
),
range_min,
range_max,
torch.int8,
Expand Down Expand Up @@ -750,7 +784,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
############################ Source Transform Start #######################


def get_quant_embedding_transform(args):
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
if args.embedding_quantize.startswith("torchao:"):
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
group_size = int(group_size)
Expand All @@ -775,16 +809,22 @@ def _torchao_embedding_quantizer(model):
else:
group_size = int(group_size)
bitwidth = int(bitwidth)
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
return lambda model: EmbeddingQuantHandler(
model,
bitwidth=bitwidth,
group_size=group_size,
packed=(bitwidth in [2, 4]),
precision=torch_dtype,
).quantized_model()


def get_quant_weight_transform(args, dtype_override, verbose):
# If these optional args are None, don't provide them to quantize()
def get_quant_weight_transform(
args,
computation_dtype: Optional[DType] = None,
checkpoint_dtype: Optional[DType] = None,
):
# If these optional args are None, don't provide them to quantize().
quant_args_str = [
"group_size",
"calibration_tasks",
Expand All @@ -802,7 +842,8 @@ def get_quant_weight_transform(args, dtype_override, verbose):
quantize,
**quant_args,
qmode=args.quantization_mode,
activation_dtype=dtype_override,
computation_dtype=computation_dtype,
checkpoint_dtype=checkpoint_dtype,
checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None),
tokenizer_path=(
Path(path) if (path := args.tokenizer_path) is not None else None
Expand All @@ -829,4 +870,28 @@ def _load_torchao_aten_lib(libname):
torch.ops.load_library(libs[0])


# We want to do compute the actual ops in the computation dtype, since the precision of the
# quantized linear will initially be the dtype of the checkpoint.
def set_8da4w_computation_dtype(
module: nn.Module, computation_dtype: torch.dtype
) -> nn.Module:

from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

def _set_8da4w_computation_dtype(module: nn.Module, dtype: torch.dtype) -> None:
"""
Recursively iterate through the module and set the precision attributes
of all Int8DynActInt4WeightLinears.
"""
for _name, child in module.named_children():
if isinstance(child, Int8DynActInt4WeightLinear):
child.precision = dtype
else:
# Recursively apply to child modules
_set_8da4w_computation_dtype(child, dtype)

_set_8da4w_computation_dtype(module, computation_dtype)
return module


############################ Source Transform End #######################
2 changes: 1 addition & 1 deletion examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, input_pos, embeddings):
args = parser.parse_args(
["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
)
quant_transform = get_quant_weight_transform(args, dtype_override, False)
quant_transform = get_quant_weight_transform(args, dtype_override)
_, quantizers, _ = get_quantizer_and_quant_params(args)
source_transforms = []
if llava.use_sdpa_with_kv_cache_op:
Expand Down
4 changes: 2 additions & 2 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,10 +708,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
et_program = et.executorch_program
inputs = et_program.execution_plan[0].inputs
self.assertNotEqual(
et_program.execution_plan[0] # pyre-ignore
et_program.execution_plan[0]
.values[inputs[0]]
.val.allocation_info.memory_offset_low,
et_program.execution_plan[0] # pyre-ignore
et_program.execution_plan[0]
.values[inputs[1]]
.val.allocation_info.memory_offset_low,
)
Expand Down
Loading
Loading