Skip to content

Conversation

fadara01
Copy link
Collaborator

@fadara01 fadara01 commented Mar 5, 2025

Stack from ghstack (oldest at bottom):

This enables a fast path for eager mode static/dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PRs #126687, #139887 enabled an optimized implementation for qlinear and qlinear_dynamic for aarch64 through ideep → oneDNN → ACL which improved performance by ~10x compared to the previous implementation.
However, the current qlinear and qlinear_dynamic path (ideep → oneDNN → ACL) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (lowp_gemm) API - for example, ACL's lowp_gemm objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with qlinear and qlinear_dynamic.

  • For qlinear_dynamic (dynamically quantized matmuls):

This PR yields an average speedup (averaged over context_lengths of 2^3 up to 2^9) of ~ 50% for bert-base-uncased, bert-large-uncased, roberta-base, distilbert-base-uncased with 16 threads on a Neoverse-V1 (with transformers==4.48) for the benchmarking script below:

# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from transformers import AutoModel, AutoConfig
import time
import numpy as np
from argparse import ArgumentParser

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__(description="huggingface model")
        self.add_argument("--context_length",
                            help="context length - number of input tokens",
                            type=int,
                            default=64
        )
        self.add_argument("--model",
                            help="model checkpoint - i.e. 'bert-base-uncased'",
                            type=str,
                            default=None)
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    model_name = args.model
    config = AutoConfig.from_pretrained(model_name)
    batch_size = 1
    model = AutoModel.from_pretrained(model_name)
    model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    model.eval()
    inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu")
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("Model = ", model_name)         
    print("Context Length = ", args.context_length)
    print("Min (ms) = ", min(times))
    print("Mean (ms) = ", np.mean(times))  
  • For qlinear (statically quantized matmuls):

This PR yields an average speedup of 2x for signed activations (s8s8s8) and 95x for unsigned activations (u8s8u8) on a Neoverse-V1 with 16 threads for the benchmarking script below.
The averages are over for all combinations of M = [8, 16, ..., 512], K = [768, 1024, 2048, 4096], N = [768, 1024, 2048, 4096].
The astronomical speedup for unsigned activation is because oneDNN v3.7 does not have an optimized implementation for u8s8u8 on AArch64.

# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
import torch.nn as nn
from torch.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, default_weight_observer
import torch
import torch.nn as nn
import numpy as np
import random
from argparse import ArgumentParser
import time

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__()
        self.add_argument("--M",
                            help="M dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--K",
                            help="K dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--N",
                            help="N dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--signed_input",
                            help="Use (signed) torch.qint8 for inputs instead of (unsigned) torch.quint8",
                            action="store_true"
        )
        self.add_argument("--seed",
                          help="Random seed",
                          type=int,
                          default=42
        )
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class LinearModel(nn.Module):
    def __init__(self, K, N):
        super(LinearModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc = nn.Linear(K, N)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

def quantize_model(model, args):
    qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False,
            dtype=torch.qint8 if args.signed_input else torch.quint8),
            weight=default_weight_observer,
    )
    # Prepare the model for static quantization
    # Specify quantization configurations
    model.qconfig = qconfig
    model_prepared = torch.quantization.prepare(model_fp32)

    # Calibrate the model with sample inputs
    # Example input data for calibration
    with torch.no_grad():
        sample_data = torch.randn(args.M, args.K)
        model_prepared(sample_data)
    # Convert the prepared model to a quantized model
    model_quantized = torch.quantization.convert(model_prepared)
    return model_quantized


if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    
    set_seed(args.seed)
    model_fp32 = LinearModel(args.K, args.N)
    model_quantized = quantize_model(model_fp32, args)
    
    inputs = torch.randn(args.M, args.K)
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model_quantized(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model_quantized(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("M,K,N,signed = ", args.M, args.K, args.N, args.signed_input)
    print("Min Times (ms) = ", min(times))
    print("Mean Times (ms) = ", np.mean(times))

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @malfet @snadampal @milpuz01

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category labels Mar 5, 2025
Copy link

pytorch-bot bot commented Mar 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148585

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 52 Pending

As of commit 52b8690 with merge base 46f096b (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
fadara01 added a commit that referenced this pull request Mar 5, 2025
This enables a fast path for eager mode dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PR #126687 enabled an optimized implementation for qlinear_dynamic for aarch64 through ideep → oneDNN → ACL which improved performance by ~10x compared to the previous implementation.
However, the current qlinear_dynamic path (ideep → oneDNN → ACL) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (lowp_gemm) API - for example, ACL's lowp_gemm objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with qlinear_dynamic. This approach yields an average speedup (averaged over context_lengths of 2^3 up to 2^9) of ~ 50% for bert-base-uncased, bert-large-uncased, roberta-base, distilbert-base-uncased with 16 threads on a Neoverse-V1 (with transformers==4.48).
To achieve this we introduce PackedLinearWeightsACL (as a subclasses of PackedLinearWeightsOnednn ) with an implementation of qlinear_dynamic that uses ACL directly, while qlinear still follows the oneDNN path.

ghstack-source-id: 05435a0
Pull Request resolved: #148585

Enable a fast path for (static) qlinear for AArch64 through ACL directly.

This enables a fast path for eager mode static quantization for AArch64 through Arm Compute Library (ACL) directly.

PR #145942 addressed the high overhead in qlinear_dynamic on AArch64 (due to redundant weight pretranspositions and reductions) by enabling a path that calls ACL directly.
This does the same thing but for (static) qlinear.

ghstack-source-id: 05435a0
Pull Request resolved: #148586
fadara01 added a commit that referenced this pull request Mar 5, 2025
This enables a fast path for eager mode static/dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PR #126687, #139887 enabled an optimized implementation for qlinear[_dynamic] for AArch64 through ideep → oneDNN → ACL which improved performance by ~10x compared to the previous implementation.
However, the current qlinear[_dynamic] path (ideep → oneDNN → ACL) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (lowp_gemm) API - for example, ACL's lowp_gemm objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.

This PR addresses the sub-optimalities above by introducing PackedLinearWeightsACL (as a subclasses of PackedLinearWeightsOnednn ) with an implementation of qlinear[_dynamic] that uses ACL directly.
ghstack-source-id: 05435a0
Pull Request resolved: #148585

ghstack-source-id: 05435a0
Pull Request resolved: #148586
fadara01 added a commit that referenced this pull request Mar 6, 2025
This enables a fast path for eager mode static/dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PR #126687, #139887 enabled an optimized implementation for qlinear[_dynamic] for AArch64 through ideep → oneDNN → ACL which improved performance by ~10x compared to the previous implementation.
However, the current qlinear[_dynamic] path (ideep → oneDNN → ACL) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (lowp_gemm) API - for example, ACL's lowp_gemm objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.

This PR addresses the sub-optimalities above by introducing PackedLinearWeightsACL (as a subclasses of PackedLinearWeightsOnednn ) with an implementation of qlinear[_dynamic] that uses ACL directly.
ghstack-source-id: 05435a0
Pull Request resolved: #148585

ghstack-source-id: 05435a0
Pull Request resolved: #148586
@fadara01 fadara01 changed the title Enable fast qlinear_dynamic path for AArch64 through ACL directly Enable fast qlinear static/dynamic path for AArch64 through ACL directly Mar 6, 2025
@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 6, 2025

@pytorchbot label "module: arm"

@pytorch-bot pytorch-bot bot added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Mar 6, 2025
@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 6, 2025

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Mar 6, 2025
[ghstack-poisoned]
@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 7, 2025

This is the ghstack replacement for #145942 and #147337 - I thought it'd be easier for review and for addressing comments to have both in the same PR given they're very similar.

I addressed all your comments here, @malfet

@fadara01
Copy link
Collaborator Author

fadara01 commented Mar 10, 2025

@jerryzh168, @digantdesai could you guys please have a look at this?
We'd really like to land it before the branch cut. Thanks for you help!

[ghstack-poisoned]
fadara01 and others added 4 commits March 14, 2025 15:05
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fadara01 please re-push the changes to make code compilable again(and thank you for accepting my suggestions, but it does not work with stacked PRs, you'll have to manually integrate them and re-push)

[ghstack-poisoned]
@fadara01
Copy link
Collaborator Author

@fadara01 please re-push the changes to make code compilable again(and thank you for accepting my suggestions, but it does not work with stacked PRs, you'll have to manually integrate them and re-push)

Done, thank you @malfet!

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for all the changes, but I hope it's a temporary (or if it's permanently we need to figure out how to kill oneDNN codepaths)

@malfet
Copy link
Contributor

malfet commented Mar 17, 2025

@pytorchbot merge -f "Let's test in trunk"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Mar 18, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

Pull Request resolved: #148653
Approved by: https://github.com/malfet
ghstack dependencies: #148585
@fadara01
Copy link
Collaborator Author

@pytorchbot cherry-pick --onto release/2.7 -c "Special Dispensation"

Copy link

pytorch-bot bot commented Mar 18, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot cherry-pick: error: argument -c/--classification: invalid choice: 'Special Dispensation' (choose from 'regression', 'critical', 'fixnewfeature', 'docs', 'release')

usage: @pytorchbot cherry-pick --onto ONTO [--fixes FIXES] -c
                               {regression,critical,fixnewfeature,docs,release}

Try @pytorchbot --help for more info.

@fadara01
Copy link
Collaborator Author

@pytorchbot cherry-pick --onto release/2.7

Copy link

pytorch-bot bot commented Mar 18, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot cherry-pick: error: the following arguments are required: -c/--classification

usage: @pytorchbot cherry-pick --onto ONTO [--fixes FIXES] -c
                               {regression,critical,fixnewfeature,docs,release}

Try @pytorchbot --help for more info.

fadara01 added a commit to fadara01/pytorch that referenced this pull request Mar 18, 2025
…tly (pytorch#148585)

This enables a fast path for eager mode static/dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PRs pytorch#126687, pytorch#139887 enabled an optimized implementation for `qlinear` and `qlinear_dynamic` for aarch64 through `ideep → oneDNN → ACL` which improved performance by ~10x compared to the previous implementation.
However, the current `qlinear` and `qlinear_dynamic` path (`ideep → oneDNN → ACL`) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (`lowp_gemm`) API - for example, ACL's `lowp_gemm` objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with `qlinear` and `qlinear_dynamic`.

- **For `qlinear_dynamic` (dynamically quantized matmuls):**

This PR yields an ****average speedup** (averaged over context_lengths of 2^3 up to 2^9) of ~ **50%** for `bert-base-uncased`, `bert-large-uncased`, `roberta-base`, `distilbert-base-uncased`** with 16 threads on a Neoverse-V1 (with transformers==4.48) for the benchmarking script below:
```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from transformers import AutoModel, AutoConfig
import time
import numpy as np
from argparse import ArgumentParser

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__(description="huggingface model")
        self.add_argument("--context_length",
                            help="context length - number of input tokens",
                            type=int,
                            default=64
        )
        self.add_argument("--model",
                            help="model checkpoint - i.e. 'bert-base-uncased'",
                            type=str,
                            default=None)
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    model_name = args.model
    config = AutoConfig.from_pretrained(model_name)
    batch_size = 1
    model = AutoModel.from_pretrained(model_name)
    model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    model.eval()
    inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu")
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("Model = ", model_name)
    print("Context Length = ", args.context_length)
    print("Min (ms) = ", min(times))
    print("Mean (ms) = ", np.mean(times))
```

- **For `qlinear` (statically quantized matmuls):**

This PR yields an **average speedup of 2x for signed activations (`s8s8s8`) and 95x for unsigned activations (u8s8u8)** on a Neoverse-V1 with 16 threads for the benchmarking script below.
The averages are over for all combinations of `M = [8, 16, ..., 512]`, `K = [768, 1024, 2048, 4096]`, `N = [768, 1024, 2048, 4096]`.
The astronomical speedup for unsigned activation is because oneDNN v3.7 does not have an optimized implementation for `u8s8u8` on AArch64.

```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
import torch.nn as nn
from torch.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, default_weight_observer
import torch
import torch.nn as nn
import numpy as np
import random
from argparse import ArgumentParser
import time

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__()
        self.add_argument("--M",
                            help="M dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--K",
                            help="K dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--N",
                            help="N dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--signed_input",
                            help="Use (signed) torch.qint8 for inputs instead of (unsigned) torch.quint8",
                            action="store_true"
        )
        self.add_argument("--seed",
                          help="Random seed",
                          type=int,
                          default=42
        )
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class LinearModel(nn.Module):
    def __init__(self, K, N):
        super(LinearModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc = nn.Linear(K, N)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

def quantize_model(model, args):
    qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False,
            dtype=torch.qint8 if args.signed_input else torch.quint8),
            weight=default_weight_observer,
    )
    # Prepare the model for static quantization
    # Specify quantization configurations
    model.qconfig = qconfig
    model_prepared = torch.quantization.prepare(model_fp32)

    # Calibrate the model with sample inputs
    # Example input data for calibration
    with torch.no_grad():
        sample_data = torch.randn(args.M, args.K)
        model_prepared(sample_data)
    # Convert the prepared model to a quantized model
    model_quantized = torch.quantization.convert(model_prepared)
    return model_quantized

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()

    set_seed(args.seed)
    model_fp32 = LinearModel(args.K, args.N)
    model_quantized = quantize_model(model_fp32, args)

    inputs = torch.randn(args.M, args.K)
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model_quantized(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model_quantized(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("M,K,N,signed = ", args.M, args.K, args.N, args.signed_input)
    print("Min Times (ms) = ", min(times))
    print("Mean Times (ms) = ", np.mean(times))
```

Pull Request resolved: pytorch#148585
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
(cherry picked from commit 08a644a)
fadara01 added a commit to fadara01/pytorch that referenced this pull request Mar 18, 2025
…48653)

This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

Pull Request resolved: pytorch#148653
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#148585

(cherry picked from commit 6c2db8f)
jurgen-paul pushed a commit to jurgen-paul/pytorch.git.file that referenced this pull request Mar 19, 2025
This enables a fast path for eager mode dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PR #126687 enabled an optimized implementation for qlinear_dynamic for aarch64 through ideep → oneDNN → ACL which improved performance by ~10x compared to the previous implementation.
However, the current qlinear_dynamic path (ideep → oneDNN → ACL) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (lowp_gemm) API - for example, ACL's lowp_gemm objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with qlinear_dynamic. This approach yields an average speedup (averaged over context_lengths of 2^3 up to 2^9) of ~ 50% for bert-base-uncased, bert-large-uncased, roberta-base, distilbert-base-uncased with 16 threads on a Neoverse-V1 (with transformers==4.48).
To achieve this we introduce PackedLinearWeightsACL (as a subclasses of PackedLinearWeightsOnednn ) with an implementation of qlinear_dynamic that uses ACL directly, while qlinear still follows the oneDNN path.

ghstack-source-id: 555097f
Pull Request resolved: pytorch/pytorch#148585
pytorch-bot bot pushed a commit that referenced this pull request Mar 27, 2025
…tly (#148585)

This enables a fast path for eager mode static/dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PRs #126687, #139887 enabled an optimized implementation for `qlinear` and `qlinear_dynamic` for aarch64 through `ideep → oneDNN → ACL` which improved performance by ~10x compared to the previous implementation.
However, the current `qlinear` and `qlinear_dynamic` path (`ideep → oneDNN → ACL`) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (`lowp_gemm`) API - for example, ACL's `lowp_gemm` objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with `qlinear` and `qlinear_dynamic`.

- **For `qlinear_dynamic` (dynamically quantized matmuls):**

This PR yields an ****average speedup** (averaged over context_lengths of 2^3 up to 2^9) of ~ **50%** for `bert-base-uncased`, `bert-large-uncased`, `roberta-base`, `distilbert-base-uncased`** with 16 threads on a Neoverse-V1 (with transformers==4.48) for the benchmarking script below:
```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from transformers import AutoModel, AutoConfig
import time
import numpy as np
from argparse import ArgumentParser

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__(description="huggingface model")
        self.add_argument("--context_length",
                            help="context length - number of input tokens",
                            type=int,
                            default=64
        )
        self.add_argument("--model",
                            help="model checkpoint - i.e. 'bert-base-uncased'",
                            type=str,
                            default=None)
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    model_name = args.model
    config = AutoConfig.from_pretrained(model_name)
    batch_size = 1
    model = AutoModel.from_pretrained(model_name)
    model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    model.eval()
    inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu")
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("Model = ", model_name)
    print("Context Length = ", args.context_length)
    print("Min (ms) = ", min(times))
    print("Mean (ms) = ", np.mean(times))
```

- **For `qlinear` (statically quantized matmuls):**

This PR yields an **average speedup of 2x for signed activations (`s8s8s8`) and 95x for unsigned activations (u8s8u8)** on a Neoverse-V1 with 16 threads for the benchmarking script below.
The averages are over for all combinations of `M = [8, 16, ..., 512]`, `K = [768, 1024, 2048, 4096]`, `N = [768, 1024, 2048, 4096]`.
The astronomical speedup for unsigned activation is because oneDNN v3.7 does not have an optimized implementation for `u8s8u8` on AArch64.

```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
import torch.nn as nn
from torch.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, default_weight_observer
import torch
import torch.nn as nn
import numpy as np
import random
from argparse import ArgumentParser
import time

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__()
        self.add_argument("--M",
                            help="M dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--K",
                            help="K dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--N",
                            help="N dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--signed_input",
                            help="Use (signed) torch.qint8 for inputs instead of (unsigned) torch.quint8",
                            action="store_true"
        )
        self.add_argument("--seed",
                          help="Random seed",
                          type=int,
                          default=42
        )
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class LinearModel(nn.Module):
    def __init__(self, K, N):
        super(LinearModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc = nn.Linear(K, N)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

def quantize_model(model, args):
    qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False,
            dtype=torch.qint8 if args.signed_input else torch.quint8),
            weight=default_weight_observer,
    )
    # Prepare the model for static quantization
    # Specify quantization configurations
    model.qconfig = qconfig
    model_prepared = torch.quantization.prepare(model_fp32)

    # Calibrate the model with sample inputs
    # Example input data for calibration
    with torch.no_grad():
        sample_data = torch.randn(args.M, args.K)
        model_prepared(sample_data)
    # Convert the prepared model to a quantized model
    model_quantized = torch.quantization.convert(model_prepared)
    return model_quantized

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()

    set_seed(args.seed)
    model_fp32 = LinearModel(args.K, args.N)
    model_quantized = quantize_model(model_fp32, args)

    inputs = torch.randn(args.M, args.K)
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model_quantized(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model_quantized(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("M,K,N,signed = ", args.M, args.K, args.N, args.signed_input)
    print("Min Times (ms) = ", min(times))
    print("Mean Times (ms) = ", np.mean(times))
```

Pull Request resolved: #148585
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
(cherry picked from commit 08a644a)
pytorch-bot bot pushed a commit that referenced this pull request Mar 27, 2025
This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

Pull Request resolved: #148653
Approved by: https://github.com/malfet
ghstack dependencies: #148585

(cherry picked from commit 6c2db8f)
malfet added a commit that referenced this pull request Mar 28, 2025
…ough ACL directly. (#149435)

* Enable fast qlinear static/dynamic path for AArch64 through ACL directly (#148585)

This enables a fast path for eager mode static/dynamic quantization for AArch64 through Arm Compute Library (ACL) directly.

Context: PRs #126687, #139887 enabled an optimized implementation for `qlinear` and `qlinear_dynamic` for aarch64 through `ideep → oneDNN → ACL` which improved performance by ~10x compared to the previous implementation.
However, the current `qlinear` and `qlinear_dynamic` path (`ideep → oneDNN → ACL`) suffers from high overhead due to the API friction between the stateless oneDNN API and the stateful ACL low-precision GEMM (`lowp_gemm`) API - for example, ACL's `lowp_gemm` objects cache information like weights reduction or weights in optimized memory format which oneDNN does not allow due to its stateless nature.
Hence, ACL currently runs a (redundant) sum of columns and pre-transposition (to the gemm kerne's optimal format) for each GEMM operation.
This PR addresses the sub-optimalities above by integrating ACL directly with `qlinear` and `qlinear_dynamic`.

- **For `qlinear_dynamic` (dynamically quantized matmuls):**

This PR yields an ****average speedup** (averaged over context_lengths of 2^3 up to 2^9) of ~ **50%** for `bert-base-uncased`, `bert-large-uncased`, `roberta-base`, `distilbert-base-uncased`** with 16 threads on a Neoverse-V1 (with transformers==4.48) for the benchmarking script below:
```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from transformers import AutoModel, AutoConfig
import time
import numpy as np
from argparse import ArgumentParser

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__(description="huggingface model")
        self.add_argument("--context_length",
                            help="context length - number of input tokens",
                            type=int,
                            default=64
        )
        self.add_argument("--model",
                            help="model checkpoint - i.e. 'bert-base-uncased'",
                            type=str,
                            default=None)
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    model_name = args.model
    config = AutoConfig.from_pretrained(model_name)
    batch_size = 1
    model = AutoModel.from_pretrained(model_name)
    model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
    model.eval()
    inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu")
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("Model = ", model_name)
    print("Context Length = ", args.context_length)
    print("Min (ms) = ", min(times))
    print("Mean (ms) = ", np.mean(times))
```

- **For `qlinear` (statically quantized matmuls):**

This PR yields an **average speedup of 2x for signed activations (`s8s8s8`) and 95x for unsigned activations (u8s8u8)** on a Neoverse-V1 with 16 threads for the benchmarking script below.
The averages are over for all combinations of `M = [8, 16, ..., 512]`, `K = [768, 1024, 2048, 4096]`, `N = [768, 1024, 2048, 4096]`.
The astronomical speedup for unsigned activation is because oneDNN v3.7 does not have an optimized implementation for `u8s8u8` on AArch64.

```
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
import torch.nn as nn
from torch.quantization import QConfig
from torch.ao.quantization.observer import HistogramObserver, default_weight_observer
import torch
import torch.nn as nn
import numpy as np
import random
from argparse import ArgumentParser
import time

class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__()
        self.add_argument("--M",
                            help="M dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--K",
                            help="K dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--N",
                            help="N dimension",
                            type=int,
                            default=64
        )
        self.add_argument("--signed_input",
                            help="Use (signed) torch.qint8 for inputs instead of (unsigned) torch.quint8",
                            action="store_true"
        )
        self.add_argument("--seed",
                          help="Random seed",
                          type=int,
                          default=42
        )
        self.add_argument("--iters",
                          help="benchmark iterations",
                          default=500)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class LinearModel(nn.Module):
    def __init__(self, K, N):
        super(LinearModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc = nn.Linear(K, N)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

def quantize_model(model, args):
    qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False,
            dtype=torch.qint8 if args.signed_input else torch.quint8),
            weight=default_weight_observer,
    )
    # Prepare the model for static quantization
    # Specify quantization configurations
    model.qconfig = qconfig
    model_prepared = torch.quantization.prepare(model_fp32)

    # Calibrate the model with sample inputs
    # Example input data for calibration
    with torch.no_grad():
        sample_data = torch.randn(args.M, args.K)
        model_prepared(sample_data)
    # Convert the prepared model to a quantized model
    model_quantized = torch.quantization.convert(model_prepared)
    return model_quantized

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()

    set_seed(args.seed)
    model_fp32 = LinearModel(args.K, args.N)
    model_quantized = quantize_model(model_fp32, args)

    inputs = torch.randn(args.M, args.K)
    times = []
    with torch.no_grad():
        # warmup
        for _ in range(10):
            model_quantized(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model_quantized(inputs)
            times.append((time.time_ns() - s) / 1e6)

    print("M,K,N,signed = ", args.M, args.K, args.N, args.signed_input)
    print("Min Times (ms) = ", min(times))
    print("Mean Times (ms) = ", np.mean(times))
```

Pull Request resolved: #148585
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
(cherry picked from commit 08a644a)

* Enable qint8 and quint8 add for AArch64 using ACL directly (#148653)

This enables qint8 and quint8 add for AArch64 through Arm Compute Library (ACL) directly.
Relative performance improvement using OMP_NUM_THREADS=1 is ~15x, using OMP_NUM_THREADS=32 it’s ~5.4x.

Co-authored-by: David Svantesson <david.svantesson-yeung@arm.com>

Pull Request resolved: #148653
Approved by: https://github.com/malfet
ghstack dependencies: #148585

(cherry picked from commit 6c2db8f)

* [Build] Guard per-op headers in ACLUtils.cpp (#149417)

To fix internal build failures, where per-op headers are not generated.
We really should have lint for something like that.

Test Plan: CI

Reviewed By: izaitsevfb

Differential Revision: D71406882

Pull Request resolved: #149417
Approved by: https://github.com/Skylion007, https://github.com/izaitsevfb

(cherry picked from commit 5db3a4a)

---------

Co-authored-by: Nikita Shulga <nshulga@meta.com>
@github-actions github-actions bot deleted the gh/fadara01/5/head branch April 20, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
arm priority ciflow/linux-aarch64 linux aarch64 CI workflow Merged module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: quantization release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants