Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ def register_rotary_emb_op():
@update_features(
[
exir_ops.edge.aten.permute.default,
exir_ops.edge.aten.permute_copy.default,
]
)
def register_view_ops():
Expand All @@ -506,6 +505,7 @@ def register_view_ops():
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.permute_copy.default,
]
)
def register_view_ops_with_buffer_meta():
Expand Down
9 changes: 9 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/indexing.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ uint idx_at(const TensorIndex tidx, const int dim) {
return tidx.data[div_4(dim)][mod_4(dim)];
}

void permute(inout TensorIndex tidx, const ivec4 permute_order[DIMLIMIT_DIV4]) {
TensorIndex new_tidx = tidx;
for (int d = 0; d < DIMLIMIT; ++d) {
int src_dim = permute_order[div_4(d)][mod_4(d)];
new_tidx.data[div_4(d)][mod_4(d)] = idx_at(tidx, src_dim);
}
tidx = new_tidx;
}

//
// Index Conversions
//
Expand Down
52 changes: 14 additions & 38 deletions backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,55 +18,31 @@ ${define_required_extensions(DTYPE)}

layout(std430) buffer;

#include "indexing_utils.h"
#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")}
${layout_declare_tensor(B, "w", "t_outp", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_inp", DTYPE, "buffer")}

${layout_declare_ubo(B, "ivec4", "in_sizes")}
${layout_declare_ubo(B, "ivec4", "out_strides")}
${layout_declare_ubo(B, "int", "out_numel")}
${layout_declare_ubo(B, "BufferMetadata", "outp")}
${layout_declare_ubo(B, "BufferMetadata", "inp")}

layout(push_constant) uniform restrict Block {
ivec4 in_strides;
ivec4 permute_dims; // Permutation mapping: permute_dims[i] = j means output dim i comes from input dim j
};

${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}

const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
${layout_declare_ubo(B, "ivec4[DIMLIMIT_DIV4]", "permute_order")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// Convert output tensor index to input tensor index based on permutation
ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) {
ivec4 in_tidx;

// Apply the permutation mapping: in_tidx[permute_dims[i]] = out_tidx[i]
in_tidx[permute_dims.x] = out_tidx.x;
in_tidx[permute_dims.y] = out_tidx.y;
in_tidx[permute_dims.z] = out_tidx.z;
in_tidx[permute_dims.w] = out_tidx.w;

return in_tidx;
}

void main() {
const int out_bufi = ivec3(gl_GlobalInvocationID).x;
if (out_bufi >= out_numel) {
const uint inp_bufi = gl_GlobalInvocationID.x;
if (inp_bufi >= numel(inp)) {
return;
}

// Convert buffer index to tensor index for output
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);

// Convert output tensor index to input tensor index using permutation
const ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx);
TensorIndex inp_tidx;
linear_idx_to_tensor_idx(inp, inp_bufi, inp_tidx);

// Convert input tensor index back to buffer index
const int in_bufi = tidx_to_bufi(in_tidx, in_strides);
TensorIndex outp_tidx = inp_tidx;
permute(outp_tidx, permute_order);

const uint outp_bufi = tensor_idx_to_linear_idx(outp, outp_tidx);
// Copy data from input to output
t_out[out_bufi] = t_in[in_bufi];
t_outp[outp_bufi] = t_inp[inp_bufi];
}
120 changes: 90 additions & 30 deletions backends/vulkan/runtime/graph/ops/impl/Permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,37 +129,22 @@ void add_permute_node(
std::vector<PushConstantDataInfo> push_constants;
vkapi::SpecVarList spec_vars;

if (graph.is_buffer_storage(out)) {
param_buffers.append(graph.sizes_ubo(in));
param_buffers.append(graph.strides_ubo(out));
param_buffers.append(graph.numel_ubo(out));

// Buffer storage - use permute_buffer shader
push_constants = {
graph.strides_pc_of(in),
PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims)),
};

spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)};
} else {
// Texture storage - use permute_texture shader
const int32_t out_channels = dim_at<kChannel4D>(graph.sizes_of(out));
const int32_t in_channels = dim_at<kChannel4D>(graph.sizes_of(in));

const int32_t packed_dim = graph.packed_dim_of(in);
ivec2 channel_info = {out_channels, in_channels};
if (packed_dim == WHCN::kChannelsDim) {
channel_info[0] = utils::align_up_4(channel_info[0]);
channel_info[1] = utils::align_up_4(channel_info[1]);
}
const int32_t out_channels = dim_at<kChannel4D>(graph.sizes_of(out));
const int32_t in_channels = dim_at<kChannel4D>(graph.sizes_of(in));

const int32_t packed_dim = graph.packed_dim_of(in);
ivec2 channel_info = {out_channels, in_channels};
if (packed_dim == WHCN::kChannelsDim) {
channel_info[0] = utils::align_up_4(channel_info[0]);
channel_info[1] = utils::align_up_4(channel_info[1]);
}

push_constants = {
graph.sizes_pc_of(out),
graph.sizes_pc_of(in),
PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims))};
push_constants = {
graph.sizes_pc_of(out),
graph.sizes_pc_of(in),
PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims))};

spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)};
}
spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
Expand All @@ -179,8 +164,83 @@ void add_permute_node(
resize_permute_node));
}

struct WHCNPermuteDims {
int32_t whcn_permute_dims[api::kTensorDimLimit];

void initialize(const std::vector<int64_t>& permute_dims) {
const int32_t permute_ndim = permute_dims.size();
for (int32_t whcn_i = 0; whcn_i < permute_ndim; whcn_i++) {
const int32_t nchw_i = permute_ndim - 1 - whcn_i;
int64_t index_val = permute_dims.at(nchw_i);
if (index_val < 0) {
index_val += permute_ndim;
}
const int32_t permute_dim_whcn = permute_ndim - 1 - index_val;
whcn_permute_dims[whcn_i] = permute_dim_whcn;
}
for (int32_t whcn_i = permute_ndim; whcn_i < api::kTensorDimLimit;
whcn_i++) {
whcn_permute_dims[whcn_i] = whcn_i;
}
}
};

void add_permute_buffer_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef permute_dims,
const ValueRef out) {
check_args(graph, in, permute_dims, out);

WHCNPermuteDims whcn_permute_dims;
// Convert the permute dims to WHCN dimension order, which is the standard in
// our compute shaders. The following transformations are applied.
// 1. Change dimension index values from NCHW order valueto WHCN order value
// 2. Extend the permute array to kTensorDimLimit
{
IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims);
whcn_permute_dims.initialize(*permute_dims_ptr);
}

std::string kernel_name = "permute";
kernel_name.reserve(kShaderNameReserve);
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

vkapi::ParamsBindList param_buffers = {
graph.buffer_meta_ubo(out),
graph.buffer_meta_ubo(in),
graph.create_params_buffer(whcn_permute_dims),
};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Parameter buffers
param_buffers,
// Push Constants
{},
// Specialization Constants
{},
// Resize Args
{permute_dims},
// Resizing Logic
resize_permute_node));
}

void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_permute_node(graph, args[0], args[1], args[2]);
int idx = 0;
const ValueRef in = args.at(idx++);
const ValueRef permute_dims = args.at(idx++);
const ValueRef out = args.at(idx++);

if (graph.is_buffer_storage(args[2])) {
return add_permute_buffer_node(graph, in, permute_dims, out);
}
return add_permute_node(graph, in, permute_dims, out);
}

REGISTER_OPERATORS {
Expand Down
Loading