Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ def get_arg_tensor_source_repset(
"""
arg_node = op_node.args[arg_i]

# For non-tensor arguments, return ALL_STORAGES_REPSET so that the respset does
# For non-tensor arguments, return ANY_STORAGE_INCL_PACKED_INT8 so that the respset does
# not appear to be empty.
if not utils.is_tensor_arg_node(arg_node):
return utils.ALL_STORAGES_REPSET
return utils.ANY_STORAGE_INCL_PACKED_INT8

# Special case for cat - use the first tensor in the list as representative
if isinstance(arg_node, list):
Expand Down
33 changes: 33 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,39 @@ def q8ta_relu_impl(
lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd")
q8ta_relu_op = getattr(getattr(torch.ops, namespace), name)

###########################
## q8ta_pixel_shuffle ##
###########################


def q8ta_pixel_shuffle_impl(
input: torch.Tensor,
input_scale: float,
input_zero_point: int,
output_inv_scale: float,
output_zero_point: int,
upscale_factor: int,
):
# Reference Python impl for op registration. The runtime kernel does a
# fused byte-shuffle (and optional requantize when scales differ).
output_scale = 1.0 / output_inv_scale
dequant = torch.ops.quantized_decomposed.dequantize_per_tensor(
input, input_scale, input_zero_point, -128, 127, input.dtype
)
shuffled = torch.nn.functional.pixel_shuffle(dequant, upscale_factor)
requantized = torch.ops.quantized_decomposed.quantize_per_tensor(
shuffled, output_scale, output_zero_point, -128, 127, torch.int8
)
return requantized


name = "q8ta_pixel_shuffle"
lib.define(
f"{name}(Tensor input, float input_scale, int input_zero_point, float output_inv_scale, int output_zero_point, int upscale_factor) -> Tensor"
)
lib.impl(name, q8ta_pixel_shuffle_impl, "CompositeExplicitAutograd")
q8ta_pixel_shuffle_op = getattr(getattr(torch.ops, namespace), name)

########################
## embedding_q4gsw ##
########################
Expand Down
41 changes: 37 additions & 4 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,25 @@ def register_q8ta_relu():
)


# =============================================================================
# Q8taPixelShuffle.cpp
# =============================================================================


@update_features(exir_ops.edge.et_vk.q8ta_pixel_shuffle.default)
def register_q8ta_pixel_shuffle():
# The fused kernel is restricted to the channels-packed family
# (PACKED_INT8_4W4C, PACKED_INT8_4C1W, PACKED_INT8_CONV2D), all of which
# share packed_dim=C. See add_q8ta_pixel_shuffle_node in Q8taPixelShuffle.cpp
# for the runtime assertion. The surrounding q8ta_conv2d ops produce
# PACKED_INT8_4W4C on this model, so the partitioner can route through this
# op without inserting layout-transition q8ta_clone dispatches.
return OpFeatures(
inputs_storage=utils.PACKED_INT8_CHANNELS_PACKED_BUFFER,
supports_resize=True,
)


# =============================================================================
# =============================================================================

Expand Down Expand Up @@ -1158,7 +1177,7 @@ def register_permute_copy():
@update_features(exir_ops.edge.aten.view_copy.default)
def register_view_copy():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
Expand Down Expand Up @@ -1213,7 +1232,7 @@ def register_unsqueeze_copy():
@update_features(exir_ops.edge.aten.clone.default)
def register_clone():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
Expand All @@ -1223,7 +1242,7 @@ def register_clone():
@update_features(exir_ops.edge.dim_order_ops._clone_dim_order.default)
def register_clone_dim_order():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
Expand All @@ -1237,7 +1256,7 @@ def register_clone_dim_order():
@update_features(exir_ops.edge.aten.alias_copy.default)
def register_alias_copy():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
supports_highdim=True,
Expand Down Expand Up @@ -1505,6 +1524,20 @@ def register_upsample_cpp_ops():
)


# =============================================================================
# PixelShuffle.cpp
# =============================================================================


@update_features(exir_ops.edge.aten.pixel_shuffle.default)
def register_pixel_shuffle():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_T,
supports_resize=True,
)


# =============================================================================
# GridPriors.cpp
# =============================================================================
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ops_not_to_decompose = [
torch.ops.aten.hardswish.default,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.pixel_shuffle.default,
]

logger: logging.Logger = logging.getLogger("")
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/patterns/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ fbcode_target(_kind = runtime.python_library,
"quantized_linear.py",
"quantized_convolution.py",
"quantized_binary.py",
"quantized_pixel_shuffle.py",
"quantized_unary.py",
"rms_norm.py",
"sdpa.py",
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import executorch.backends.vulkan.patterns.quantized_linear # noqa

import executorch.backends.vulkan.patterns.quantized_pixel_shuffle # noqa

import executorch.backends.vulkan.patterns.quantized_unary # noqa

import executorch.backends.vulkan.patterns.rms_norm # noqa
Expand Down
180 changes: 180 additions & 0 deletions backends/vulkan/patterns/quantized_pixel_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Set

import executorch.backends.vulkan.utils as utils

import torch

from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops

from torch.fx.node import Argument


# Set of ops that act as no-ops on values (i.e. clones / dim_order copies that
# preserve dtype and shape). The matcher transparently skips these between the
# dequantize, pixel_shuffle, and quantize nodes.
_NOOP_PASSTHROUGH_TARGETS: Set[object] = {
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
}


def _is_noop_passthrough(node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in _NOOP_PASSTHROUGH_TARGETS


def _skip_passthrough_user(
node: torch.fx.Node, collected: List[torch.fx.Node]
) -> Optional[torch.fx.Node]:
"""Given `node`, advance to its next non-passthrough user, walking through
any chain of clone/dim_order_copy ops in between (collecting them in
`collected`). Returns None if `node` has not exactly one user, or if any
intermediate passthrough has more than one user."""
if len(node.users) != 1:
return None
cur = next(iter(node.users))
while _is_noop_passthrough(cur):
collected.append(cur)
if len(cur.users) != 1:
return None
cur = next(iter(cur.users))
return cur


class QuantizedPixelShuffleMatch(PatternMatch):
"""
Matches an un-decomposed PixelShuffle wrapped between a quant/dequant pair:

q8ta_dequantize_per_tensor (int8 -> fp32)
[optional] clone / _clone_dim_order
aten.pixel_shuffle.default (upscale_factor = r)
[optional] clone / _clone_dim_order
q8ta_quantize_per_tensor (fp32 -> int8)

The anchor is the dequantize node since it is a unique entry point.

This relies on the partitioner's `ops_to_not_decompose()` hook preserving
`aten.pixel_shuffle.default` through edge lowering, so we do not need to
re-detect the decomposed view -> permute -> view pattern.
"""

def __init__(self, dequantize_node: torch.fx.Node) -> None:
self.anchor_node: torch.fx.Node = dequantize_node
self.match_found: bool = False
self.all_nodes: List[torch.fx.Node] = [dequantize_node]

# Validate the dequantize node is one of the quant decomposed ops.
if not utils.is_dequant_node(dequantize_node):
return

# Walk forward to the pixel_shuffle node (skipping any clones).
pixel_shuffle_node = _skip_passthrough_user(dequantize_node, self.all_nodes)
if pixel_shuffle_node is None:
return
if pixel_shuffle_node.op != "call_function":
return
if pixel_shuffle_node.target != exir_ops.edge.aten.pixel_shuffle.default:
return

# Walk forward to the quantize node (skipping any clones).
quantize_node = _skip_passthrough_user(pixel_shuffle_node, self.all_nodes)
if quantize_node is None or not utils.is_quant_node(quantize_node):
return

# pixel_shuffle args are (input, upscale_factor).
if len(pixel_shuffle_node.args) < 2:
return
upscale_factor = pixel_shuffle_node.args[1]
if not isinstance(upscale_factor, int):
return

# Capture the nodes and quant params we need for the replacement.
self.dequantize_input_node = dequantize_node
self.pixel_shuffle_node: torch.fx.Node = pixel_shuffle_node
self.quantize_output_node: torch.fx.Node = quantize_node

self.input_int8_node: Argument = dequantize_node.args[0]
self.input_scales_node: Argument = dequantize_node.args[1]
self.input_zeros_node: Argument = dequantize_node.args[2]
self.output_scales_node: Argument = quantize_node.args[1]
self.output_zeros_node: Argument = quantize_node.args[2]
self.upscale_factor: int = upscale_factor

self.all_nodes.extend([pixel_shuffle_node, quantize_node])
# The replacement target replaces uses of the quantize node.
self.output_node: torch.fx.Node = quantize_node

self.match_found = True


@register_pattern_detector("quantized_pixel_shuffle")
def find_quantized_pixel_shuffle_pattern(
node: torch.fx.Node,
) -> Optional[QuantizedPixelShuffleMatch]:
if node.op != "call_function":
return None
if not utils.is_dequant_node(node):
return None
matched = QuantizedPixelShuffleMatch(node)
if matched.match_found:
return matched
return None


@register_pattern_replacement("quantized_pixel_shuffle")
def make_quantized_pixel_shuffle_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedPixelShuffleMatch,
) -> None:
op_target = exir_ops.edge.et_vk.q8ta_pixel_shuffle.default

# The fused op takes the *inverse* of the output scale to match the
# runtime kernel's expectation.
output_scale = match.output_scales_node
inv_output_scale: object
if isinstance(output_scale, (int, float)):
inv_output_scale = 1.0 / float(output_scale)
else:
# Intentional bail-out at the replacement step (not a TODO). The
# matcher deliberately does not pre-validate that the output scale is
# scalar because every observed quantize_per_tensor in real models has
# a baked-in float scale; if that assumption breaks, we want a loud
# failure here at fusion time rather than a silent miscompile.
# If the output scale is a graph node (rare for static per-tensor
# quant, but possible), insert a reciprocal computation. For all the
# cases observed in the model the scales are baked-in floats, so we
# raise here to make the failure visible rather than producing a
# silent miscompile.
raise NotImplementedError(
"quantized_pixel_shuffle pattern only supports scalar output scales"
)

with graph_module.graph.inserting_before(match.output_node):
new_node = graph_module.graph.create_node(
"call_function",
op_target,
args=(
match.input_int8_node,
match.input_scales_node,
match.input_zeros_node,
inv_output_scale,
match.output_zeros_node,
match.upscale_factor,
),
)

new_node.meta["val"] = match.output_node.meta["val"]
match.quantize_output_node.replace_all_uses_with(new_node)
Loading
Loading