diff --git a/aten/src/ATen/native/vulkan/api/Common.h b/aten/src/ATen/native/vulkan/api/Common.h index be89073e90ba..376914ad0ead 100644 --- a/aten/src/ATen/native/vulkan/api/Common.h +++ b/aten/src/ATen/native/vulkan/api/Common.h @@ -50,6 +50,8 @@ namespace native { namespace vulkan { namespace api { +const int64_t MAX_STACK_DEPTH=2048*4; + struct Adapter; struct Command; class Context; diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl index 0a58c5d0a2f6..6e1e8f121dae 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl @@ -18,6 +18,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block ivec2 padding; ivec2 dilate; vec2 clamp; + int stacks_per_tower; } uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; @@ -28,7 +29,9 @@ void main() { /* Dynamically Uniform */ const ivec3 size = imageSize(uOutput); const ivec3 isize = textureSize(uInput, 0); - const ivec4 block = pos.z * uBlock.kernel.z + ivec4(0, 1, 2, 3); + const int tower = pos.z/(uBlock.stacks_per_tower); + const int tower_offset = pos.z % uBlock.stacks_per_tower; + const ivec4 block = tower_offset * uBlock.kernel.z + ivec4(0, 1, 2, 3); if (all(lessThan(pos, size))) { const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; @@ -46,10 +49,10 @@ void main() { for (int x = start.x, kx = kstart.x; x < end.x; x += uBlock.dilate.x, ++kx) { const vec4 In = texelFetch(uInput, ivec3(x, y, z), 0); - sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, ky, kz.x), 0), sum); - sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, ky, kz.y), 0), sum); - sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, ky, kz.z), 0), sum); - sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, ky, kz.w), 0), sum); + sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.x), 0), sum); + sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.y), 0), sum); + sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.z), 0), sum); + sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.w), 0), sum); } } } diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl index c2962844e0bc..0bf0de43abda 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl @@ -17,7 +17,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block ivec2 stride; ivec2 padding; vec2 clamp; - int W; + int stacks_per_tower; } uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; @@ -28,7 +28,9 @@ void main() { /* Dynamically Uniform */ const ivec3 size = imageSize(uOutput); const ivec3 isize = textureSize(uInput, 0); - const ivec4 block = pos.z * uBlock.kernel.x + ivec4(0, 1, 2, 3); + const int tower = pos.z/(uBlock.stacks_per_tower); + const int tower_offset = pos.z % uBlock.stacks_per_tower; + const ivec4 block = tower_offset * uBlock.kernel.x + ivec4(0, 1, 2, 3); if (all(lessThan(pos, size))) { const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; @@ -39,37 +41,10 @@ void main() { const vec4 In = texelFetch(uInput, ivec3(ipos.x, ipos.y, z), 0); const ivec4 kz = block + 4 * z; - const int W = uBlock.W; - - const vec4 val1 = vec4( - texelFetch(uKernel, ivec3((4*kz.x+0)%W, ((4*kz.x+0))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.x+1)%W, ((4*kz.x+1))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.x+2)%W, ((4*kz.x+2))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.x+3)%W, ((4*kz.x+3))/W, 0), 0).x - ); - const vec4 val2 = vec4( - texelFetch(uKernel, ivec3((4*kz.y+0)%W, ((4*kz.y+0))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.y+1)%W, ((4*kz.y+1))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.y+2)%W, ((4*kz.y+2))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.y+3)%W, ((4*kz.y+3))/W, 0), 0).x - ); - const vec4 val3 = vec4( - texelFetch(uKernel, ivec3((4*kz.z+0)%W, ((4*kz.z+0))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.z+1)%W, ((4*kz.z+1))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.z+2)%W, ((4*kz.z+2))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.z+3)%W, ((4*kz.z+3))/W, 0), 0).x - ); - const vec4 val4 = vec4( - texelFetch(uKernel, ivec3((4*kz.w+0)%W, ((4*kz.w+0))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.w+1)%W, ((4*kz.w+1))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.w+2)%W, ((4*kz.w+2))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.w+3)%W, ((4*kz.w+3))/W, 0), 0).x - ); - - sum = fma(In.xxxx, val1, sum); - sum = fma(In.yyyy, val2, sum); - sum = fma(In.zzzz, val3, sum); - sum = fma(In.wwww, val4, sum); + sum = fma(In.xxxx, texelFetch(uKernel, ivec3(0, tower, kz.x), 0), sum); + sum = fma(In.yyyy, texelFetch(uKernel, ivec3(0, tower, kz.y), 0), sum); + sum = fma(In.zzzz, texelFetch(uKernel, ivec3(0, tower, kz.z), 0), sum); + sum = fma(In.wwww, texelFetch(uKernel, ivec3(0, tower, kz.w), 0), sum); } imageStore( diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index 7cf3b4fe5137..4e63dbf2b8de 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -42,10 +42,10 @@ vTensor pack_weights( if (is_depthwise(src_filter, groups)) { vTensor v_weight{ - api::context(), - &pool, - src_filter, - weight.options(), + api::context(), + &pool, + src_filter, + weight.options(), }; using Future = vTensor::Future; @@ -66,16 +66,26 @@ vTensor pack_weights( using namespace api::utils; + const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], 4); + const int64_t stack_depth = + 4 * api::utils::align_up(src_filter[Layout::Filter::input], 4); + const int64_t max_stacks_per_tower = + at::native::vulkan::api::MAX_STACK_DEPTH / stack_depth; + const int64_t num_towers = div_up(num_stacks, max_stacks_per_tower); + int64_t stacks_per_tower = num_stacks; + if (num_towers > 1) { + stacks_per_tower = div_up(num_stacks, num_towers); + } vTensor v_weight{ - api::context(), - &pool, - { - div_up(src_filter[Layout::Filter::output], 4), - 4 * align_up(src_filter[Layout::Filter::input], 4), - src_filter[Layout::Filter::height], - src_filter[Layout::Filter::width], - }, - weight.options(), + api::context(), + &pool, + { + stacks_per_tower, + stack_depth, + src_filter[Layout::Filter::height] * num_towers, + src_filter[Layout::Filter::width], + }, + weight.options(), }; using Future = vTensor::Future; @@ -83,38 +93,49 @@ vTensor pack_weights( Future::Payload v_weight_payload = v_weight_future.wait(); /* Source */ - const int64_t src_kernel = src_filter[Layout::Filter::height] * src_filter[Layout::Filter::width]; - const int64_t src_block = src_kernel * src_filter[Layout::Filter::input]; + const int64_t src_kw_sz = src_filter[Layout::Filter::width]; + const int64_t src_kh_sz = src_filter[Layout::Filter::height]; + const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; + const int64_t src_block_sz = + src_kernel_sz * src_filter[Layout::Filter::input]; /* Destination */ const IntArrayRef dst_filter = v_weight.sizes(); - const int64_t dst_kernel = dst_filter[Layout::Filter::height] * dst_filter[Layout::Filter::width]; - const int64_t dst_block = dst_kernel * dst_filter[Layout::Filter::input]; - TORCH_INTERNAL_ASSERT(src_kernel == dst_kernel, "Internal error!"); + const int64_t dst_kw_sz = src_filter[Layout::Filter::width]; + const int64_t dst_kh_sz = src_filter[Layout::Filter::height] * num_towers; + const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + const int64_t dst_block_sz = + dst_kernel_sz * dst_filter[Layout::Filter::input]; + // TORCH_INTERNAL_ASSERT(src_kernel_sz == dst_kernel_sz, "Internal error!"); float* const dst_weight_ptr = v_weight_payload.get(); memset(dst_weight_ptr, 0, v_weight.nbytes()); - for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) { - /* Source */ - const float *const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block; + for (int64_t i_tower = 0; i_tower < num_towers; ++i_tower) { + const float* const src_tower_ptr = + src_weight_ptr + (i_tower * stacks_per_tower * 4) * src_block_sz; + for (int64_t src_oc = 0; src_oc < (stacks_per_tower * 4); ++src_oc) { + /* Source */ + const float* const src_weight_oc_ptr = + src_tower_ptr + src_oc * src_block_sz; - /* Destination */ - const int64_t dst_oc = src_oc / 4; - const int64_t dst_oc_offset = src_oc % 4; + /* Destination */ + const int64_t dst_oc = src_oc / 4; + const int64_t dst_oc_offset = src_oc % 4; - float* const dst_weight_oc_ptr = - dst_weight_ptr + - dst_oc * dst_block + - dst_oc_offset * dst_kernel; + float* const dst_weight_oc_ptr = dst_weight_ptr + dst_oc * dst_block_sz + + dst_oc_offset * dst_kernel_sz; - for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) { - const int64_t dst_ic = 4 * src_ic; + for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; + ++src_ic) { + const int64_t dst_ic = 4 * src_ic; - memcpy( - dst_weight_oc_ptr + dst_ic * dst_kernel, - src_weight_oc_ptr + src_ic * src_kernel, - sizeof(float) * dst_kernel); + memcpy( + dst_weight_oc_ptr + dst_ic * dst_kernel_sz + + (i_tower * src_kernel_sz), + src_weight_oc_ptr + src_ic * src_kernel_sz, + sizeof(float) * src_kernel_sz); + } } } @@ -339,31 +360,14 @@ void conv2d_pointwise( using namespace api::utils; if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - - vTensor v_weight_reshaped{ - context, - {1,1, v_weight.sizes()[0], v_weight.sizes()[1]}, - v_input.options(), - }; - - api::Command::Buffer temp_command_buffer = - api::context()->command().pool.allocate(); - temp_command_buffer.begin(); - - temp_command_buffer.copy( - v_weight.buffer(temp_command_buffer), - v_weight_reshaped.buffer(temp_command_buffer, vTensor::Access::Write) - ); - - temp_command_buffer.end(); - temp_command_buffer.submit(api::context()->gpu().queue); + const int64_t stacks_per_tower = v_weight.sizes()[0]; const struct { int32_t kernel_ic, kernel_oc; int32_t stride_x, stride_y; int32_t padding_x, padding_y; float clamp_x, clamp_y; - int32_t w; + int32_t stacks_per_tower; } block { safe_downcast(filter[Layout::Filter::input]), safe_downcast(filter[Layout::Filter::output]), @@ -373,7 +377,7 @@ void conv2d_pointwise( safe_downcast(padding[Layout::Parameter::height]), output_min, output_max, - v_weight.sizes()[1], + safe_downcast(stacks_per_tower), }; context->dispatch( @@ -395,7 +399,7 @@ void conv2d_pointwise( v_input.image(command_buffer), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. - v_weight_reshaped.image(command_buffer, vTensor::Access::Read), + v_weight.image(command_buffer), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. v_bias.buffer(command_buffer), @@ -424,12 +428,14 @@ void conv2d( using namespace api::utils; if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const int64_t stacks_per_tower = v_weight.sizes()[0]; const struct { int32_t kernel_x, kernel_y, kernel_ic, kernel_oc; int32_t stride_x, stride_y; int32_t padding_x, padding_y; int32_t dilate_x, dilate_y; float clamp_x, clamp_y; + int32_t stacks_per_tower; } block { safe_downcast(filter[Layout::Filter::width]), safe_downcast(filter[Layout::Filter::height]), @@ -443,6 +449,7 @@ void conv2d( safe_downcast(dilation[Layout::Parameter::height]), output_min, output_max, + safe_downcast(stacks_per_tower), }; context->dispatch( @@ -706,7 +713,7 @@ c10::intrusive_ptr conv2d_clamp_prepack( /* output_padding = */ {}, groups, output_min, - output_min)); + output_max)); } Tensor conv2d_clamp_run(