Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ runtime.python_library(
],
)

runtime.python_library(
name = "replace_qdq",
srcs = ["replace_qdq.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan:utils_lib",
"//executorch/exir:pass_base",
],
)

runtime.python_library(
name = "fuse_patterns",
srcs = ["fuse_patterns.py"],
Expand Down Expand Up @@ -150,6 +163,7 @@ runtime.python_library(
":remove_asserts",
":remove_local_scalar_dense",
":remove_redundant_ops",
":replace_qdq",
":squeeze_unsqueeze_inputs",
":tag_memory_meta_pass",
]
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from executorch.backends.vulkan._passes.remove_redundant_ops import (
RemoveRedundantOpsTransform,
)
from executorch.backends.vulkan._passes.replace_qdq import ReplaceQDQPass
from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import (
SqueezeUnsqueezeInputs,
)
Expand All @@ -36,6 +37,7 @@
"RemoveAssertsTransform",
"RemoveLocalScalarDenseOpsTransform",
"RemoveRedundantOpsTransform",
"ReplaceQDQPass",
"SqueezeUnsqueezeInputs",
"TagMemoryMetaPass",
]
93 changes: 93 additions & 0 deletions backends/vulkan/_passes/replace_qdq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.vulkan.utils as utils
import torch
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult


class ReplaceQDQPass(ExportPass):
"""
Replace standard quantize/dequantize ops with custom conv-specific ops when they
feed into/from quantized convolution operations. This optimization allows the
backend to handle quantization more efficiently for convolution operations.
"""

def __init__(self):
super(ReplaceQDQPass, self).__init__()

def call(self, graph_module: torch.fx.GraphModule):
# Track nodes that need to be replaced
nodes_to_replace = []

for node in graph_module.graph.nodes:
# Check if this is the custom quantized conv2d op
if node.target in [
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default,
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default,
]:
# Replace quantize op feeding into conv2d (first argument is the quantized input)
quantized_input_node = node.args[0]
if isinstance(
quantized_input_node, torch.fx.Node
) and utils.is_quant_node(quantized_input_node):
# Get the arguments from the original quantize node
input_tensor = quantized_input_node.args[0]
scale = quantized_input_node.args[1]
zero_point = quantized_input_node.args[2]

nodes_to_replace.append(
{
"old_node": quantized_input_node,
"new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,
"args": (input_tensor, scale, zero_point),
"node_type": "quantize_input",
}
)

# Find dequantize ops that consume the output of this conv2d
for user in node.users:
if utils.is_dequant_node(user):
# Get the arguments from the original dequantize node
scale = user.args[1]
zero_point = user.args[2]

nodes_to_replace.append(
{
"old_node": user,
"new_target": exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default,
"args": (
node,
scale,
zero_point,
), # node is the conv2d output
"node_type": "dequantize_output",
}
)

# Apply the replacements
for replacement in nodes_to_replace:
old_node = replacement["old_node"]
new_target = replacement["new_target"]
new_args = replacement["args"]

with graph_module.graph.inserting_before(old_node):
new_node = graph_module.graph.create_node(
"call_function", new_target, args=new_args
)
new_node.meta = old_node.meta.copy()
old_node.replace_all_uses_with(new_node)

# Clean up the graph
graph_module.graph.eliminate_dead_code()
graph_module.recompile()

# Re-trace to validate everything is ok
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
158 changes: 139 additions & 19 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,46 +354,124 @@ def linear_q8ta_q8csw(
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)

#######################
## conv2d_q8ta_q8csw ##
#######################
############################
## conv2d_q8ta_q8csw_q8to ##
############################


def conv2d_q8ta_q8csw(
def conv2d_q8ta_q8csw_q8to(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
output_scale: float,
output_zero_point: int,
bias: Optional[torch.Tensor],
kernel_size: list,
stride: list,
padding: list,
dilation: list,
groups: int,
):
IC = x.shape[1]
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
)

# Calculate weight dimensions
OC = weights.shape[0]
assert OC % groups == 0, "Output channels must be divisible by groups"
IC_per_group = int(x.shape[1] / groups)
K_h, K_w = kernel_size[0], kernel_size[1]

canonical_weight_K_dim = K_h * K_w * IC
orig_weight_K_dim = K_h * K_w * IC_per_group
# Remove any padding added to in_features dim to align to a multiple of 4
if weights.shape[-1] > orig_weight_K_dim:
weights = weights[:, :orig_weight_K_dim]

# Remove any padding added to output channels dim to align to a multiple of 4
if weights.shape[-1] != canonical_weight_K_dim:
weights = weights[:, :canonical_weight_K_dim]
weight_scales = weight_scales[:canonical_weight_K_dim]
if weight_scales.shape[0] > OC:
weight_scales = weight_scales[:OC]
if bias is not None:
bias = bias[:canonical_weight_K_dim]
bias = bias[:OC]

# Reshape to original 4D format (OC, IC, H, W)
weights = weights.view(OC, IC_per_group, K_h, K_w)

weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
# Dequantize weights
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
weight_scales,
weight_zeros,
0, # axis=0 for output channel quantization
-127,
127,
torch.int8,
)

# Calculate dimensions
OC = weights.shape[0]
in_features = weights.shape[1]
IC = in_features // (K_h * K_w)
# Perform convolution
out = torch.nn.functional.conv2d(
x, weights, bias, stride, padding, dilation, groups
)

# Reshape to original 4D format (OC, IC, H, W)
weights = weights.view(OC, IC, K_h, K_w)
out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)

return out


name = "conv2d_q8ta_q8csw_q8to"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
float output_scale,
int output_zero_point,
Tensor? bias,
SymInt[] kernel_size,
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
"""
)
lib.impl(name, conv2d_q8ta_q8csw_q8to, "CompositeExplicitAutograd")
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)


def conv2d_q8ta_q8csw_q8to_dw(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
output_scale: float,
output_zero_point: int,
bias: Optional[torch.Tensor],
kernel_size: list,
stride: list,
padding: list,
dilation: list,
groups: int,
):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
)

# Restore weight to original data layout
K_h, K_w, OC = weights.shape
weights = weights.permute(2, 0, 1).reshape(OC, 1, K_h, K_w)

weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
# Dequantize weights
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
Expand All @@ -410,10 +488,14 @@ def conv2d_q8ta_q8csw(
x, weights, bias, stride, padding, dilation, groups
)

out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)

return out


name = "conv2d_q8ta_q8csw"
name = "conv2d_q8ta_q8csw_q8to_dw"
lib.define(
f"""
{name}(
Expand All @@ -423,6 +505,8 @@ def conv2d_q8ta_q8csw(
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
float output_scale,
int output_zero_point,
Tensor? bias,
SymInt[] kernel_size,
SymInt[] stride,
Expand All @@ -431,8 +515,8 @@ def conv2d_q8ta_q8csw(
SymInt groups) -> Tensor
"""
)
lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd")
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)
lib.impl(name, conv2d_q8ta_q8csw_q8to_dw, "CompositeExplicitAutograd")
conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name)

######################
## apply_rotary_emb ##
Expand All @@ -452,3 +536,39 @@ def apply_rotary_emb_impl(
)
lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd")
apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)

#############################
## quantize/dequantize ops ##
#############################


def quantize_q8ta_for_conv2d_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
):
return torch.ops.quantized_decomposed.quantize_per_tensor(
input, scale, zero_point, -128, 127, torch.int8
)


name = "quantize_q8ta_for_conv2d"
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
lib.impl(name, quantize_q8ta_for_conv2d_impl, "CompositeExplicitAutograd")
quantize_q8ta_for_conv2d_op = getattr(getattr(torch.ops, namespace), name)


def dequantize_q8to_from_conv2d_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
):
return torch.ops.quantized_decomposed.dequantize_per_tensor(
input, scale, zero_point, -128, 127, input.dtype
)


name = "dequantize_q8to_from_conv2d"
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd")
dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name)
Loading
Loading