Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -933,10 +933,13 @@ jobs:
PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add
./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear
./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row

# Run e2e testing for selected operators. More operators will be tested via this
# route in the future.
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*torchao*"

nxp-build-test:
name: nxp-build-test
Expand Down
30 changes: 30 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,19 @@ def linear_q4gsw(
return out


def linear_dq8ca_q4gsw(
x: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
group_size: int,
bias: Optional[torch.Tensor] = None,
):
return linear_q4gsw(x, weights, weight_scales, group_size)


name = "linear_q4gsw"
lib.define(
f"""
Expand All @@ -307,6 +320,23 @@ def linear_q4gsw(
lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd")
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)

name = "linear_dq8ca_q4gsw"
lib.define(
f"""
{name}(
Tensor input,
Tensor input_scales,
Tensor input_zp,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
int group_size,
Tensor? bias = None) -> Tensor
"""
)
lib.impl(name, linear_dq8ca_q4gsw, "CompositeExplicitAutograd")
linear_dq8ca_q4gsw_op = getattr(getattr(torch.ops, namespace), name)

########################
## linear_qta8a_qga4w ##
########################
Expand Down
22 changes: 19 additions & 3 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def register_torchao_choose_qparams_affine():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
outputs_storage=[
utils.CONTIGUOUS_BUFFER, # scales
utils.CONTIGUOUS_BUFFER, # zero_points
utils.WIDTH_PACKED_TEXTURE, # scales
utils.WIDTH_PACKED_TEXTURE, # zero_points
],
supports_resize=True,
)
Expand Down Expand Up @@ -341,7 +341,23 @@ def register_quantized_linear_ops():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
supports_prepacking=True,
supports_resize=False,
)


@update_features(exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default)
def register_linear_dqa_qw_ops():
return OpFeatures(
inputs_storage=[
utils.CONTIGUOUS_ANY, # input
utils.WIDTH_PACKED_TEXTURE, # input_scale
utils.WIDTH_PACKED_TEXTURE, # input_zero_point
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # group_size (scalar)
utils.NO_STORAGE, # bias (prepacked)
],
supports_prepacking=True,
)


Expand Down
143 changes: 139 additions & 4 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator

from typing import Optional

import executorch.backends.vulkan.utils as utils
Expand Down Expand Up @@ -117,8 +119,19 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
self.match_found = True
return

self.input_scales_node = self.quantize_input_node.args[1]
self.input_zeros_node = self.quantize_input_node.args[2]
scales_arg_idx = 1
zeros_arg_idx = 2

# torchao op has a slightly different function schema
if (
self.quantize_input_node.target
== exir_ops.edge.torchao.quantize_affine.default
):
scales_arg_idx = 2
zeros_arg_idx = 3

self.input_scales_node = self.quantize_input_node.args[scales_arg_idx]
self.input_zeros_node = self.quantize_input_node.args[zeros_arg_idx]

assert dq_node is not None
self.all_nodes.extend(
Expand Down Expand Up @@ -164,6 +177,27 @@ def is_input_static_per_tensor_quantized(self) -> bool:
# are scalars.
return isinstance(self.input_scales_node, float)

def is_input_dynamic_perchannel_quantized(self) -> bool:
if self.quantize_input_node is None:
return False

if not isinstance(self.input_scales_node, torch.fx.Node):
return False

# For dynamic quantization, input scale node should be a getitem operator
# retrieving the output of a choose_qparams op
if self.input_scales_node.target != operator.getitem:
return False

# The getitem node should be retrieving from a choose_qparams op
if not utils.is_choose_qparams_node(self.input_scales_node.args[0]):
return False

scales_shape = self.input_scales_node.meta["val"].shape
input_shape = self.fp_input_node.meta["val"].shape

return input_shape[-2] == scales_shape[-1]


linear_anchor_nodes = {
exir_ops.edge.aten.linear.default,
Expand Down Expand Up @@ -230,6 +264,34 @@ def pack_4bit_weight_tensor(weight_tensor: torch.Tensor) -> torch.Tensor:
return weight_tensor[::, 1::2] << 4 | weight_tensor[::, ::2]


def compute_per_group_sums(weight_tensor: torch.Tensor, group_size: int):
"""
Compute the sum of weights per quantization group.

Args:
weight_tensor (torch.Tensor): Tensor of shape [out_channels, in_channels], dtype int8.
group_size (int): Number of input channels per quantization group.

Returns:
torch.Tensor: Tensor of shape [num_groups, out_channels], where num_groups = in_channels // group_size.
"""
out_channels, in_channels = weight_tensor.shape
num_groups = in_channels // group_size
# Reshape to [out_channels, num_groups, group_size]
reshaped = weight_tensor.view(out_channels, num_groups, group_size)
# Sum over group_size dimension to get [out_channels, num_groups]
sums = reshaped.sum(dim=2)
# Transpose to [num_groups, out_channels]
sums = sums.transpose(0, 1).contiguous()
# Pad out_channels dim (dim=1) to be a multiple of 8 if needed
out_channels = sums.shape[1]
if out_channels % 8 != 0:
num_pad = 8 - (out_channels % 8)
sums = F.pad(sums, (0, num_pad))

return sums.to(torch.int32).contiguous()


##
## Pattern Replacement
##
Expand Down Expand Up @@ -281,6 +343,73 @@ def make_linear_q4gsw_op(
match.output_node.replace_all_uses_with(linear_q4gsw_node)


def make_linear_dq8ca_q4gsw_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedLinearMatch,
weight_tensor: torch.Tensor,
weight_scales_tensor: torch.Tensor,
):
num_groups = weight_scales_tensor.shape[-1]
in_channels = weight_tensor.shape[-1]
group_size = in_channels // num_groups

# Compute per quant group sums before packing the weight tensor
sum_per_quant_group = compute_per_group_sums(weight_tensor, group_size)

weight_tensor = pack_4bit_weight_tensor(weight_tensor)
# Use this function for convenience to update the state dict with the packed
# weight tensor. Alignment will already have been done in the above function.
weight_tensor = utils.align_width_and_update_state_dict(
ep, match.weight_node, weight_tensor, align_to=1, force_update=True
)

# Also transpose the weight scales tensor to shape [num_groups, N]
weight_scales_tensor = weight_scales_tensor.transpose(0, 1).contiguous()
utils.align_width_and_update_state_dict(
ep,
match.weight_scales_node,
weight_scales_tensor,
align_to=1,
force_update=True,
)

first_graph_node = list(graph_module.graph.nodes)[0]
with graph_module.graph.inserting_before(first_graph_node):
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
# Pre-compute the weight sums which are needed to apply activation zero point
# when using integer accumulation.
sums_name = weight_tensor_name + "_sums"
# Sanitize the name
sums_name = sums_name.replace(".", "_")

weight_sums_node = create_constant_placeholder(
exp_program=ep,
graph=graph_module.graph,
kind=InputKind.CONSTANT_TENSOR,
name=sums_name,
data=sum_per_quant_group,
)

with graph_module.graph.inserting_before(match.output_node):
qlinear_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default,
args=(
match.fp_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
weight_sums_node,
match.weight_scales_node,
group_size,
),
)

qlinear_node.meta["val"] = match.output_node.meta["val"]
match.output_node.replace_all_uses_with(qlinear_node)


def make_linear_q8ta_q8csw_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
Expand Down Expand Up @@ -354,10 +483,16 @@ def replace_quantized_linear_patterns(
make_linear_q4gsw_op(
ep, graph_module, match, weight_tensor, weight_scales_tensor
)
elif (
match.is_input_dynamic_perchannel_quantized()
and match.is_weight_pergroup_quantized()
and utils.is_in_4bit_range(weight_tensor)
):
make_linear_dq8ca_q4gsw_op(
ep, graph_module, match, weight_tensor, weight_scales_tensor
)
elif (
match.is_input_static_per_tensor_quantized()
and match.is_weight_perchannel_quantized()
):
make_linear_q8ta_q8csw_custom_op(ep, graph_module, match, weight_tensor)

# No-op for unsupported quant patterns
7 changes: 7 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ void DispatchNode::encode(ComputeGraph* graph) {
if (!shader_) {
return;
}

// If any global wg size element is 0, then skip encoding this shader
if (global_workgroup_size_[0] == 0 || global_workgroup_size_[1] == 0 ||
global_workgroup_size_[2] == 0) {
return;
}

api::Context* const context = graph->context();
vkapi::PipelineBarrier pipeline_barrier{};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) {
const vec4 mat1_tex = texelFetch(mat1_tensor, mat1_pos, 0);

for (int r = 0; r < 4; ++r) {
if (4 * i + r >= mat2_sizes.y) {
continue;
}
// On-demand construction of mat2_pos appears to provide the lowest
// latency. Surprisingly, this doesn't translate to mat1_pos.
ivec3 mat2_pos = ivec3(0);
Expand Down
Loading
Loading