Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,10 @@ jobs:
./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear
./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d
# Run e2e testing for selected operators. More operators will be tested via this
# route in the future.
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"
nxp-build-test:
name: nxp-build-test
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
Expand Down
14 changes: 14 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,19 @@ runtime.python_library(
],
)

runtime.python_library(
name = "fold_qdq",
srcs = ["fold_qdq.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan:utils_lib",
"//executorch/exir:pass_base",
],
)

runtime.python_library(
name = "fuse_patterns",
srcs = ["fuse_patterns.py"],
Expand All @@ -144,6 +157,7 @@ runtime.python_library(
"//executorch/examples/...",
],
deps = [
":fold_qdq",
":fuse_patterns",
":fuse_quantized_ops",
":insert_prepack_nodes",
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from executorch.backends.vulkan._passes.fold_qdq import FoldQDQPass
from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
FuseQuantizedOpsTransform,
Expand All @@ -30,6 +31,7 @@
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass

__all__ = [
"FoldQDQPass",
"FusePatternsPass",
"FuseQuantizedOpsTransform",
"insert_prepack_nodes",
Expand Down
41 changes: 41 additions & 0 deletions backends/vulkan/_passes/fold_qdq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.vulkan.utils as utils
import torch

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass


class FoldQDQPass(ExportPass):
"""
Erase Q/DQ chain introduced by PT2E quantization workflow. It is assumed that all
valid quant op patterns have already been fused before this pass.
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
super(FoldQDQPass, self).__init__()
self.edge_program = edge_program

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if utils.is_quant_node(node):
original_node = node.args[0]
assert isinstance(original_node, torch.fx.Node)
# For each direct user that is a dequant node, connect the original
# node to the users of the dequant node.
for user in node.users:
if utils.is_dequant_node(user):
dq_node = user
dq_node.replace_all_uses_with(original_node)

graph_module.recompile()
dead_code_elimination_pass(graph_module)
# Re-trace to validate everything is ok
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
131 changes: 131 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import executorch.backends.vulkan.patterns as vk_patterns
import torch.library

Expand Down Expand Up @@ -321,6 +323,135 @@ def linear_qta8a_qga4w(
lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd")
linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name)

#################
## qaqw_linear ##
#################


def linear_q8ta_q8csw(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
weight_scales,
weight_zeros,
0,
-127,
127,
torch.int8,
)

# Perform linear operation
out = torch.nn.functional.linear(x, weights)
if bias is not None:
out = out + bias

return out


name = "linear_q8ta_q8csw"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
Tensor? bias = None) -> Tensor
"""
)
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)

##################
## conv2d_q8ta_q8csw ##
##################


def conv2d_q8ta_q8csw(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
bias: Optional[torch.Tensor],
kernel_size: list,
stride: list,
padding: list,
dilation: list,
groups: int,
):
IC = x.shape[1]
K_h, K_w = kernel_size[0], kernel_size[1]

canonical_weight_K_dim = K_h * K_w * IC
# Remove any padding added to output channels dim to align to a multiple of 4
if weights.shape[-1] != canonical_weight_K_dim:
weights = weights[:, :canonical_weight_K_dim]
weight_scales = weight_scales[:canonical_weight_K_dim]
if bias is not None:
bias = bias[:canonical_weight_K_dim]

weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)

# Calculate dimensions
OC = weights.shape[0]
in_features = weights.shape[1]
IC = in_features // (K_h * K_w)

# Reshape to original 4D format (OC, IC, H, W)
weights = weights.view(OC, IC, K_h, K_w)

# Dequantize weights
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
weight_scales,
weight_zeros,
0, # axis=0 for output channel quantization
-127,
127,
torch.int8,
)

# Perform convolution
out = torch.nn.functional.conv2d(
x, weights, bias, stride, padding, dilation, groups
)

return out


name = "conv2d_q8ta_q8csw"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
Tensor? bias,
SymInt[] kernel_size,
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
"""
)
lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd")
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)

######################
## apply_rotary_emb ##
######################
Expand Down
40 changes: 40 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,19 @@ def register_int8_mm_op():
)


@update_features(
[
exir_ops.edge.et_vk.linear_q8ta_q8csw.default,
]
)
def register_qa_qw_linear():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
supports_prepacking=True,
supports_resize=False,
)


@update_features(
[
exir_ops.edge.et_vk.linear_weight_int4.default,
Expand Down Expand Up @@ -457,6 +470,33 @@ def register_convolution_op():
)


@update_features(
[
exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default,
]
)
def register_quantized_conv_op():
return OpFeatures(
inputs_storage=[
utils.CHANNELS_PACKED_TEXTURE, # input
utils.NO_STORAGE, # input_scale (non tensor)
utils.NO_STORAGE, # input_zero_point (non tensor)
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # bias (prepacked)
utils.NO_STORAGE, # kernel_size (non tensor)
utils.NO_STORAGE, # stride (non tensor)
utils.NO_STORAGE, # padding (non tensor)
utils.NO_STORAGE, # dilation (non tensor)
utils.NO_STORAGE, # groups (non tensor)
utils.NO_STORAGE, # original OC count (non tensor)
],
supports_resize=False,
supports_prepacking=True,
)


@update_features("llama::sdpa_with_kv_cache")
def register_sdpa_with_kv_cache_op():
return OpFeatures(
Expand Down
9 changes: 5 additions & 4 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
vulkan_supported_ops,
)

from executorch.backends.vulkan.patterns import PatternMatch

from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
Expand All @@ -41,7 +43,6 @@

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.fx.passes.utils.matcher_utils import InternalMatch

# pyre-ignore
ops_not_to_decompose = [
Expand All @@ -60,7 +61,7 @@ def __init__(
require_dynamic_shape: bool = False,
operator_blocklist: Optional[Set[OpKey]] = None,
operator_allowlist: Optional[Set[OpKey]] = None,
fusable_subgraphs: Optional[List[InternalMatch]] = None,
fusable_subgraphs: Optional[List[PatternMatch]] = None,
nn_module_blocklist: Optional[Set[str]] = None,
nn_module_allowlist: Optional[Set[str]] = None,
) -> None:
Expand All @@ -72,13 +73,13 @@ def __init__(
operator_blocklist if operator_blocklist is not None else set()
)
self.operator_allowlist = operator_allowlist
self.fusable_subgraphs: List[InternalMatch] = (
self.fusable_subgraphs: List[PatternMatch] = (
fusable_subgraphs if fusable_subgraphs is not None else []
)
# Create a set of all nodes that are part of fusable subgraphs for quick lookup
self.fusable_nodes: Set[torch.fx.Node] = set()
for match in self.fusable_subgraphs:
self.fusable_nodes.update(match.nodes_map.values())
self.fusable_nodes.update(match.all_nodes)

self.nn_module_blocklist = nn_module_blocklist
self.nn_module_allowlist = nn_module_allowlist
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/patterns/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ runtime.python_library(
"pattern_registry.py",
"rope.py",
"quantized_linear.py",
"quantized_convolution.py",
],
visibility = [
"//executorch/backends/...",
Expand Down
Loading
Loading