Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ if [[ "${CUSTOM}" == "ON" ]]; then
EXPORT_ARGS="${EXPORT_ARGS} model.use_sdpa_with_kv_cache=true"
fi
if [[ "${QE}" == "ON" ]]; then
EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,1024\""
EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,768\""
fi
if [[ "${MPS}" == "ON" ]]; then
EXPORT_ARGS="${EXPORT_ARGS} backend.mps.enabled=true model.enable_dynamic_shape=false debug.verbose=true"
Expand Down
5 changes: 1 addition & 4 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.extension.export_util.utils import save_pte_program

Expand Down Expand Up @@ -211,9 +210,7 @@ def main() -> None:
executorch_program = edge_manager.to_executorch(
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
QuantFusionPass(),
],
do_quant_fusion_and_const_prop=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
Expand Down
43 changes: 22 additions & 21 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,19 +595,16 @@ def __init__(

@torch.no_grad()
def create_quantized_state_dict(self, packed=False) -> Dict:
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
IntxWeightOnlyConfig,
MappingType,
quantize_,
)

cur_state_dict = self.mod.state_dict()

if self.bitwidth == 2:
range_min = -2
range_max = 1
elif self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
assert self.bitwidth in [2, 4, 8], f"Unsupported bitwidth {self.bitwidth}"

for fqn, mod in self.mod.named_modules():
if isinstance(mod, nn.Embedding):
Expand All @@ -619,18 +616,22 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
print(
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
)
weight, scales, _ = dynamically_quantize_per_channel(
(
mod.weight.to(dtype=self.precision)
if self.precision
else mod.weight
tmp_model = nn.Embedding(mod.weight.shape[0], mod.weight.shape[1])
if self.precision:
tmp_model = tmp_model.to(dtype=self.precision)
tmp_model.weight = nn.Parameter(mod.weight)
config = IntxWeightOnlyConfig(
weight_dtype=getattr(torch, f"int{self.bitwidth}"),
granularity=(
PerAxis(0)
if (self.group_size is None or self.group_size == 0)
else PerGroup(self.group_size)
),
range_min,
range_max,
torch.int8,
self.group_size,
scales_dtype=mod.weight.dtype,
mapping_type=MappingType.SYMMETRIC,
)
quantize_(tmp_model, config, lambda m, fqn: isinstance(m, nn.Embedding))
weight = tmp_model.weight.qdata # pyre-ignore[16]
scales = tmp_model.weight.scale # pyre-ignore[16]

if packed:
if self.bitwidth == 2:
Expand Down
Loading
Loading