Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,7 @@ jobs:
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row
./cmake-out/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations
./cmake-out/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add

# "Classic" Operator tests
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/_passes/replace_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def call(self, graph_module: torch.fx.GraphModule):
if node.target in [
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default,
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default,
exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default,
]:
# Replace quantize op feeding into conv2d (first argument is the quantized input)
quantized_input_node = node.args[0]
Expand Down
42 changes: 42 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,3 +572,45 @@ def dequantize_q8to_from_conv2d_impl(
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd")
dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name)

########################
## add_q8ta_q8ta_q8to ##
########################


def add_q8ta_q8ta_q8to_impl(
input_a: torch.Tensor,
input_b: torch.Tensor,
input_a_scale: float,
input_a_zero_point: int,
input_b_scale: float,
input_b_zero_point: int,
output_scale: float,
output_zero_point: int,
alpha: float,
):
# Dequantize inputs to float
dequant_a = torch.ops.quantized_decomposed.dequantize_per_tensor(
input_a, input_a_scale, input_a_zero_point, -128, 127, input_a.dtype
)
dequant_b = torch.ops.quantized_decomposed.dequantize_per_tensor(
input_b, input_b_scale, input_b_zero_point, -128, 127, input_b.dtype
)

# Perform addition with alpha scaling
result = dequant_a + alpha * dequant_b

# Quantize the result back to int8
quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor(
result, output_scale, output_zero_point, -128, 127, torch.int8
)

return quantized_result


name = "add_q8ta_q8ta_q8to"
lib.define(
f"{name}(Tensor input_a, Tensor input_b, float input_a_scale, int input_a_zero_point, float input_b_scale, int input_b_zero_point, float output_scale, int output_zero_point, float alpha) -> Tensor"
)
lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd")
add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name)
13 changes: 13 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,19 @@ def register_quantized_conv_op():
)


@update_features(
[
exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default,
]
)
def register_quantized_binary_op():
return OpFeatures(
inputs_storage=utils.PACKED_INT8_4W4C_BUFFER,
supports_resize=False,
supports_prepacking=True,
)


@update_features(
[
exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/patterns/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ runtime.python_library(
"rope.py",
"quantized_linear.py",
"quantized_convolution.py",
"quantized_binary.py",
],
visibility = [
"//executorch/backends/...",
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from typing import List

import executorch.backends.vulkan.patterns.quantized_binary # noqa

import executorch.backends.vulkan.patterns.quantized_convolution # noqa

import executorch.backends.vulkan.patterns.quantized_linear # noqa
Expand Down
161 changes: 161 additions & 0 deletions backends/vulkan/patterns/quantized_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import executorch.backends.vulkan.utils as utils

import torch

from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops


class QuantizedBinaryMatch(PatternMatch):
def __init__(self, binary_node: torch.fx.Node) -> None:
self.anchor_node = binary_node
self.match_found = False
self.all_nodes = [self.anchor_node]

# Extract alpha parameter if it exists (for add operations)
self.alpha = 1.0
if len(binary_node.args) > 2 and binary_node.args[2] is not None:
# Alpha is typically a scalar value
if isinstance(binary_node.args[2], (int, float)):
self.alpha = binary_node.args[2]

# Identify input nodes - both should be dequantize nodes for static quantization
if len(binary_node.args) < 2:
return

input_a_node = binary_node.args[0]
assert isinstance(input_a_node, torch.fx.Node)
input_b_node = binary_node.args[1]
assert isinstance(input_b_node, torch.fx.Node)

# Both arguments must be dequant nodes for static quantization
if not utils.is_dequant_node(input_a_node) or not utils.is_dequant_node(
input_b_node
):
return

self.dequantize_input_a_node = input_a_node
self.dequantize_input_b_node = input_b_node

# Extract quantization parameters for input A
self.quantize_input_a_node = self.dequantize_input_a_node.args[0]
self.input_a_scales_node = self.dequantize_input_a_node.args[1]
self.input_a_zeros_node = self.dequantize_input_a_node.args[2]

# Extract quantization parameters for input B
self.quantize_input_b_node = self.dequantize_input_b_node.args[0]
self.input_b_scales_node = self.dequantize_input_b_node.args[1]
self.input_b_zeros_node = self.dequantize_input_b_node.args[2]

self.all_nodes.extend(
[self.dequantize_input_a_node, self.dequantize_input_b_node]
)

# Identify output node
self.output_node = self.anchor_node

# The binary operation output must have only one user; it will be either a relu node
# or a quantize node.
if len(self.output_node.users) != 1:
return

cur_node = list(self.output_node.users)[0]
self.relu_node = None
if cur_node.target == exir_ops.edge.aten.relu.default:
self.relu_node = cur_node
self.all_nodes.append(self.relu_node)
# If there's a relu, get its user (should be the quantize node)
if len(cur_node.users) != 1:
return
cur_node = list(cur_node.users)[0]

if not utils.is_quant_node(cur_node):
return

self.quantize_output_node = cur_node
self.output_scales_node = self.quantize_output_node.args[1]
self.output_zeros_node = self.quantize_output_node.args[2]

self.all_nodes.append(self.quantize_output_node)

self.match_found = True


# Define the binary operation anchor nodes that we support
binary_anchor_nodes = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.add_.Tensor,
}


@register_pattern_detector("quantized_binary")
def find_quantized_binary_patterns(
node: torch.fx.Node,
) -> Optional[QuantizedBinaryMatch]:
if node.target not in binary_anchor_nodes:
return None

matched_pattern = QuantizedBinaryMatch(node)
if matched_pattern.match_found:
return matched_pattern

return None


##
## Pattern Replacement
##


@register_pattern_replacement("quantized_binary")
def make_add_q8ta_q8ta_q8to_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedBinaryMatch,
):
# Determine the operation type based on the anchor node
op_target = None
if match.anchor_node.target in {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.add_.Tensor,
}:
op_target = exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default
else:
# For future binary operations, add more mappings here
raise NotImplementedError(
f"Unsupported binary operation: {match.anchor_node.target}"
)

with graph_module.graph.inserting_before(match.output_node):
qbinary_node = graph_module.graph.create_node(
"call_function",
op_target,
args=(
match.quantize_input_a_node,
match.quantize_input_b_node,
match.input_a_scales_node,
match.input_a_zeros_node,
match.input_b_scales_node,
match.input_b_zeros_node,
match.output_scales_node,
match.output_zeros_node,
match.alpha, # Alpha parameter for scaling
),
)

qbinary_node.meta["val"] = match.output_node.meta["val"]
match.quantize_output_node.replace_all_uses_with(qbinary_node)
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define NAME ${VARIANT_NAME}

#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
#define T ${texel_load_component_type(DTYPE, "buffer")}

$if IO_STORAGE == "buffer":
#define PACKED_INT8_OUTPUT_BUFFER
#define PACKED_INT8_INPUT_BUFFER

#define op(X, Y) ${OPERATOR}

${define_required_extensions(DTYPE)}

layout(std430) buffer;

#extension GL_EXT_debug_printf : enable
#define DEBUG_MODE
#include "indexing.glslh"
#include "common.glslh"

${layout_declare_tensor(B, "w", "t_packed_int8_out", "int", IO_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_packed_int8_in_a", "int", IO_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_packed_int8_in_b", "int", IO_STORAGE, is_scalar_array=False)}

${layout_declare_ubo(B, "ivec4", "out_sizes")}

layout(push_constant) uniform restrict Block {
float input_a_scale;
int input_a_zp;
float input_b_scale;
int input_b_zp;
float output_inv_scale;
int output_zp;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const int tid = int(gl_GlobalInvocationID.x);

const int W4 = div_up_4(out_sizes.x);
const int H = out_sizes.y;
const int C4 = div_up_4(out_sizes.z);
const int N = out_sizes.w;

if (tid >= W4 * H * C4 * N) {
return;
}

const ivec4 in_block_1 = t_packed_int8_in_a[tid];
const ivec4 in_block_2 = t_packed_int8_in_b[tid];

ivec4 out_block = ivec4(pack_into_int32(ivec4(output_zp)));

for (int row = 0; row < 4; row++) {
vec4 in_texel_1 = unpack_and_dequantize(
in_block_1[row], input_a_scale, input_a_zp);
vec4 in_texel_2 = unpack_and_dequantize(
in_block_2[row], input_b_scale, input_b_zp);

vec4 out_texel = op(in_texel_1, in_texel_2);
out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp);
}

t_packed_int8_out[tid] = out_block;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

binary_q8ta_q8ta_q8to:
parameter_names_with_default_values:
OPERATOR: X + Y
NDIM: 3
DTYPE: float
PACKING: C_packed
IO_STORAGE: buffer
generate_variant_forall:
IO_STORAGE:
- VALUE: buffer
shader_variants:
- NAME: add_q8ta_q8ta_q8to
OPERATOR: X + Y
14 changes: 14 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/common.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,20 @@ int pack_into_int32(const ivec4 quant_vals) {
return packed;
}

vec4 unpack_and_dequantize(
const int packed_int8_vals,
const float scale,
const int zp) {
ivec4 unpacked = unpack_int8x4(packed_int8_vals);
return vec4(unpacked - zp) * scale;
}

int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) {
ivec4 quantized = ivec4(round(vals * inv_scale) + zp);
quantized = clamp(quantized, -128, 127);
return pack_into_int32(quantized);
}

#ifdef DEBUG_MODE

#extension GL_EXT_debug_printf : require
Expand Down
Loading
Loading