@@ -33,16 +33,6 @@ void add_copy_offset_node(
3333 add_dtype_suffix (kernel_name, *t_out);
3434 add_storage_type_suffix (kernel_name, *t_out);
3535
36- const struct Block final {
37- alignas (16 ) ivec3 range;
38- alignas (16 ) ivec3 src_offset;
39- alignas (16 ) ivec3 dst_offset;
40- } offset_params{
41- range,
42- src_offset,
43- dst_offset,
44- };
45-
4636 auto shader = VK_KERNEL_FROM_STR (kernel_name);
4737
4838 graph.execute_nodes ().emplace_back (new DispatchNode (
@@ -56,11 +46,18 @@ void add_copy_offset_node(
5646 {in, vkapi::kRead },
5747 },
5848 // Parameter buffers
59- {
60- graph.create_params_buffer (offset_params),
61- },
49+ {},
6250 // Specialization Constants
63- {graph.hashed_layout_of (out), graph.hashed_layout_of (in)}));
51+ {graph.hashed_layout_of (out), graph.hashed_layout_of (in)},
52+ nullptr ,
53+ {},
54+ {
55+ PushConstantDataInfo (&range, sizeof (range), sizeof (utils::ivec4)),
56+ PushConstantDataInfo (
57+ &src_offset, sizeof (src_offset), sizeof (utils::ivec4)),
58+ PushConstantDataInfo (
59+ &dst_offset, sizeof (dst_offset), sizeof (utils::ivec4)),
60+ }));
6461}
6562
6663void add_copy_channel_offset_node (
@@ -128,28 +125,23 @@ void add_copy_channel_offset_node(
128125 // The shader combines the global invocation id and the dst_offset to get
129126 // the actual coordinate.
130127
131- ivec3 dst_offset{
128+ const ivec3 dst_offset{
132129 0 , 0 , dst_first_z + batch_idx * utils::div_up_4 (out_channels)};
133130
134- uvec3 global_size{
131+ const uvec3 global_size{
135132 utils::safe_downcast<uint32_t >(dim_at<kWidth4D >(in_sizes)),
136133 utils::safe_downcast<uint32_t >(dim_at<kHeight4D >(in_sizes)),
137134 utils::safe_downcast<uint32_t >(dst_last_z - dst_first_z + 1 )};
138- uvec3 local_size = graph.create_local_wg_size (global_size);
139-
140- const struct Block final {
141- ivec3 range;
142- int32_t channel_range;
143- ivec3 dst_offset;
144- int32_t dst_channel_offset;
145- int32_t src_channel_offset;
146- } channel_offset_params{
147- utils::make_ivec3 (global_size),
148- channel_range,
149- dst_offset,
150- dst_channel_offset,
151- src_channel_offset,
152- };
135+ const uvec3 local_size = graph.create_local_wg_size (global_size);
136+
137+ const utils::ivec4 range_params = {
138+ static_cast <int >(global_size[0 ]),
139+ static_cast <int >(global_size[1 ]),
140+ static_cast <int >(global_size[2 ]),
141+ channel_range};
142+
143+ const utils::ivec4 offset_params = {
144+ dst_offset[0 ], dst_offset[1 ], dst_offset[2 ], dst_channel_offset};
153145
154146 auto shader = VK_KERNEL_FROM_STR (kernel_name);
155147
@@ -165,13 +157,17 @@ void add_copy_channel_offset_node(
165157 {in, vkapi::MemoryAccessType::READ},
166158 },
167159 // Parameter buffers
168- {
169- t_out->sizes_ubo (),
170- t_in->sizes_ubo (),
171- graph.create_params_buffer (channel_offset_params),
172- },
160+ {},
173161 // Specialization Constants
174- {graph.hashed_layout_of (out), graph.hashed_layout_of (in)}));
162+ {graph.hashed_layout_of (out), graph.hashed_layout_of (in)},
163+ nullptr ,
164+ {},
165+ {graph.sizes_pc_of (out),
166+ graph.sizes_pc_of (in),
167+ PushConstantDataInfo (&range_params, sizeof (range_params)),
168+ PushConstantDataInfo (&offset_params, sizeof (offset_params)),
169+ PushConstantDataInfo (
170+ &src_channel_offset, sizeof (src_channel_offset))}));
175171 }
176172}
177173
0 commit comments