Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .remove_redundancy import RemoveRedundancy
from .replace_arange_args import ReplaceArangeArgs
from .replace_inf_values import ReplaceInfValues
from .seq_mse import SeqMSE
from .tag_quant_io import TagQuantIO


Expand Down Expand Up @@ -78,5 +79,6 @@
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceInfValues,
SeqMSE,
TagQuantIO,
]
223 changes: 223 additions & 0 deletions backends/qualcomm/_passes/seq_mse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import types
from contextlib import contextmanager

import torch
import torchao
from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
PerBlockParamObserver,
)
from executorch.exir.pass_base import ExportPass, PassResult
from torchao.quantization.pt2e import PerChannelMinMaxObserver


class SeqMseModule(torch.nn.Module):
"""
Args:
nominal_weight: Tensor
nominal parameters from operator
nominal_bias: Tensor
nominal parameters from operator
operator: fx.Node
operator to be executed
observer: UniformQuantizationObserverBase
parameter observer (specific for weight)
num_candidates: int
grids to search minimal mse loss
"""

def __init__(
self,
nominal_weight,
nominal_bias,
operator,
observer,
num_candidates,
):
super().__init__()
self.nominal_weight = nominal_weight
self.nominal_bias = nominal_bias
self.observer = observer
self.steps = torch.linspace(
1 / num_candidates, 1, steps=num_candidates
).tolist()
self.operator = self._make_operator(operator)
self.best_candidate_step = 1.0

def _make_operator(self, aten_op):
if aten_op.target == torch.ops.aten.conv2d.default:
stride = [1, 1] if len(aten_op.args) < 4 else aten_op.args[3]
padding = [0, 0] if len(aten_op.args) < 5 else aten_op.args[4]
dilation = [1, 1] if len(aten_op.args) < 6 else aten_op.args[5]
groups = 1 if len(aten_op.args) < 7 else aten_op.args[6]
has_bias = self.nominal_bias is not None
module = torch.nn.Conv2d(
in_channels=self.nominal_weight.shape[1],
out_channels=self.nominal_weight.shape[0],
kernel_size=self.nominal_weight.shape[-2:],
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=has_bias,
)
module.weight.data = self.nominal_weight
if has_bias:
module.bias.data = self.nominal_bias
return module
else:
raise NotImplementedError(f"target of {aten_op.target} is not implemented")

def _per_block_qdq(self, scale, zero_point):
return torchao.quantization.quant_primitives._fake_quantize_affine(
input=self.nominal_weight,
block_size=self.observer.block_size,
scale=scale,
zero_point=zero_point,
quant_dtype=self.observer.dtype,
quant_min=self.observer.quant_min,
quant_max=self.observer.quant_max,
)

def _per_channel_qdq(self, scale, zero_point):
return torch.fake_quantize_per_channel_affine(
input=self.nominal_weight,
scale=scale,
zero_point=zero_point,
axis=0,
quant_min=self.observer.quant_min,
quant_max=self.observer.quant_max,
)

def _fake_quant(self, scale, zero_point):
dispatcher = {
PerChannelMinMaxObserver: self._per_channel_qdq,
PerBlockParamObserver: self._per_block_qdq,
}
return dispatcher[type(self.observer)](scale, zero_point)

def _find_best_candidate(self, nominal_input, nominal_output):
# calculate current baseline
scale, zero_point = self.observer.calculate_qparams()
zero_point = zero_point.to(torch.int32)
self.operator.weight.data = self._fake_quant(scale, zero_point)
candidate, current_loss = (
1,
torch.nn.functional.mse_loss(
self.operator(nominal_input), nominal_output
).item(),
)
for step in self.steps:
self.operator.weight.data = self._fake_quant(scale * step, zero_point)
loss = torch.nn.functional.mse_loss(
self.operator(nominal_input), nominal_output
).item()
if loss < current_loss:
candidate, current_loss = step, loss
return candidate

def forward(self, nominal_input, nominal_output):
self.best_candidate_step = self._find_best_candidate(
nominal_input=nominal_input, nominal_output=nominal_output
)


class InsertSeqMse(ExportPass):
"""
Insert Seq Mse Observer to find the best quant config for certain node's weight.
"""

seq_mse_ops = {torch.ops.aten.conv2d.default}

def __init__(self, num_candidates=1000):
super(InsertSeqMse, self).__init__()
self.num_candidates = num_candidates

def _insert_seq_mse(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
count = 0
for node in graph_module.graph.nodes:
if node.target in self.seq_mse_ops:
# extract observer
weight_node_obs = node.args[1]
observer = getattr(graph_module, weight_node_obs.name)
# extract parameters
weight_node = weight_node_obs.args[0]
weight_tensor = graph_module.get_parameter(weight_node.target).detach()
bias_tensor = None
if len(node.args) > 2 and node.args[2] is not None:
bias_tensor = graph_module.get_parameter(
node.args[2].args[0].target
).detach()

with graph_module.graph.inserting_after(node):
seq_mse_mod = SeqMseModule(
nominal_weight=weight_tensor,
nominal_bias=bias_tensor,
operator=node,
observer=observer,
num_candidates=self.num_candidates,
)
module_name = f"seq_mse_{count}"
count += 1
setattr(graph_module, module_name, seq_mse_mod)
input_nodes = (node.args[0], node)
graph_module.graph.create_node(
"call_module", module_name, input_nodes, {}
)

def call(self, graph_module: torch.fx.GraphModule):
self._insert_seq_mse(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)


class RemoveSeqMse(ExportPass):
"""
Remove Seq Mse before invoking convert_pt2e and update final quantization encoding.
"""

def __init__(self):
super(RemoveSeqMse, self).__init__()
Comment on lines +185 to +186
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we just omit this?


def _remove_seq_mse(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
node_to_erase = []
for node in graph_module.graph.nodes:
if node.op == "call_module":
# try extracting SeqMse module
module = getattr(graph_module, node.target)
if isinstance(module, SeqMseModule):
# rewrite observer method for pre-calculated scale
scale, zero_point = module.observer.calculate_qparams()
module.observer.updated_encoding = (
scale * module.best_candidate_step,
zero_point,
)
module.observer.calculate_qparams = types.MethodType(
lambda s: s.updated_encoding, module.observer
)
node_to_erase.append(node)

for node in node_to_erase:
graph_module.graph.erase_node(node)

def call(self, graph_module: torch.fx.GraphModule):
self._remove_seq_mse(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)


@contextmanager
def SeqMSE(prepared_gm, num_candidates):
prepared_gm = InsertSeqMse(num_candidates)(prepared_gm).graph_module
try:
yield
finally:
prepared_gm = RemoveSeqMse()(prepared_gm).graph_module
19 changes: 19 additions & 0 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@
)


def annotate_down_proj(
gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
):
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten.conv2d.default
and any(s in node.meta["stack_trace"] for s in ["forward_feedfoward_conv"])
and node.args[0].target == torch.ops.aten.mul.Tensor
):
input_qspec_map = {}
input_qspec_map[node.args[0]] = quantization_config.input_activation
input_qspec_map[node.args[1]] = quantization_config.weight
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)


def annotate_eurobert(gm: torch.fx.GraphModule):
"""
QNN does not support int32 -> signed 16bit quant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
eps=eps,
**kwargs,
)
self.dtype = dtype
self.block_size = block_size
# TODO: expand this when QNN starts to support more configurations
self.bitwidth_of_scale = 4
Expand Down
71 changes: 46 additions & 25 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4636,6 +4636,33 @@ def test_qnn_backend_generate_optrace(self):
qhas_data = json.load(qhas_file)
self.assertIn("data", qhas_data)

def test_qnn_backend_seq_mse(self):
from executorch.backends.qualcomm._passes.seq_mse import SeqMSE

o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0
module = Conv2dSingle( # noqa: F405
in_channel=i_ch,
out_channel=o_ch,
kernel_size=kernel,
padding=padding,
)
sample_input = (torch.randn(1, i_ch, 1, o_ch),)
# per-channel / per-block
quantizers = [
make_quantizer(),
make_quantizer(quant_dtype=QuantDtype.use_16a4w_block),
]
quantizers[-1].set_block_size_map({"conv2d": (1, 32, 1, 1)})

for i, quantizer in enumerate(quantizers):
with self.subTest(i=i):
ep = torch.export.export(module, sample_input).module()
prepared = prepare_pt2e(ep, quantizer)
with SeqMSE(prepared, 100):
prepared(*sample_input)
converted = convert_pt2e(prepared)
self.lower_module_and_test_output(converted, sample_input)


class TestExampleLLMScript(TestQNN):
def test_static_gemma3_1b(self):
Expand Down Expand Up @@ -4709,7 +4736,7 @@ def test_static_gemma3_1b(self):
msg["inference_speed"], inference_speed_ref[self.model]
)

def test_llama3_2_1b(self):
def test_llama3_2_instruct(self):
if not self.required_envs():
self.skipTest("missing required envs")
assert (
Expand Down Expand Up @@ -4741,13 +4768,16 @@ def test_llama3_2_1b(self):
"--temperature",
"0",
"--decoder_model",
"llama3_2",
"llama3_2-1b_instruct",
"--model_mode",
"hybrid",
"--prefill_ar_len",
"32",
"kv",
"--max_seq_len",
"512",
"1024",
"--eval_perplexity",
"--tasks",
"wikitext",
"--limit",
"1",
]
if self.compile_only:
cmds.extend(["--compile_only"])
Expand All @@ -4760,7 +4790,6 @@ def test_llama3_2_1b(self):
if self.pre_gen_pte:
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])

golden_start_with = "<|start_header_id|>user<|end_header_id|>"
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
Expand All @@ -4769,19 +4798,17 @@ def test_llama3_2_1b(self):
if "Error" in msg:
self.fail(msg["Error"])
else:
if not self.compile_only:
model_out = msg["result"][0]
self.assertTrue(
model_out.startswith(golden_start_with),
f"Expected Output: {golden_start_with}. Actual Output: {model_out}",
inference_speed_ref = {"SM8650": 37, "SM8750": 49}
if (
not self.compile_only
and not self.enable_x86_64
and self.model in inference_speed_ref
):
self.assertLessEqual(msg["pte_size"], 1_500_000_000)
self.assertLessEqual(msg["wiki_ppl"], 15)
self.assertGreaterEqual(
msg["inference_speed"], inference_speed_ref[self.model]
)
# x86 does not allow weight sharing, so we don't check pte size.
# Inference speed on x86 is slow, so we only check when running on Android
if not self.enable_x86_64:
pte_size = msg["pte_size"]
self.assertLessEqual(pte_size, 1_300_000_000) # 1.3GB
if not self.compile_only and not self.enable_x86_64:
self.assertGreaterEqual(msg["inference_speed"], 66) # Lanai

def test_llama_stories_260k(self):
if not self.required_envs():
Expand Down Expand Up @@ -4976,12 +5003,6 @@ def test_static_phi4(self):
cmds.extend(["--enable_x86_64"])
if self.pre_gen_pte:
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
cmds.extend(
[
"--quant_attrs_path",
f"{self.pre_gen_pte}/kv_llama_qnn_quant_attrs.json",
]
)

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
Expand Down
Loading
Loading