Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
#version 450 core

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
#define T ${texel_load_component_type(DTYPE, "buffer")}
#define BUF_T ${buffer_scalar_type(BUF_DTYPE)}
#define VEC4_T ${texel_load_type(DTYPE, PACKED_STORAGE)}
#define T ${texel_load_component_type(DTYPE, PACKED_STORAGE)}

$if PACKED_STORAGE == "buffer":
#define OUTPUT_BUFFER

#extension GL_EXT_control_flow_attributes : require

${define_required_extensions("buffer", DTYPE)}
${define_required_extensions("buffer", BUF_DTYPE)}
$if PACKED_STORAGE != "buffer":
${define_required_extensions(PACKED_STORAGE, DTYPE)}

Expand All @@ -29,7 +30,7 @@ $if PACKED_STORAGE == "buffer":
${layout_declare_tensor(B, "w", "t_weight_packed", DTYPE, "buffer", is_scalar_array=False)}
$else:
${layout_declare_tensor(B, "w", "t_weight_packed", DTYPE, PACKED_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_weight_src", DTYPE, "buffer", is_scalar_array=True)}
${layout_declare_tensor(B, "r", "t_weight_src", BUF_DTYPE, "buffer", is_scalar_array=True)}

layout(push_constant) uniform restrict Block {
int N;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
pack_fp_linear_weight:
parameter_names_with_default_values:
DTYPE: float
BUF_DTYPE: float
PACKED_STORAGE: texture2d
generate_variant_forall:
PACKED_STORAGE:
- VALUE: texture2d
- VALUE: buffer
DTYPE:
- VALUE: float
- VALUE: half
combination:
parameter_names: [PACKED_STORAGE, DTYPE, BUF_DTYPE]
combos:
- parameter_values: [texture2d, float, float]
- parameter_values: [texture2d, half, half]
- parameter_values: [texture2d, half, float]
- parameter_values: [buffer, float, float]
- parameter_values: [buffer, half, half]
shader_variants:
- NAME: pack_fp_linear_weight
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ static ValueRef prepack_conv1d_pw_weight(
std::string kernel_name = "pack_fp_linear_weight";
add_storage_type_suffix(kernel_name, weight_storage);
add_dtype_suffix(kernel_name, graph.dtype_of(weight_data));
add_dtype_suffix(kernel_name, graph.get_staging_dtype_for(weight_data));

graph.prepack_nodes().emplace_back(new PrepackNode(
graph,
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/impl/Conv2dPW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ ValueRef prepack_conv2d_pw_weight(
std::string pack_kernel_name = "pack_fp_linear_weight";
add_storage_type_suffix(pack_kernel_name, weight_storage);
add_dtype_suffix(pack_kernel_name, graph.dtype_of(weight_data));
add_dtype_suffix(pack_kernel_name, graph.get_staging_dtype_for(weight_data));

graph.prepack_nodes().emplace_back(new PrepackNode(
graph,
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/impl/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ ValueRef prepack_fp_linear_weight(
std::string kernel_name = "pack_fp_linear_weight";
add_storage_type_suffix(kernel_name, weight_storage);
add_dtype_suffix(kernel_name, graph.dtype_of(weight_data));
add_dtype_suffix(kernel_name, graph.get_staging_dtype_for(weight_data));

graph.prepack_nodes().emplace_back(new PrepackNode(
graph,
Expand Down
Loading