Skip to content

Enables the per_tensor lowering patterns for weight per_packing #2391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

choudhary-devang
Copy link
Collaborator

@choudhary-devang choudhary-devang commented Jun 17, 2025

This Pr is an extension of #2139 pr,

Major changes:
1)Introduced lowering pattern for "per_tensor" quantized weights.
2) Modified the original api get_default_arm_inductor_quantization_config to add user choice of using "per_tensor" and "per_channel" granularity in model weight's quantization.

supported shapes:

  1. s8:s8:f32 - (per_tensor / per_channel) input : s8, weight : s8, output : f32
  2. u8:s8:f32 - (per_tensor / per_channel ) input : u8, weight : s8, output : f32

Tested and verified for different models:

  • Bert model
  • Resnet model
  • Vit model
  • Custum models

Example script for refence:

import torch
from transformers import BertModel
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as aiq
from torchao.quantization.pt2e.quantizer.arm_inductor_quantizer import ArmInductorQuantizer
import torch._inductor.config as config
# Enable C++ wrapper for Inductor
config.cpp_wrapper = True
config.freezing=True

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)

# Set the model to eval mode
model = model.eval()

# Create the data, using dummy data here as an example
traced_bs = 32
seq_length = 128
x = torch.randint(0, 10000, (traced_bs, seq_length))
attention_mask = torch.ones((traced_bs, seq_length))
example_inputs = (x, attention_mask)

# Capture the FX Graph to be quantized
with torch.no_grad():
    exported_model = torch.export.export_for_training(model, example_inputs).module()
    # Set up the quantizer and prepare the model for post-training quantization
    quantizer = ArmInductorQuantizer()
    quantizer.set_global(aiq.get_default_arm_inductor_quantization_config(is_dynamic=True, is_per_channel=True))
    prepared_model = prepare_pt2e(exported_model, quantizer)
    converted_model = convert_pt2e(prepared_model)
    converted_model = torch.compile(converted_model)
    with torch.profiler.profile( record_shapes=True) as prof:
        for _ in range(200):
            converted_model(*example_inputs)
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_time_total"))

Results

Model FP32 quant (int8) Speedup
resnet 62.967 44.482 1.415561
bert 103.879 71.953 1.443706
vit 69.031 59.973 1.151035

All time in sec, Taken on Aws Graviton 3E 32 core Instance

Pip list

image

cc: @jerryzh168, @fadara01, @Xia-Weiwen

Copy link

pytorch-bot bot commented Jun 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2391

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e51e9ec with merge base 11ce634 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 17, 2025
@choudhary-devang
Copy link
Collaborator Author

Hi @jerryzh168, @fadara01, @Xia-Weiwen can you please review this pr
thankyou

@jerryzh168
Copy link
Contributor

Thanks, can you add some tests in https://github.com/pytorch/ao/tree/main/test/quantization/pt2e

@jerryzh168 jerryzh168 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jun 26, 2025
@choudhary-devang
Copy link
Collaborator Author

Hi @jerryzh168,
I have added the testcase specific for the changes and to keep them separate i have added the file like : -ao/test/quantization/pt2e/test_arm_inductor_quantizer_per_tensor.py
can you please review this,
thankyou

@fadara01
Copy link

Thanks for your PR!
Do we see any speedups (against fp32) for e.g. bert / resnet50 as a result of this lowering?
Do we need to do any work in pytorch - qconv and qlinear to support such lowerings?

@choudhary-devang
Copy link
Collaborator Author

Thanks for your PR! Do we see any speedups (against fp32) for e.g. bert / resnet50 as a result of this lowering? Do we need to do any work in pytorch - qconv and qlinear to support such lowerings?

Hi @fadara01, Thanks for the response.
I have updated the description to include some of the details, we don't need any changes in pytorch.
for my experimentation i have used pip install torch torchvision.

to recreate the experiment
Fp32 script

import torch
from transformers import BertModel

# model loading
model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
# Create the data, using dummy data here as an example
traced_bs = 32
seq_length = 128
x = torch.randint(0, 10000, (traced_bs, seq_length))
attention_mask = torch.ones((traced_bs, seq_length))
example_inputs = (x, attention_mask)

# Inference 
with torch.no_grad():
    model = torch.compile(model)
    with torch.profiler.profile( record_shapes=True) as prof:
        for _ in range(200):
                model(x)
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_time_total"))

quant script

import torch
from transformers import BertModel
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as aiq
from torchao.quantization.pt2e.quantizer.arm_inductor_quantizer import ArmInductorQuantizer
import torch._inductor.config as config
# Enable C++ wrapper for Inductor
config.cpp_wrapper = True
config.freezing=True

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)

# Set the model to eval mode
model = model.eval()

# Create the data, using dummy data here as an example
traced_bs = 32
seq_length = 128
x = torch.randint(0, 10000, (traced_bs, seq_length))
attention_mask = torch.ones((traced_bs, seq_length))
example_inputs = (x, attention_mask)

# Capture the FX Graph to be quantized
with torch.no_grad():
    exported_model = torch.export.export_for_training(model, example_inputs).module()
    # Set up the quantizer and prepare the model for post-training quantization
    quantizer = ArmInductorQuantizer()
    quantizer.set_global(aiq.get_default_arm_inductor_quantization_config(is_dynamic=True, is_per_channel=True))
    prepared_model = prepare_pt2e(exported_model, quantizer)
    converted_model = convert_pt2e(prepared_model)
    converted_model = torch.compile(converted_model)
    with torch.profiler.profile( record_shapes=True) as prof:
        for _ in range(200):
            converted_model(*example_inputs)
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_time_total"))

current setup
**kernel **
onednn_verbose,v1,primitive,exec,cpu,matmul,lowp_gemm:acl,undef,src:s8:a:blocked:ab::f0 wei:s8::blocked:ab::f0 bia:f32:a:blocked:ab::f0_mask2 dst:f32:a:blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+wei:0:f32 attr-zero-points:src0:0:s32,,50x512:512x1000,0.224854

@fadara01
Copy link

fadara01 commented Jul 21, 2025

Ahhh that's amazing! I remember doing a PoC for this exact thing back in the day and I had to tweak qlinear/qconv, hence my question.

@choudhary-devang
Copy link
Collaborator Author

Hi @jerryzh168, @fadara01, can you please approve and merge this change.
thankyou

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants