Skip to content

[CoreML] ValueError: ("'activation_dtype' must be in [torch.quint8, torch.float32] (got torch.uint8)" #7587

@daniil-lyakhov

Description

@daniil-lyakhov

🐛 Describe the bug

Hi there,
I'm trying to run the quantization example from the CoreML ReadMe https://github.com/pytorch/executorch/tree/main/backends/apple/coreml#quantization

import torch
import executorch.exir

from torch.export import export_for_training
from torch.ao.quantization.quantize_pt2e import (
    convert_pt2e,
    prepare_pt2e,
    prepare_qat_pt2e,
)

from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
from coremltools.optimize.torch.quantization.quantization_config import (
    LinearQuantizerConfig,
    QuantizationScheme,
)

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        a = self.conv(x)
        return self.relu(a)

source_model = Model()
example_inputs = (torch.randn((1, 3, 256, 256)), )

pre_autograd_aten_dialect = export_for_training(source_model, example_inputs).module()

quantization_config = LinearQuantizerConfig.from_dict(
    {
        "global_config": {
            "quantization_scheme": QuantizationScheme.symmetric,
            "activation_dtype": torch.uint8,
            "weight_dtype": torch.int8,
            "weight_per_channel": True,
        }
    }
)
quantizer = CoreMLQuantizer(quantization_config)

# For post-training quantization, use `prepare_pt2e`
# For quantization-aware trainin,g use `prepare_qat_pt2e`
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)

prepared_graph(*example_inputs)
converted_graph = convert_pt2e(prepared_graph)
And (with my typo fix #7586 ) still the code fails with following error:
Torch version 2.5.0 has not been tested with coremltools. You may run into unexpected errors. Torch 2.4.0 is the most recent version that has been tested.
Traceback (most recent call last):
  File "<cattrs generated structure coremltools.optimize.torch.quantization.quantization_config.ModuleLinearQuantizerConfig>", line 51, in structure_ModuleLinearQuantizerConfig
  File "<attrs generated init coremltools.optimize.torch.quantization.quantization_config.ModuleLinearQuantizerConfig>", line 13, in __init__
  File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/attr/_make.py", line 2972, in __call__
    v(inst, attr, value)
  File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/attr/validators.py", line 246, in __call__
    raise ValueError(
ValueError: ("'activation_dtype' must be in [torch.quint8, torch.float32] (got torch.uint8)", Attribute(name='activation_dtype', default=torch.quint8, validator=_AndValidator(_validators=(<instance_of validator for type <class 'torch.dtype'>>, <in_ validator with options [torch.quint8, torch.float32]>)), repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=<class 'torch.dtype'>, converter=<function maybe_convert_str_to_dtype at 0x16d8dc040>, kw_only=False, inherited=False, on_setattr=None, alias='activation_dtype'), [torch.quint8, torch.float32], torch.uint8)

During handling of the above exception, another exception occurred:

  + Exception Group Traceback (most recent call last):
  |   File "/Users/dupeljan/Documents/executorch/coreml_tut/main_q_no_export.py", line 34, in <module>
  |     quantization_config = LinearQuantizerConfig.from_dict(
  |                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  |   File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/coremltools/optimize/torch/quantization/quantization_config.py", line 411, in from_dict
  |     return converter.structure_attrs_fromdict(config_dict, cls)
  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  |   File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/cattrs/converters.py", line 754, in structure_attrs_fromdict
  |     conv_obj[getattr(a, "alias", a.name)] = self._structure_attribute(a, val)
  |                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  |   File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/cattrs/converters.py", line 735, in _structure_attribute
  |     return self._structure_func.dispatch(type_)(value, type_)
  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  |   File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/cattrs/converters.py", line 879, in _structure_optional
  |     return self._structure_func.dispatch(other)(obj, other)
  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  |   File "<cattrs generated structure coremltools.optimize.torch.quantization.quantization_config.ModuleLinearQuantizerConfig>", line 54, in structure_ModuleLinearQuantizerConfig
  | cattrs.errors.ClassValidationError: While structuring ModuleLinearQuantizerConfig (1 sub-exception)
  +-+---------------- 1 ----------------
    | Traceback (most recent call last):
    |   File "<cattrs generated structure coremltools.optimize.torch.quantization.quantization_config.ModuleLinearQuantizerConfig>", line 51, in structure_ModuleLinearQuantizerConfig
    |   File "<attrs generated init coremltools.optimize.torch.quantization.quantization_config.ModuleLinearQuantizerConfig>", line 13, in __init__
    |   File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/attr/_make.py", line 2972, in __call__
    |     v(inst, attr, value)
    |   File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/attr/validators.py", line 246, in __call__
    |     raise ValueError(
    | ValueError: ("'activation_dtype' must be in [torch.quint8, torch.float32] (got torch.uint8)", Attribute(name='activation_dtype', default=torch.quint8, validator=_AndValidator(_validators=(<instance_of validator for type <class 'torch.dtype'>>, <in_ validator with options [torch.quint8, torch.float32]>)), repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=True, metadata=mappingproxy({}), type=<class 'torch.dtype'>, converter=<function maybe_convert_str_to_dtype at 0x16d8dc040>, kw_only=False, inherited=False, on_setattr=None, alias='activation_dtype'), [torch.quint8, torch.float32], torch.uint8)
    +------------------------------------

I was trying to install torch==2.4.0 but then it fails with following:

Traceback (most recent call last):
  File "/Users/dupeljan/Documents/executorch/coreml_tut/main_q_no_export.py", line 2, in <module>
    import executorch.exir
  File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/executorch/exir/__init__.py", line 9, in <module>
    from executorch.exir.capture import (
  File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/executorch/exir/capture/__init__.py", line 9, in <module>
    from executorch.exir.capture._capture import (
  File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/executorch/exir/capture/_capture.py", line 17, in <module>
    from executorch.exir.program import ExirExportedProgram
  File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/executorch/exir/program/__init__.py", line 10, in <module>
    from executorch.exir.program._program import (
  File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/executorch/exir/program/_program.py", line 53, in <module>
    from executorch.exir.verification.verifier import (
  File "/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/executorch/exir/verification/verifier.py", line 23, in <module>
    from torch._export.utils import _detect_fake_mode_from_gm
ImportError: cannot import name '_detect_fake_mode_from_gm' from 'torch._export.utils' (/Users/dupeljan/Documents/executorch/executorch_env/lib/python3.12/site-packages/torch/_export/utils.py)
CC: @dbort @swolchok @Olivia-liu

Versions

PyTorch version: 2.5.0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.31.2
Libc version: N/A

Python version: 3.12.5 (main, Aug 6 2024, 19:08:49) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2

Versions of relevant libraries:
[pip3] executorch==0.4.0a0+6a085ff
[pip3] executorchcoreml==0.0.1
[pip3] numpy==1.26.4
[pip3] torch==2.5.0
[pip3] torchaudio==2.5.0
[pip3] torchsr==1.0.4
[pip3] torchvision==0.20.0
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: coremlIssues related to Apple's Core ML delegation and code under backends/apple/coreml/module: examplesIssues related to demos under examples/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions