Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions backends/apple/coreml/test/test_coreml_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Please refer to the license found in the LICENSE file in the root directory of the source tree.


import copy
import unittest

import coremltools as ct
Expand Down Expand Up @@ -152,8 +153,9 @@ def forward(self, x):
# Test with different group sizes
for group_size in [8, 16, 32]:
with self.subTest(group_size=group_size):
model_to_export = copy.deepcopy(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy can you try without a copy here, as the PR i landed should take care of this within the source transform stage.

session = export(
model=model,
model=model_to_export,
example_inputs=example_inputs,
export_recipe=ExportRecipe.get_recipe(
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP,
Expand Down Expand Up @@ -219,8 +221,9 @@ def forward(self, x):
# Test with different group sizes
for group_size in [16, 32, 64]:
with self.subTest(group_size=group_size):
model_to_export = copy.deepcopy(model)
session = export(
model=model,
model=model_to_export,
example_inputs=example_inputs,
export_recipe=ExportRecipe.get_recipe(
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP,
Expand Down
7 changes: 5 additions & 2 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2680,14 +2680,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def apply_8da4w_quantization(self):
"""Apply TorchAO 8da4w quantization (int8 dynamic activation + int4 weight)."""
from torchao.quantization import (
int8_dynamic_activation_int4_weight,
Int8DynamicActivationIntxWeightConfig,
quantize_,
)
from torchao.quantization.granularity import PerGroup
from torchao.utils import unwrap_tensor_subclass

quantize_(
self,
int8_dynamic_activation_int4_weight(group_size=self.group_size),
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4, granularity=PerGroup(self.group_size)
),
)
unwrap_tensor_subclass(self)
return self
Expand Down
10 changes: 8 additions & 2 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
from torch.export.graph_signature import ExportGraphSignature, InputKind

try:
from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
Int8DynamicActivationIntxWeightConfig,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass
Expand Down Expand Up @@ -391,7 +392,12 @@ def _test_groupwise_dq_linear(
"""
Helper function to test groupwise dynamic quantized linear op with different configurations.
"""
quantize_(mod, int8_dynamic_activation_int4_weight(group_size=group_size))
quantize_(
mod,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4, weight_granularity=PerGroup(group_size)
),
)
unwrap_tensor_subclass(mod)
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
Expand Down
24 changes: 17 additions & 7 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def quantize( # noqa C901
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
bitwidth = int(matches[0][0])

from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
Expand All @@ -136,7 +135,7 @@ def quantize( # noqa C901
PerAxis(0) if group_size == 0 else PerGroup(group_size)
),
weight_mapping_type=MappingType.SYMMETRIC,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
intx_packing_format="opaque_torchao_auto",
),
)
model = unwrap_tensor_subclass(model)
Expand All @@ -148,10 +147,21 @@ def quantize( # noqa C901
# TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
group_size = 128

from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
from torchao.quantization import (
Int8DynamicActivationIntxWeightConfig,
quantize_,
)
from torchao.quantization.granularity import PerGroup
from torchao.utils import unwrap_tensor_subclass

quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(group_size),
),
)

model = unwrap_tensor_subclass(model)

# TODO: deal with checkpoint / computation dtype decoupling.
Expand Down Expand Up @@ -744,9 +754,9 @@ def get_quant_embedding_transform(
dtype_override: Optional[DType] = None,
):
if embedding_quantize.startswith("torchao:"):
from torchao.experimental.quant_api import (
from torchao.prototype.quantization.embedding.api import (
EmbeddingQuantizer,
SharedEmbeddingQuantizer,
TiedEmbeddingQuantizer,
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import MappingType
Expand Down Expand Up @@ -780,7 +790,7 @@ def _torchao_embedding_quantizer(model):
use_fallback=False,
).quantize(model)
else:
SharedEmbeddingQuantizer(
TiedEmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
Expand Down
2 changes: 1 addition & 1 deletion third-party/ao
Submodule ao updated 103 files
Loading