Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self.default_layout: VkMemoryLayout = default_memory_layout
self.texture_limits = texture_limits

def propose_node_storage(
def propose_node_storage( # noqa: C901
self,
node: torch.fx.Node,
) -> Optional[VkStorageType]:
Expand Down Expand Up @@ -138,15 +138,23 @@ def propose_node_storage(
for arg in node.args:
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
storage = utils.get_node_storage_type(arg)
# Some operators which return multiple output tensors may specify a
# different storage type for each output. In this case, the storage type
# for the first output is used.
if isinstance(storage, (list, tuple)):
storage = storage[0]
if storage is not None and storage in valid_storage_types:
return storage

# If no storage type has been resolved yet, assume the optimal storage type of
# the first opinionated user. This search is recursive.
for user in node.users:
optimal_storage = self.propose_node_storage(user)
if optimal_storage is not None:
return optimal_storage
storage = self.propose_node_storage(user)
# See above
if isinstance(storage, (list, tuple)):
storage = storage[0]
if storage is not None:
return storage

if self.default_storage in valid_storage_types:
return self.default_storage
Expand Down Expand Up @@ -179,15 +187,23 @@ def propose_node_layout(
for arg in node.args:
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
layout = utils.get_node_memory_layout(arg)
# Some operators which return multiple output tensors may specify a
# different memory layout for each output. In this case, the storage
# type for the first output is used.
if isinstance(layout, (list, tuple)):
layout = layout[0]
if layout is not None and layout in valid_layouts:
return layout

# If no storage type has been resolved yet, assume the optimal storage type of
# the first opinionated user. This search is recursive.
# If no memory layout has been resolved yet, assume the optimal layout of the
# first opinionated user. This search is recursive.
for user in node.users:
optimal_storage = self.propose_node_layout(user, storage)
if optimal_storage is not None:
return optimal_storage
layout = self.propose_node_layout(user, storage)
# See above comment
if isinstance(layout, (list, tuple)):
layout = layout[0]
if layout is not None:
return layout

# As a last resort, return the default storage type that should be used.
if self.default_layout in valid_layouts:
Expand Down
26 changes: 26 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,32 @@ def register_ported_ops_with_prepacking(features: OpFeatures):
return features


@update_features(
[
exir_ops.edge.aten.native_group_norm.default,
]
)
def register_native_group_norm(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.CHANNELS},
)
features.handles_own_prepacking = True

features.optimal_storage = [
VkStorageType.TEXTURE_3D,
VkStorageType.BUFFER,
VkStorageType.BUFFER,
]

features.optimal_layout = [
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
VkMemoryLayout.TENSOR_WIDTH_PACKED,
VkMemoryLayout.TENSOR_WIDTH_PACKED,
]

return features


# Ported ops that support their own prepacking.
@update_features(
[
Expand Down
32 changes: 32 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,38 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
VK_THROW("Could not get dtype of value with type ", val.type());
}

bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const {
if (!val_is_tensor(idx)) {
return false;
}
if (!is_buffer_storage(idx)) {
return false;
}
return is_contiguous(idx);
}

bool ComputeGraph::is_standard_channels_packed_texture_tensor(
const ValueRef idx) const {
if (!val_is_tensor(idx)) {
return false;
}
if (is_buffer_storage(idx)) {
return false;
}
return has_standard_axis_map(idx) && packed_dim_of(idx) == 2;
}

bool ComputeGraph::is_standard_width_packed_texture_tensor(
const ValueRef idx) const {
if (!val_is_tensor(idx)) {
return false;
}
if (is_buffer_storage(idx)) {
return false;
}
return has_standard_axis_map(idx) && packed_dim_of(idx) == 0;
}

ValueRef ComputeGraph::add_tensor(
const std::vector<int64_t>& sizes,
const vkapi::ScalarType dtype,
Expand Down
30 changes: 28 additions & 2 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class ComputeGraph final {
inline ptr_type get_##short_name(const ValueRef idx) { \
return ptr_type(this, idx); \
} \
inline bool val_is_##short_name(const ValueRef idx) { \
inline bool val_is_##short_name(const ValueRef idx) const { \
return values_.at(idx).is##type_name(); \
}

Expand Down Expand Up @@ -314,6 +314,32 @@ class ComputeGraph final {
return values_.at(idx).toConstTensor().has_buffer_storage();
}

/*
* Checks that the following is true:
* 1. The value at `idx` is a tensor
* 2. The tensor at `idx` has buffer storage
* 3. The buffer backed tensor at `idx` has a contiguous memory layout
*/
bool is_contiguous_buffer_tensor(const ValueRef idx) const;

/*
* Checks that the following is true:
* 1. The value at `idx` is a tensor
* 2. The tensor at `idx` has texture storage
* 3. The texture backed tensor at `idx` has a standard axis mapping
* 4. The texture backed tensor at `idx` is channels packed
*/
bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const;

/*
* Checks that the following is true:
* 1. The value at `idx` is a tensor
* 2. The tensor at `idx` has texture storage
* 3. The texture backed tensor at `idx` has a standard axis mapping
* 4. The texture backed tensor at `idx` is width packed
*/
bool is_standard_width_packed_texture_tensor(const ValueRef idx) const;

inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base)
const {
return values_.at(maybe_view)
Expand Down Expand Up @@ -354,7 +380,7 @@ class ComputeGraph final {
return values_.at(idx).toTensor().numel_ubo();
}

inline bool has_standard_axis_map(const ValueRef idx) {
inline bool has_standard_axis_map(const ValueRef idx) const {
return values_.at(idx).toTensor().has_standard_axis_map();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#include "broadcasting_utils.h"
#include "indexing_utils.h"

#define PRECISION ${PRECISION}

#define BUF_T ${buffer_scalar_type(DTYPE)}

${define_required_extensions(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_mean", DTYPE, "buffer")}
${layout_declare_tensor(B, "w", "t_rstd", DTYPE, "buffer")}

${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")}

${layout_declare_ubo(B, "ivec4", "mean_strides")}
${layout_declare_ubo(B, "int", "mean_numel")}
${layout_declare_ubo(B, "ivec3", "in_limits")}
${layout_declare_ubo(B, "ivec4", "in_sizes")}

layout(push_constant) uniform PRECISION restrict Block {
int group;
float epsilon;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "mean_layout", "DEFAULT_DIM_ORDER")}
const lowp ivec4 mean_dim_order = unhash_dim_order(mean_layout);

#define LOCAL_WORK_GROUP_SIZE 64
shared float shared_sum[LOCAL_WORK_GROUP_SIZE];
shared float shared_sum_sq[LOCAL_WORK_GROUP_SIZE];

/*
* Computes the mean and standard deviation of one group of channels of the
* input tensor for the group normalization operator.
*
* Given a tensor of shape [W, H, C, N] the mean and standard deviation tensors
* will have a shape of [G, N] where G = C / group.
*
* The input tensor is assumed to be a channels-packed texture tensor with the
* standard axis mapping. The output tensors are assumed to be contiguous buffer
* tensors.
*
* Algorithm:
* 1. Each shader invocation corresponds to one group in one batch
* 2. The local work group cooperatively reduces over all spatial locations (H×W)
* and all channels within the group (C/group channels)
* 3. Uses shared memory for efficient parallel reduction
* 4. Main thread (local ID 0) writes the final mean and rstd to buffer
*
* Global work group size: {N, 1, 1}
* N is the number of elements in the tensor buffer; each thread computes one
* output element.
*
* Local work group size: {1, float, 1}
* float should be a power of 2, recommended 64 or 128 threads. This allows
* efficient tree-based reduction in shared memory. Each local group will
* cooperate to compute the output element.
*
* Each shader invocation will compute the mean and standard deviation for one
* channel group in the input, and write out the corresponding result.
*/
void group_norm_reduce_C_packed() {
const int global_idx = int(gl_GlobalInvocationID.x);
const int local_idx = int(gl_LocalInvocationID.y);

// Calculate group dimensions
const int D = in_sizes.z / group; // channels per group
const int HxW = in_sizes.y * in_sizes.x; // spatial size
const int group_size = D * HxW; // total elements per group

// Convert global index to (group_idx, batch_idx)
const ivec4 mean_tidx = bufi_to_tidx(global_idx, mean_strides, mean_dim_order);

// Initialize local sums
float local_sum = 0.0;
float local_sum_sq = 0.0;
int local_count = 0;

// Calculate the range of channels for this group
const int group_start_channel = mean_tidx.x * D;
const int group_end_channel = group_start_channel + D;

// Calculate the range of texels that contain channels from this group
const int start_texel_idx = group_start_channel / 4;
const int end_texel_idx = divup4(group_end_channel);
const int texels_in_group = end_texel_idx - start_texel_idx;

// Total texels to process across all spatial locations
const int total_texels = texels_in_group * HxW;

// Each thread processes a subset of texels
const int texels_per_thread = (total_texels + LOCAL_WORK_GROUP_SIZE - 1) / LOCAL_WORK_GROUP_SIZE;
const int start_texel = local_idx * texels_per_thread;
const int end_texel = min(start_texel + texels_per_thread, total_texels);

// Process assigned texels
for (int texel_idx = start_texel; texel_idx < end_texel; texel_idx++) {
// Convert texel index to spatial and channel coordinates
const int spatial_idx = texel_idx / texels_in_group;
const int texel_in_group = texel_idx % texels_in_group;

// Convert to spatial coordinates
const int w = spatial_idx % in_sizes.x;
const int h = spatial_idx / in_sizes.x;

// Calculate the global texel index
const int global_texel_idx = start_texel_idx + texel_in_group;

// Convert to texture position using default axis mapping
ivec3 tex_pos = ivec3(w, h, global_texel_idx);

// Adjust for batch dimension if needed
if (in_sizes.w > 1) {
// default axis mapping means channels is the batch concat dim
tex_pos.z += mean_tidx.y * divup4(in_sizes.z);
}

// Check bounds and load texel
if (all(lessThan(tex_pos, in_limits))) {
const vec4 texel_val = load_texel(t_in, tex_pos);

// Process all components of the texel that belong to this group
const int texel_start_channel = global_texel_idx * 4;
for (int comp = 0; comp < 4; comp++) {
const int current_channel = texel_start_channel + comp;

// Check if this component belongs to the current group
if (current_channel >= group_start_channel && current_channel < group_end_channel) {
const float val = texel_val[comp];
local_sum += val;
local_sum_sq += val * val;
local_count++;
}
}
}
}

// Store local results in shared memory
shared_sum[local_idx] = local_sum;
shared_sum_sq[local_idx] = local_sum_sq;

// Synchronize threads
memoryBarrierShared();
barrier();

// Perform tree-based reduction in shared memory
for (int stride = LOCAL_WORK_GROUP_SIZE / 2; stride > 0; stride /= 2) {
if (local_idx < stride) {
shared_sum[local_idx] += shared_sum[local_idx + stride];
shared_sum_sq[local_idx] += shared_sum_sq[local_idx + stride];
}
memoryBarrierShared();
barrier();
}

// Main thread writes the result
if (local_idx == 0 && global_idx < mean_numel) {
const float total_sum = shared_sum[0];
const float total_sum_sq = shared_sum_sq[0];
const float count = float(group_size);

// Calculate mean and reciprocal standard deviation
const float mean_val = total_sum / count;
const float variance = (total_sum_sq / count) - (mean_val * mean_val);
const float rstd_val = 1.0 / sqrt(variance + epsilon);

// Write to buffer-backed tensors
t_mean[global_idx] = BUF_T(mean_val);
t_rstd[global_idx] = BUF_T(rstd_val);
}
}

void main() {
group_norm_reduce_C_packed();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

group_norm_reduce_texture:
parameter_names_with_default_values:
DTYPE: float
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: group_norm_reduce_texture
Loading
Loading