Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions backends/vulkan/_passes/replace_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,23 @@ def call(self, graph_module: torch.fx.GraphModule):
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default,
exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default,
]:
# Replace quantize op feeding into conv2d (first argument is the quantized input)
quantized_input_node = node.args[0]
if isinstance(
quantized_input_node, torch.fx.Node
) and utils.is_quant_node(quantized_input_node):
# Get the arguments from the original quantize node
input_tensor = quantized_input_node.args[0]
scale = quantized_input_node.args[1]
zero_point = quantized_input_node.args[2]
for quantized_input_node in node.args:
if isinstance(
quantized_input_node, torch.fx.Node
) and utils.is_quant_node(quantized_input_node):
# Get the arguments from the original quantize node
input_tensor = quantized_input_node.args[0]
scale = quantized_input_node.args[1]
zero_point = quantized_input_node.args[2]

nodes_to_replace.append(
{
"old_node": quantized_input_node,
"new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,
"args": (input_tensor, scale, zero_point),
"node_type": "quantize_input",
}
)
nodes_to_replace.append(
{
"old_node": quantized_input_node,
"new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,
"args": (input_tensor, scale, zero_point),
"node_type": "quantize_input",
}
)

# Find dequantize ops that consume the output of this conv2d
for user in node.users:
Expand Down
30 changes: 21 additions & 9 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No

arg_repset = op_repsets.get_arg_repset(arg_i)
if arg_repset.is_constrained():
return arg_repset
return

arg_node = op_repsets.op_node.args[arg_i]

Expand All @@ -378,21 +378,33 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No
arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset)

def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None:
"""
Similar to the `constrain_op_arg_repset` function, but for the output repset of
the operator.
"""
out_repset = op_repsets.get_out_repset(0)
if out_repset.is_constrained():
return

op_node = op_repsets.op_node
out_respset = self.trace_node_users_to_constrain_repset(op_node, out_repset)

op_repsets.try_constrain_with_out_repset(out_respset)

def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None:
# For most ops, constraining the argument repsets will also contrain the output
# repset due to OpRepSets maintaining synchronization rules.
for i in range(len(op_repsets.op_node.args)):
if utils.is_tensor_arg_node(op_repsets.op_node.args[i]):
self.constrain_op_arg_repset(i, op_repsets)

# TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there
# is no need to constrain output repsets explicitly. Currently, the exceptions
# (i.e. choose qparams) already define constrined repsets for the output, so
# there is again no need to explicitly constrain the outputs. If an operator
# appears later on that does not sync input and output representations, and
# defines ambiguous repsets for the output tensor(s), then we will need to add
# additional logic to this function to constrain the output repsets separately
# from the input repsets.
# However, some operators do not sync input and output representations and also
# define ambiguous repsets for the output tensor(s). In those cases we will need
# to execute additional logic to constrain the output repsets separately from
# the input repsets.
if not op_repsets.sync_primary_io_repr and op_repsets.sync_outs_repr:
self.constrain_op_out_repset(op_repsets)

def set_op_node_tensor_reprs(
self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node
Expand Down
5 changes: 3 additions & 2 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def register_quantized_binary_op():
def register_quantize_for_conv2d_op():
return OpFeatures(
inputs_storage=[
utils.CHANNELS_PACKED_TEXTURE,
utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER,
],
outputs_storage=[
utils.PACKED_INT8_4W4C_BUFFER,
Expand All @@ -656,7 +656,7 @@ def register_dequantize_for_conv2d_op():
utils.PACKED_INT8_4W4C_BUFFER,
],
outputs_storage=[
utils.CHANNELS_PACKED_TEXTURE,
utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER,
],
supports_resize=False,
)
Expand Down Expand Up @@ -711,6 +711,7 @@ def register_view_ops():
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.gather.default,
]
)
def register_view_ops_with_buffer_meta():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,21 @@
#include "linear_fp_input_tile.glslh"

VEC4_T load_fp_input_texel(const Conv2dTensorIndex tidx) {
#ifdef INPUT_BUFFER
VEC4_T texel = VEC4_T(0);
const int c_idx = mul_4(tidx.data.z);
const int c_stride = input_sizes.y * input_sizes.x;

const int base_buf_i = c_idx * c_stride + tidx.data.y * input_sizes.x + tidx.data.x;
const int limit = min(input_sizes.z - c_idx, 4);

for (int i = 0; i < limit; i++) {
texel[i] = t_fp_input[base_buf_i + i * c_stride];
}
return texel;
#else
return texelFetch(t_fp_input, tidx.data, 0);
#endif
}

void load_fp_input_tile(
Expand All @@ -23,7 +37,9 @@ void load_fp_input_tile(
#if TILE_M == 4 && TILE_K4 == 1
Conv2dTensorIndex load_tidx = block_idx_to_tensor_idx(block_idx);
[[unroll]] for (int w = 0; w < TILE_M; w++) {
tile.data[w][0] = load_fp_input_texel(load_tidx);
if (load_tidx.data.x < input_sizes.x) {
tile.data[w][0] = load_fp_input_texel(load_tidx);
}
load_tidx.data.x++;
}
#else
Expand Down
57 changes: 57 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define T ${buffer_scalar_type(DTYPE)}

${define_active_storage_type("buffer")}
${define_required_extensions(DTYPE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_input", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_index", "int", "buffer")}

${layout_declare_ubo(B, "BufferMetadata", "outp")}
${layout_declare_ubo(B, "BufferMetadata", "inp")}
${layout_declare_ubo(B, "BufferMetadata", "index")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int gather_dim = 0;

void main() {
const uint out_bufi = gl_GlobalInvocationID.x;
if (out_of_bounds(out_bufi, outp)) {
return;
}

TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi);

// Load the index value at the same position in the index tensor
const uint index_bufi = tensor_idx_to_linear_idx(index, out_tidx);
const int gather_idx = t_index[index_bufi];

// Construct the input tensor index by replacing the gather dimension
// with the gathered index value
TensorIndex input_tidx = out_tidx;
input_tidx.data[div_4(gather_dim)][mod_4(gather_dim)] = gather_idx;

// Load from input tensor and store to output
const uint input_bufi = tensor_idx_to_linear_idx(inp, input_tidx);

t_out[out_bufi] = t_input[input_bufi];
}
16 changes: 16 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

gather_buffer:
parameter_names_with_default_values:
DTYPE: float
STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: gather_buffer
67 changes: 67 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}
#define T ${texel_load_component_type(DTYPE, "texture3d")}

${define_active_storage_type("texture3d")}
${define_required_extensions(DTYPE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

#include "common.glslh"
#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_input", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_index", "int", "texture3d")}

${layout_declare_ubo(B, "TextureMetadata", "outp")}
${layout_declare_ubo(B, "TextureMetadata", "inp")}
${layout_declare_ubo(B, "TextureMetadata", "index")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int gather_dim = 0;

void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);

if (out_of_bounds(out_pos, outp)) {
return;
}

TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);
ivec4 idx_texel = texelFetch(t_index, out_pos, 0);

VEC4_T out_texel = VEC4_T(0);

int limit = min(
4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]);
for (int comp = 0; comp < 4; comp++) {
TensorIndex4D input_tidx = out_tidx;
int gather_idx = idx_texel[comp];
input_tidx.data[gather_dim] = gather_idx;

TextureElementIndex input_elem_pos = tensor4d_idx_to_texture_element_idx_simple(
inp, input_tidx);

VEC4_T input_texel = texelFetch(t_input, input_elem_pos.pos, 0);
out_texel[comp] = input_texel[input_elem_pos.comp];

out_tidx.data[outp.packed_dim]++;
}

imageStore(t_out, out_pos, out_texel);
}
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

gather_texture:
parameter_names_with_default_values:
DTYPE: float
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: gather_texture3d
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ layout(std430) buffer;
#include "conv2d_common.glslh"

${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE)}

${layout_declare_ubo(B, "ivec4", "input_sizes")}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ quantize_and_pack_q8ta_conv2d_input:
combos:
- parameter_values: [texture3d, texture3d]
- parameter_values: [buffer, texture3d]
- parameter_values: [buffer, buffer]
DTYPE:
- VALUE: float
shader_variants:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ layout(std430) buffer;

#include "conv2d_common.glslh"

${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE)}
${layout_declare_tensor(B, "r", "t_packed_int8_output", "int", INPUT_STORAGE, is_scalar_array=False)}

${layout_declare_ubo(B, "ivec4", "output_sizes")}
Expand Down Expand Up @@ -84,15 +84,29 @@ void unpack_and_dequantize(
void store_fp_output_texel(
const Conv2dTensorIndex tidx,
const VEC4_T out_texel) {
#ifdef OUTPUT_BUFFER
const int c_idx = mul_4(tidx.data.z);
const int c_stride = output_sizes.y * output_sizes.x;

const int base_buf_i = c_idx * c_stride + tidx.data.y * output_sizes.x + tidx.data.x;
const int limit = min(output_sizes.z - c_idx, 4);

for (int i = 0; i < limit; ++i) {
t_fp_output[base_buf_i + i * c_stride] = out_texel[i];
}
#else
imageStore(t_fp_output, tidx.data, out_texel);
#endif
}

void store_fp_tile(
const FPInputTile block,
const Conv2dBlockIndex block_idx) {
Conv2dTensorIndex store_tidx = block_idx_to_tensor_idx(block_idx);
[[unroll]] for (int w = 0; w < 4; w++) {
store_fp_output_texel(store_tidx, block.data[w][0]);
if (store_tidx.data.x < output_sizes.x) {
store_fp_output_texel(store_tidx, block.data[w][0]);
}
store_tidx.data.x++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ unpack_and_dequantize_q8ta_conv2d_output:
combos:
- parameter_values: [texture3d, texture3d]
- parameter_values: [texture3d, buffer]
- parameter_values: [buffer, buffer]
DTYPE:
- VALUE: float
shader_variants:
Expand Down
Loading
Loading