Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Didn't find engine for operation quantized::linear_prepack NoQEngine on Apple Silicion PyTorch 2.2.2 #123507

Closed
RazeBerry opened this issue Apr 6, 2024 · 7 comments
Labels
has workaround module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 oncall: quantization Quantization support in PyTorch

Comments

@RazeBerry
Copy link

RazeBerry commented Apr 6, 2024

馃悰 Describe the bug

When attempting to convert a standard torch.nn.Linear layer to its quantized counterpart using nnq.Linear.from_float(linear), an AssertionError is thrown if the qconfig is not explicitly set for the input module. I have tried to find the solution to this problem everywhere and no one has solved it yet it seems! Furthermore, the problem is concentrated within the Apple Silicon users too.

import torch
from torch import nn
import torch.nn.quantized as nnq
from torch.quantization import get_default_qconfig, prepare, convert

# Define the model
class SimpleLinearModel(nn.Module):
    def __init__(self):
        super(SimpleLinearModel, self).__init__()
        self.linear = nn.Linear(5, 10)  # Example dimensions

    def forward(self, x):
        return self.linear(x)

# Instantiate the model
model = SimpleLinearModel()

# Define the qconfig (using 'fbgemm' or 'qnnpack' configuration)
qconfig = get_default_qconfig('fbgemm')  # or 'qnnpack'

# Apply the qconfig to the model
model.qconfig = qconfig

# Prepare the model for quantization
prepared_model = prepare(model, inplace=False)

# Convert the prepared model to a quantized model
quantized_model = convert(prepared_model, inplace=False)

# Now, quantized_model is ready for inference or further operations
/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/quantization/observer.py:220: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/quantization/observer.py:1263: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point 
  warnings.warn(
Traceback (most recent call last):
  File "/Users/sihao/Documents/errorreproducer.py", line 28, in <module>
    quantized_model = convert(prepared_model, inplace=False)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/quantization/quantize.py", line 553, in convert
    _convert(
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/quantization/quantize.py", line 593, in _convert
    reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/quantization/quantize.py", line 626, in swap_module
    new_mod = qmod.from_float(mod)
              ^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/linear.py", line 277, in from_float
    qlinear = cls(mod.in_features,
              ^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/linear.py", line 151, in __init__
    self._packed_params = LinearPackedParams(dtype)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/linear.py", line 27, in __init__
    self.set_weight_bias(wq, None)
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/linear.py", line 32, in set_weight_bias
    self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/_ops.py", line 755, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Didn't find engine for operation quantized::linear_prepack NoQEngine

Versions

PyTorch version: 2.2.2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.2.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.28.3
Libc version: N/A

Python version: 3.11.0 (main, Feb 1 2024, 23:57:54) [Clang 15.0.0 (clang-1500.1.0.2.5)] (64-bit runtime)
Python platform: macOS-14.2.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] onnxruntime==1.17.0
[pip3] optree==0.11.0
[pip3] pytorch-lightning==2.2.1
[pip3] pytorch-metric-learning==2.4.1
[pip3] torch==2.2.2
[pip3] torch-audiomentations==0.11.1
[pip3] torch-pitch-shift==1.2.4
[pip3] torchaudio==2.2.2
[pip3] torchmetrics==1.3.1
[pip3] torchvision==0.17.2
[conda] numpy 1.26.4 py311h7125741_0 conda-forge
[conda] pytorch 2.2.1 py3.11_0 pytorch
[conda] torchvision 0.17.1 py311_cpu pytorch

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @malfet @snadampal

@malfet malfet added oncall: quantization Quantization support in PyTorch module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 labels Apr 7, 2024
@malfet
Copy link
Contributor

malfet commented Apr 7, 2024

fbgemm was always too x86 oriented, not sure what's the thing with qnnpack though:

% python -c "import torch;print(torch.backends.quantized.supported_engines)"
['qnnpack', 'none']

@RazeBerry
Copy link
Author

RazeBerry commented Apr 8, 2024

fbgemm was always too x86 oriented, not sure what's the thing with qnnpack though:

% python -c "import torch;print(torch.backends.quantized.supported_engines)"
['qnnpack', 'none']

Sorry about the error in the code, even if I change it to: qconfig = get_default_qconfig('qnnpack') . The issue persists.

Furthermore, qnnpack is there:

(base) sihao@Sihaos-MacBook-Pro-2 documents % python -c "import torch;print(torch.backends.quantized.supported_engines)"
['qnnpack', 'none']

@Xia-Weiwen
Copy link
Collaborator

Hi @RazeBerry How about adding the following before get_default_qconfig?

torch.backends.quantized.engine = 'qnnpack'

@malfet
Copy link
Contributor

malfet commented Apr 8, 2024

@RazeBerry is @Xia-Weiwen suggested, specifying quantized engine should fix your problem, though I agree that qnnpack backend (as the only one available on ARM platform) should have been selected by default

@RazeBerry
Copy link
Author

@Xia-Weiwen @malfet

That works! Thank you so much !! It is a bit confusing it has to be intentionally specified at all considering QNN is the only one available on ARM. It would be great if the package can be modified so one doesn't need that line. I am using many packages with pytorch as dependencies and I almost always have problem quantization due to the fact that line is not specified in those packages.

@RazeBerry
Copy link
Author

@Xia-Weiwen @malfet Just one last question, I was wondering if this is were to be expected running it on ARM architecture, have seen Whisper throwing the similar type of error due to non-implementation? Thank you very much

import torch
from torch import nn
import torch.nn.quantized as nnq
from torch.quantization import get_default_qconfig, prepare, convert

# Define the model
class SimpleLinearModel(nn.Module):
    def __init__(self):
        super(SimpleLinearModel, self).__init__()
        self.linear = nn.Linear(5, 10)  # Example dimensions

    def forward(self, x):
        return self.linear(x)

# Instantiate the model
model = SimpleLinearModel()

torch.backends.quantized.engine = 'qnnpack'

# Define the qconfig (using 'fbgemm' or 'qnnpack' configuration)
qconfig = get_default_qconfig('qnnpack')  # or 'fbgemm'

# Apply the qconfig to the model
model.qconfig = qconfig

# Prepare the model for quantization
model.eval()  # Set the model to evaluation mode
prepared_model = prepare(model)

# Calibrate the model with sample data
calibration_data = torch.randn(64, 5)  # Generate sample data for calibration
prepared_model(calibration_data)

# Convert the prepared model to a quantized model
quantized_model = convert(prepared_model)

# Test the quantized model
input_data = torch.randn(1, 5)
output = quantized_model(input_data)
print(output)
Traceback (most recent call last):
  File "/Users/sihao/Documents/errorreproducer.py", line 39, in <module>
    output = quantized_model(input_data)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/Documents/errorreproducer.py", line 13, in forward
    return self.linear(x)
           ^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/linear.py", line 168, in forward
    return torch.ops.quantized.linear(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/_ops.py", line 755, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Could not run 'quantized::linear' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::linear' is only available for these backends: [MPS, Meta, QuantizedCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

MPS: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:75 [backend fallback]
Meta: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/quantized/cpu/qlinear.cpp:1140 [kernel]
BackendSelect: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:154 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:324 [backend fallback]
Named: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]
AutogradCPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]
AutogradCUDA: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]
AutogradXLA: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]
AutogradMPS: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]
AutogradXPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]
AutogradHPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]
AutogradLazy: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]
AutogradMeta: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]
Tracer: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/TraceTypeManual.cpp:297 [backend fallback]
AutocastCPU: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastCUDA: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:720 [backend fallback]
BatchedNestedTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:746 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:162 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:166 [backend fallback]
PythonDispatcher: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:158 [backend fallback]

@Xia-Weiwen
Copy link
Collaborator

@RazeBerry You are using eager mode quantization in your script. In this mode, you need to insert QuantStub and DeQuantStub in your model to quantize input and dequantize output. See doc here: https://pytorch.org/docs/stable/quantization.html#post-training-static-quantization
Alternatively, You may use FX mode so that quant/dequant are inserted automatically and you don't have to insert them yourself: https://pytorch.org/docs/stable/quantization.html#prototype-maintaince-mode-fx-graph-mode-quantization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
has workaround module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 oncall: quantization Quantization support in PyTorch
Projects
None yet
Development

No branches or pull requests

3 participants