88
99#include < executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010
11+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112#include < executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
1213#include < executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1314
@@ -37,12 +38,12 @@ void check_matmul_args(
3738void resize_matmul_node (
3839 ComputeGraph* graph,
3940 const std::vector<ArgGroup>& args,
40- const std::vector<ValueRef>& extra_args ) {
41+ const std::vector<ValueRef>& resize_args ) {
4142 vTensorPtr out = graph->get_tensor (args[0 ].refs [0 ]);
4243 vTensorPtr mat1 = graph->get_tensor (args[1 ].refs [0 ]);
4344 vTensorPtr mat2 = graph->get_tensor (args[1 ].refs [1 ]);
4445
45- bool mat2_is_transposed = graph->get_bool (extra_args [0 ]);
46+ bool mat2_is_transposed = graph->get_bool (resize_args [0 ]);
4647
4748 const int out_cols = utils::val_at (-2 , mat1->sizes ());
4849 const int out_rows = mat2_is_transposed ? utils::val_at (-2 , mat2->sizes ())
@@ -56,6 +57,23 @@ void resize_matmul_node(
5657 out->virtual_resize (new_out_sizes);
5758}
5859
60+ /* *
61+ * Custom global workgroup size function for naive buffer matmul operations.
62+ */
63+ utils::uvec3 matmul_naive_buffer_global_wg_size (
64+ ComputeGraph* graph,
65+ const vkapi::ShaderInfo& shader,
66+ const std::vector<ArgGroup>& args,
67+ const std::vector<ValueRef>& resize_args) {
68+ (void )shader;
69+ (void )resize_args;
70+ const ValueRef out = args.at (0 ).refs .at (0 );
71+ return {
72+ graph->size_at <uint32_t >(-1 , out),
73+ graph->size_at <uint32_t >(-2 , out),
74+ graph->size_at <uint32_t >(-3 , out) * graph->size_at <uint32_t >(-4 , out)};
75+ }
76+
5977void add_matmul_naive_buffer_node (
6078 ComputeGraph& graph,
6179 const ValueRef mat1,
@@ -72,21 +90,16 @@ void add_matmul_naive_buffer_node(
7290 std::string kernel_name = " matmul_naive_buffer" ;
7391 add_dtype_suffix (kernel_name, graph.dtype_of (out));
7492
75- utils::uvec3 global_size = {
76- graph.size_at <uint32_t >(-1 , out),
77- graph.size_at <uint32_t >(-2 , out),
78- graph.size_at <uint32_t >(-3 , out) * graph.size_at <uint32_t >(-4 , out)};
79-
8093 int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
8194 graph.get_bool (mat2_is_transposed))
8295 ? 1
8396 : 0 ;
8497
85- graph.execute_nodes ().emplace_back (new DispatchNode (
98+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
8699 graph,
87100 VK_KERNEL_FROM_STR (kernel_name),
88- global_size ,
89- graph. create_local_wg_size (global_size) ,
101+ matmul_naive_buffer_global_wg_size ,
102+ default_pick_local_wg_size ,
90103 // Inputs and Outputs
91104 {{out, vkapi::kWrite }, {{mat1, mat2}, vkapi::kRead }},
92105 // Shader params buffers
@@ -109,6 +122,22 @@ void add_matmul_naive_buffer_node(
109122 resize_matmul_node));
110123}
111124
125+ vkapi::ShaderInfo pick_matmul_naive_texture3d_shader (
126+ ComputeGraph* graph,
127+ const std::vector<ArgGroup>& args,
128+ const std::vector<ValueRef>& resize_args) {
129+ const ValueRef out = args.at (0 ).refs .at (0 );
130+ const bool is_transposed = graph->get_bool (resize_args.at (0 ));
131+
132+ std::string kernel_name =
133+ is_transposed ? " matmul_transposed_naive" : " matmul_naive" ;
134+ kernel_name.reserve (kShaderNameReserve );
135+ add_storage_type_suffix (kernel_name, graph->storage_type_of (out));
136+ add_dtype_suffix (kernel_name, graph->dtype_of (out));
137+
138+ return VK_KERNEL_FROM_STR (kernel_name);
139+ }
140+
112141void add_matmul_naive_texture3d_node (
113142 ComputeGraph& graph,
114143 const ValueRef mat1,
@@ -122,19 +151,11 @@ void add_matmul_naive_texture3d_node(
122151 utils::kHeightPacked ,
123152 /* passthrough = */ true );
124153
125- std::string kernel_name = graph.get_bool (mat2_is_transposed)
126- ? " matmul_transposed_naive"
127- : " matmul_naive" ;
128- kernel_name.reserve (kShaderNameReserve );
129- add_storage_type_suffix (kernel_name, graph.storage_type_of (out));
130- add_dtype_suffix (kernel_name, graph.dtype_of (out));
131-
132- utils::uvec3 global_wg_size = graph.logical_limits_of (out);
133- graph.execute_nodes ().emplace_back (new DispatchNode (
154+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
134155 graph,
135- VK_KERNEL_FROM_STR (kernel_name) ,
136- global_wg_size ,
137- graph. create_local_wg_size (global_wg_size) ,
156+ pick_matmul_naive_texture3d_shader ,
157+ default_pick_global_wg_size ,
158+ default_pick_local_wg_size ,
138159 // Inputs and Outputs
139160 {{out, vkapi::kWrite }, {{mat1, mat2}, vkapi::kRead }},
140161 // Shader params buffers
@@ -156,6 +177,59 @@ void add_matmul_naive_texture3d_node(
156177 resize_matmul_node));
157178}
158179
180+ vkapi::ShaderInfo pick_matmul_optimized_shader (
181+ ComputeGraph* graph,
182+ const std::vector<ArgGroup>& args,
183+ const std::vector<ValueRef>& resize_args) {
184+ const ValueRef out = args.at (0 ).refs .at (0 );
185+ const ValueRef mat1_W_packed = resize_args.at (1 );
186+ const bool mat2_is_transposed_val = graph->get_bool (resize_args.at (0 ));
187+
188+ std::string kernel_name = mat2_is_transposed_val
189+ ? " matmul_transposed_optimized"
190+ : " matmul_optimized" ;
191+
192+ std::vector<int64_t > mat1_sizes = graph->sizes_of (mat1_W_packed);
193+ size_t mat1_dims = mat1_sizes.size ();
194+ if (mat1_dims == 3 ) {
195+ kernel_name = " batch_" + kernel_name;
196+ }
197+ if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
198+ kernel_name += " _tile_row_2" ;
199+ } else {
200+ kernel_name += " _tile_row_4" ;
201+ }
202+
203+ add_dtype_suffix (kernel_name, graph->dtype_of (out));
204+
205+ return VK_KERNEL_FROM_STR (kernel_name);
206+ }
207+
208+ utils::uvec3 matmul_optimized_global_wg_size (
209+ ComputeGraph* graph,
210+ const vkapi::ShaderInfo& shader,
211+ const std::vector<ArgGroup>& args,
212+ const std::vector<ValueRef>& resize_args) {
213+ (void )shader;
214+
215+ const ValueRef out = args.at (0 ).refs .at (0 );
216+ const ValueRef mat1_W_packed = resize_args.at (1 );
217+
218+ const std::vector<int64_t > mat1_sizes = graph->sizes_of (mat1_W_packed);
219+ const size_t mat1_dims = mat1_sizes.size ();
220+
221+ utils::uvec3 global_size = graph->logical_limits_of (out);
222+ if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
223+ // Use `logical_extents` instead of `image_extents` because the workgroup
224+ // axes need to correspond to tensor dimensions.
225+ global_size = utils::divup_vec (global_size, {4 , 2 , 1 });
226+ } else {
227+ global_size = utils::divup_vec (global_size, {4 , 4 , 1 });
228+ }
229+
230+ return global_size;
231+ }
232+
159233void add_matmul_optimized_node (
160234 ComputeGraph& graph,
161235 const ValueRef mat1,
@@ -192,45 +266,11 @@ void add_matmul_optimized_node(
192266 viewFn (graph, {mat2, graph.add_none (), mat2_packed});
193267 }
194268
195- std::string kernel_name = mat2_is_transposed_val
196- ? " matmul_transposed_optimized"
197- : " matmul_optimized" ;
198-
199- std::vector<int64_t > mat1_sizes = graph.sizes_of (mat1_W_packed);
200- int mat1_dims = mat1_sizes.size ();
201- if (mat1_dims == 3 ) {
202- kernel_name = " batch_" + kernel_name;
203- }
204- if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
205- kernel_name += " _tile_row_2" ;
206- } else {
207- kernel_name += " _tile_row_4" ;
208- }
209-
210- add_dtype_suffix (kernel_name, graph.dtype_of (out));
211-
212- // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the
213- // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is
214- // channels packed, C does not need to be divided by 4. The "identity" of each
215- // thread is the (x, y, z) coordinate of the output tile it is computing, and
216- // this identity can be used to compute the tensor index of the top left
217- // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0]
218- utils::uvec3 global_size = graph.logical_limits_of (out);
219- if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
220- // Use `logical_extents` instead of `image_extents` because the workgroup
221- // axes need to correspond to tensor dimensions.
222- global_size = utils::divup_vec (global_size, {4 , 2 , 1 });
223- } else {
224- global_size = utils::divup_vec (global_size, {4 , 4 , 1 });
225- }
226-
227- utils::uvec3 local_size = adaptive_work_group_size (global_size);
228-
229- graph.execute_nodes ().emplace_back (new DispatchNode (
269+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
230270 graph,
231- VK_KERNEL_FROM_STR (kernel_name) ,
232- global_size ,
233- local_size ,
271+ pick_matmul_optimized_shader ,
272+ matmul_optimized_global_wg_size ,
273+ default_pick_local_wg_size ,
234274 // Inputs and Outputs
235275 {{out, vkapi::kWrite }, {{mat1_W_packed, mat2_packed}, vkapi::kRead }},
236276 // Shader params buffers
@@ -246,7 +286,7 @@ void add_matmul_optimized_node(
246286 graph.hashed_layout_of (mat1_W_packed),
247287 graph.hashed_layout_of (mat2_packed)},
248288 // Resize Args
249- {mat2_is_transposed},
289+ {mat2_is_transposed, mat1_W_packed },
250290 // Resizing Logic
251291 resize_matmul_node));
252292}
0 commit comments