Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ${layout_declare_tensor(B, "r", "mat2_tensor", DTYPE, "texture3d")}
$if HAS_BIAS:
${layout_declare_tensor(B, "r", "bias_tensor", DTYPE, "texture3d")}
${layout_declare_ubo(B, "ivec4", "out_sizes")}
${layout_declare_ubo(B, "ivec3", "out_logical_limits")}
${layout_declare_ubo(B, "ivec3", "out_limits")}
${layout_declare_ubo(B, "ivec4", "out_axis_map")}
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
${layout_declare_ubo(B, "ivec4", "mat1_axis_map")}
Expand Down Expand Up @@ -63,11 +63,11 @@ vec4 get_bias_texel_W_packed(ivec3 logical_pos) {
}
#endif // HAS_BIAS

vec4 matmul_naive_k_dim_packed(const ivec3 out_mpos) {
vec4 matmul_naive_k_dim_packed(const ivec3 out_lpos) {
ivec3 mat1_pos;
mat1_pos[mat1_axis_map.x] = 0;
mat1_pos[mat1_axis_map.y] = out_mpos.y;
mat1_pos[mat1_axis_map.z] = out_mpos.z;
mat1_pos[mat1_axis_map.y] = out_lpos.y;
mat1_pos[mat1_axis_map.z] = out_lpos.z;
#ifdef MAT2_IS_TRANSPOSED
const int mat2_k_axis = mat2_axis_map.x;
const int mat2_row_axis = mat2_axis_map.y;
Expand All @@ -88,9 +88,9 @@ vec4 matmul_naive_k_dim_packed(const ivec3 out_mpos) {
// latency. Surprisingly, this doesn't translate to mat1_pos.
ivec3 mat2_pos = ivec3(0);
mat2_pos[mat2_k_axis] = i;
mat2_pos[mat2_row_axis] = out_mpos.x * 4 + r;
mat2_pos[mat2_row_axis] = out_lpos.x * 4 + r;
#ifndef MAT2_IS_TRANSPOSED
mat2_pos[mat2_axis_map.z] = out_mpos.z;
mat2_pos[mat2_axis_map.z] = out_lpos.z;
#endif // MAT2_IS_TRANSPOSED
sums[r] = dot(mat1_tex, texelFetch(mat2_tensor, mat2_pos, 0));
}
Expand All @@ -103,16 +103,16 @@ vec4 matmul_naive_k_dim_packed(const ivec3 out_mpos) {
return texel;
}

vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_mpos) {
vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) {
ivec3 mat1_pos;
mat1_pos[mat1_axis_map.x] = 0;
mat1_pos[mat1_axis_map.y] = out_mpos.y;
mat1_pos[mat1_axis_map.z] = out_mpos.z;
mat1_pos[mat1_axis_map.y] = out_lpos.y;
mat1_pos[mat1_axis_map.z] = out_lpos.z;

ivec3 mat2_pos;
mat2_pos[mat2_axis_map.x] = out_mpos.x;
mat2_pos[mat2_axis_map.x] = out_lpos.x;
mat2_pos[mat2_axis_map.y] = 0;
mat2_pos[mat2_axis_map.z] = out_mpos.z;
mat2_pos[mat2_axis_map.z] = out_lpos.z;

ivec3 mat2_pos_offset = ivec3(0);
mat2_pos_offset[mat2_axis_map.y] = 1;
Expand All @@ -131,9 +131,9 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_mpos) {
// On-demand construction of mat2_pos appears to provide the lowest
// latency. Surprisingly, this doesn't translate to mat1_pos.
ivec3 mat2_pos = ivec3(0);
mat2_pos[mat2_axis_map.x] = out_mpos.x;
mat2_pos[mat2_axis_map.x] = out_lpos.x;
mat2_pos[mat2_axis_map.y] = 4 * i + r;
mat2_pos[mat2_axis_map.z] = out_mpos.z;
mat2_pos[mat2_axis_map.z] = out_lpos.z;

vec4 mat1_comp_vec = vec4(mat1_tex[r]);
texel = fma(mat1_comp_vec, texelFetch(mat2_tensor, mat2_pos, 0), texel);
Expand All @@ -144,33 +144,31 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_mpos) {
}

void main() {
const ivec3 out_mpos = ivec3(gl_GlobalInvocationID);
if (any(greaterThanEqual(out_mpos, out_logical_limits))) {
const ivec3 out_lpos = ivec3(gl_GlobalInvocationID);
if (any(greaterThanEqual(out_lpos, out_limits))) {
return;
}

vec4 texel = vec4(0);

#ifdef MAT2_IS_TRANSPOSED
if (mat2_packed_dim == W_DIM) {
texel = matmul_naive_k_dim_packed(out_mpos);
texel = matmul_naive_k_dim_packed(out_lpos);
} else {
texel = matmul_naive_k_dim_packed_row_dim_packed(out_mpos);
texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos);
}
#else
if (mat2_packed_dim == W_DIM) {
texel = matmul_naive_k_dim_packed_row_dim_packed(out_mpos);
texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos);
} else {
texel = matmul_naive_k_dim_packed(out_mpos);
texel = matmul_naive_k_dim_packed(out_lpos);
}
#endif // MAT2_IS_TRANSPOSED

#ifdef HAS_BIAS
vec4 bias_texel = get_bias_texel_W_packed(out_mpos);
vec4 bias_texel = get_bias_texel_W_packed(out_lpos);
texel = beta * bias_texel + alpha * texel;
#endif // HAS_BIAS

ivec3 out_pos = to_texture_pos(out_mpos, out_axis_map);

imageStore(out_tensor, out_pos, texel);
imageStore(out_tensor, lpos_to_pos(out_lpos, out_axis_map), texel);
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ FloatMatrix matmul_partial(const ivec4 out_idx_tl) {
//

void write_results_C_packed(const ivec4 out_idx_tl, FloatMatrix results) {
ivec3 out_pos = to_texture_pos(
ivec3 out_pos = tidx_to_pos(
out_idx_tl, out_sizes, out_axis_map, out_packed_dim);

for (int tile_c = 0;
Expand Down
18 changes: 8 additions & 10 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,26 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
layout(constant_id = 3) const int packed_dim = C_DIM;

void main() {
// pos is physical (x, y, z), as global workgroup uses image extents
const ivec3 pos = ivec3(gl_GlobalInvocationID);
// physical pos (x, y, z) -> logical (w, c, h, n) output
const ivec4 idx = to_tensor_idx(pos, out_sizes, out_axis_map, packed_dim);
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
const ivec4 tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, packed_dim);

if (any(greaterThanEqual(idx, out_sizes))) {
if (any(greaterThanEqual(tidx, out_sizes))) {
return;
}

// broadcast on logical sizes
ivec4 in_idx = broadcast_indices(idx, in_sizes);
ivec4 in_idx = broadcast_indices(tidx, in_sizes);
VEC4_T in_texel = VEC4_T(load_texel(
t_in,
// read axis mapped texel
to_texture_pos(in_idx, in_sizes, in_axis_map, packed_dim)));
tidx_to_pos(in_idx, in_sizes, in_axis_map, packed_dim)));

// broadcast on logical sizes
ivec4 other_idx = broadcast_indices(idx, other_sizes);
ivec4 other_idx = broadcast_indices(tidx, other_sizes);
VEC4_T other_texel = VEC4_T(load_texel(
t_other,
// read axis mapped texel
to_texture_pos(other_idx, other_sizes, other_axis_map, packed_dim)));
tidx_to_pos(other_idx, other_sizes, other_axis_map, packed_dim)));

// Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
if (broadcast_params.x > 0) {
Expand All @@ -68,6 +66,6 @@ void main() {
}

imageStore(t_out,
to_texture_pos(idx, out_sizes, out_axis_map, packed_dim),
tidx_to_pos(tidx, out_sizes, out_axis_map, packed_dim),
VEC4_T(op(in_texel, other_texel, alpha)));
}
10 changes: 5 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
layout(constant_id = 3) const int UNUSED_packed_dim = W_DIM;

void main() {
int out_id = int(gl_GlobalInvocationID.x);
if (out_id >= numel) {
int nchwi = int(gl_GlobalInvocationID.x);
if (nchwi >= numel) {
return;
}

ivec4 t_in_idx = from_nchw_buffer_i(out_id, in_sizes);
const int in_id = to_buffer_id(t_in_idx, in_strides);
ivec4 in_tidx = nchwi_to_tidx(nchwi, in_sizes);
const int in_bufi = tidx_to_bufi(in_tidx, in_strides);

nchw_buf[out_id] = t_in[in_id];
nchw_buf[nchwi] = t_in[in_bufi];
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void main() {
}

// Map tensor_idx to normal buffer_i
const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim);
const ivec4 p0 = tidx_to_nchwi(idx, sizes, packed_dim);

// Compute modified tensor_idx by inverting the CPU function
const int N = original_sizes.w;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void main() {
}

// Map tensor_idx to normal buffer_i
const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim);
const ivec4 p0 = tidx_to_nchwi(idx, sizes, packed_dim);

// Compute modified tensor_idx by inverting the CPU function
const int N = original_sizes.w;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void main() {
}

// Map tensor_idx to normal buffer_i
const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim);
const ivec4 p0 = tidx_to_nchwi(idx, sizes, packed_dim);

// Compute modified tensor_idx by inverting the CPU function
const int N = original_sizes.w;
Expand Down
14 changes: 7 additions & 7 deletions backends/vulkan/runtime/graph/ops/glsl/embedding.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
layout(constant_id = 3) const int packed_dim = C_DIM;

void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
const ivec4 out_idx = to_tensor_idx(out_pos, sizes, out_axis_map, packed_dim);
if (any(greaterThanEqual(out_idx, sizes))) {
const ivec3 out_lpos = ivec3(gl_GlobalInvocationID);
const ivec4 out_tidx = lpos_to_tidx(out_lpos, sizes, out_axis_map.w, packed_dim);
if (any(greaterThanEqual(out_tidx, sizes))) {
return;
}
VEC4_T out_texel;

// Consider optimizing via W-packing format for t_in and t_weight.
for (int i = 0; i < 4; ++i) {
// Read input tensor for embedding index.
const ivec3 in_pos = to_texture_pos(ivec3(out_idx.y, out_idx.z * 4 + i, out_idx.w / 4), in_axis_map);
const int in_texel_elem = load_texel(t_in, in_pos)[out_idx.w % 4];
const ivec3 in_pos = lpos_to_pos(ivec3(out_tidx.y, out_tidx.z * 4 + i, out_tidx.w / 4), in_axis_map);
const int in_texel_elem = load_texel(t_in, in_pos)[out_tidx.w % 4];

// Read weight tensor for embedding.
const ivec3 weight_pos = to_texture_pos(ivec3(out_idx.x, in_texel_elem, 0), weight_axis_map);
const ivec3 weight_pos = lpos_to_pos(ivec3(out_tidx.x, in_texel_elem, 0), weight_axis_map);
out_texel[i] = load_texel(t_weight, weight_pos).x;
}

imageStore(t_out, out_pos, out_texel);
imageStore(t_out, lpos_to_pos(out_lpos, out_axis_map), out_texel);
}
12 changes: 6 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
layout(constant_id = 3) const int packed_dim = C_DIM;

void write_out_texel(VEC4_T texel, ivec4 tensor_idx) {
const ivec4 buf_indices = get_texel_nchw_buffer_ixs(
const ivec4 buf_indices = tidx_to_nchwi(
tensor_idx,
sizes,
packed_dim);
Expand All @@ -51,13 +51,13 @@ void write_out_texel(VEC4_T texel, ivec4 tensor_idx) {
}

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_map, packed_dim);
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
const ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map.w, packed_dim);

if (any(greaterThanEqual(tensor_idx, sizes))) {
if (any(greaterThanEqual(tidx, sizes))) {
return;
}

const VEC4_T intex = load_texel(t_in, pos);
write_out_texel(intex, tensor_idx);
const VEC4_T intex = load_texel(t_in, lpos_to_pos(lpos, axis_map));
write_out_texel(intex, tidx);
}
12 changes: 6 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/index_select_channel.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ void main() {
}

const ivec4 idx = to_tensor_idx(out_pos, out_sizes, packed_dim);
const ivec4 buffer_ixs = get_texel_nchw_buffer_ixs(idx, out_sizes, packed_dim);
const ivec4 buffer_ixs = tidx_to_nchwi(idx, out_sizes, packed_dim);

VEC4_T out_texel;
for (int i = 0; i < 4; ++i) {
const ivec4 out_idx = from_nchw_buffer_i(buffer_ixs[i], out_sizes);
int out_channel = out_idx.z;
const ivec4 out_tidx = nchwi_to_tidx(buffer_ixs[i], out_sizes);
int out_channel = out_tidx.z;
int in_channel = texelFetch(t_idx, ivec3(out_channel, 0, 0), 0).x;

ivec4 in_idx = out_idx;
in_idx.z = in_channel;
ivec4 in_tidx = out_tidx;
in_tidx.z = in_channel;

ivec4 in_elem_pos = to_texture_elem_pos(in_idx, in_sizes, packed_dim);
ivec4 in_elem_pos = to_texture_elem_pos(in_tidx, in_sizes, packed_dim);

VEC4_T in_texel = texelFetch(t_in, in_elem_pos.xyz, 0);

Expand Down
Loading
Loading