From dfd372ead74773eececbe2c91ab14ecf30e2ff6f Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 2 Dec 2020 13:16:54 -0500 Subject: [PATCH] [vulkan] Distribute weight prepacking along y dimension for conv2d ghstack-source-id: f814ba805d5b801d7a8f9e7b652209c2c5487051 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48266 --- aten/src/ATen/native/vulkan/glsl/conv2d.glsl | 19 +-- .../ATen/native/vulkan/glsl/conv2d_pw.glsl | 47 ++------ aten/src/ATen/native/vulkan/ops/Common.h | 4 + .../ATen/native/vulkan/ops/Convolution.cpp | 113 +++++++++--------- 4 files changed, 82 insertions(+), 101 deletions(-) diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl index fd54c2f38721..9646eb8c9f19 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl @@ -18,6 +18,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block ivec2 padding; ivec2 dilate; vec2 clamp; + int stacks_per_tower; } uBlock; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -28,7 +29,9 @@ void main() { /* Dynamically Uniform */ const ivec3 size = imageSize(uOutput); const ivec3 isize = textureSize(uInput, 0); - const ivec4 block = pos.z * uBlock.kernel.z + ivec4(0, 1, 2, 3); + const int tower = pos.z/(uBlock.stacks_per_tower); + const int tower_offset = pos.z % uBlock.stacks_per_tower; + const ivec4 block = tower_offset * uBlock.kernel.z + ivec4(0, 1, 2, 3); if (all(lessThan(pos, size))) { const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; @@ -39,17 +42,17 @@ void main() { vec4 sum = uBias.data[pos.z]; - for (int z = 0; z < uBlock.kernel.z; ++z) { - const ivec4 kz = block + 4 * z; + for (int z = 0; z < uBlock.kernel.z; z+=4) { + const ivec4 kz = block + z; for (int y = start.y, ky = kstart.y; y < end.y; y += uBlock.dilate.y, ++ky) { for (int x = start.x, kx = kstart.x; x < end.x; x += uBlock.dilate.x, ++kx) { - const vec4 In = texelFetch(uInput, ivec3(x, y, z), 0); + const vec4 In = texelFetch(uInput, ivec3(x, y, z/4), 0); - sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, ky, kz.x), 0), sum); - sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, ky, kz.y), 0), sum); - sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, ky, kz.z), 0), sum); - sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, ky, kz.w), 0), sum); + sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.x), 0), sum); + sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.y), 0), sum); + sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.z), 0), sum); + sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.w), 0), sum); } } } diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl index bbc745ca8efd..48d9f785008b 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl @@ -17,7 +17,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block ivec2 stride; ivec2 padding; vec2 clamp; - int W; + int stacks_per_tower; } uBlock; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -28,48 +28,23 @@ void main() { /* Dynamically Uniform */ const ivec3 size = imageSize(uOutput); const ivec3 isize = textureSize(uInput, 0); - const ivec4 block = pos.z * uBlock.kernel.x + ivec4(0, 1, 2, 3); + const int tower = pos.z/(uBlock.stacks_per_tower); + const int tower_offset = pos.z % uBlock.stacks_per_tower; + const ivec4 block = tower_offset * uBlock.kernel.x + ivec4(0, 1, 2, 3); if (all(lessThan(pos, size))) { const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; vec4 sum = uBias.data[pos.z]; - for (int z = 0; z < uBlock.kernel.x; ++z) { - const vec4 In = texelFetch(uInput, ivec3(ipos.x, ipos.y, z), 0); - const ivec4 kz = block + 4 * z; + for (int z = 0; z < uBlock.kernel.x; z+=4) { + const vec4 In = texelFetch(uInput, ivec3(ipos.x, ipos.y, z/4), 0); + const ivec4 kz = block + z; - const int W = uBlock.W; - - const vec4 val1 = vec4( - texelFetch(uKernel, ivec3((4*kz.x+0)%W, ((4*kz.x+0))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.x+1)%W, ((4*kz.x+1))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.x+2)%W, ((4*kz.x+2))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.x+3)%W, ((4*kz.x+3))/W, 0), 0).x - ); - const vec4 val2 = vec4( - texelFetch(uKernel, ivec3((4*kz.y+0)%W, ((4*kz.y+0))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.y+1)%W, ((4*kz.y+1))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.y+2)%W, ((4*kz.y+2))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.y+3)%W, ((4*kz.y+3))/W, 0), 0).x - ); - const vec4 val3 = vec4( - texelFetch(uKernel, ivec3((4*kz.z+0)%W, ((4*kz.z+0))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.z+1)%W, ((4*kz.z+1))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.z+2)%W, ((4*kz.z+2))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.z+3)%W, ((4*kz.z+3))/W, 0), 0).x - ); - const vec4 val4 = vec4( - texelFetch(uKernel, ivec3((4*kz.w+0)%W, ((4*kz.w+0))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.w+1)%W, ((4*kz.w+1))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.w+2)%W, ((4*kz.w+2))/W, 0), 0).x, - texelFetch(uKernel, ivec3((4*kz.w+3)%W, ((4*kz.w+3))/W, 0), 0).x - ); - - sum = fma(In.xxxx, val1, sum); - sum = fma(In.yyyy, val2, sum); - sum = fma(In.zzzz, val3, sum); - sum = fma(In.wwww, val4, sum); + sum = fma(In.xxxx, texelFetch(uKernel, ivec3(0, tower, kz.x), 0), sum); + sum = fma(In.yyyy, texelFetch(uKernel, ivec3(0, tower, kz.y), 0), sum); + sum = fma(In.zzzz, texelFetch(uKernel, ivec3(0, tower, kz.z), 0), sum); + sum = fma(In.wwww, texelFetch(uKernel, ivec3(0, tower, kz.w), 0), sum); } imageStore( diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h index d8b99052795c..6f7080f71a80 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.h +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -39,6 +39,10 @@ struct Experimentation { static constexpr bool kUseConv2dOldApi = true; }; +struct ConvPrepackLimits final { + static constexpr int64_t maxStackDepth = 2048*4; +}; + } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index 0a12c7b4933a..a77b1935eda6 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -45,10 +45,10 @@ vTensor pack_weights( if (is_depthwise(src_filter, groups)) { vTensor v_weight{ - api::context(), - &pool, - src_filter, - weight.options(), + api::context(), + &pool, + src_filter, + weight.options(), }; using Future = vTensor::Future; @@ -153,16 +153,26 @@ vTensor pack_weights( return v_weight; } + const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); + const int64_t stack_depth = + 4 * api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); + const int64_t max_stacks_per_tower = + ConvPrepackLimits::maxStackDepth / stack_depth; + const int64_t num_towers = div_up(num_stacks, max_stacks_per_tower); + int64_t stacks_per_tower = num_stacks; + if (num_towers > 1) { + stacks_per_tower = div_up(num_stacks, num_towers); + } vTensor v_weight{ - api::context(), - &pool, - { - div_up(src_filter[Layout::Filter::output], INT64_C(4)), - 4 * align_up(src_filter[Layout::Filter::input], INT64_C(4)), - src_filter[Layout::Filter::height], - src_filter[Layout::Filter::width], - }, - weight.options(), + api::context(), + &pool, + { + stacks_per_tower, + stack_depth, + src_filter[Layout::Filter::height] * num_towers, + src_filter[Layout::Filter::width], + }, + weight.options(), }; using Future = vTensor::Future; @@ -170,38 +180,47 @@ vTensor pack_weights( Future::Payload v_weight_payload = v_weight_future.wait(); /* Source */ - const int64_t src_kernel = src_filter[Layout::Filter::height] * src_filter[Layout::Filter::width]; - const int64_t src_block = src_kernel * src_filter[Layout::Filter::input]; + const int64_t src_kw_sz = src_filter[Layout::Filter::width]; + const int64_t src_kh_sz = src_filter[Layout::Filter::height]; + const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; + const int64_t src_block_sz = + src_kernel_sz * src_filter[Layout::Filter::input]; /* Destination */ const IntArrayRef dst_filter = v_weight.sizes(); - const int64_t dst_kernel = dst_filter[Layout::Filter::height] * dst_filter[Layout::Filter::width]; - const int64_t dst_block = dst_kernel * dst_filter[Layout::Filter::input]; - TORCH_INTERNAL_ASSERT(src_kernel == dst_kernel, "Internal error!"); + const int64_t dst_kw_sz = src_filter[Layout::Filter::width]; + const int64_t dst_kh_sz = src_filter[Layout::Filter::height] * num_towers; + const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + const int64_t dst_block_sz = + dst_kernel_sz * dst_filter[Layout::Filter::input]; + + TORCH_INTERNAL_ASSERT(src_kernel_sz*num_towers == dst_kernel_sz, "Internal error!"); float* const dst_weight_ptr = v_weight_payload.get(); memset(dst_weight_ptr, 0, v_weight.nbytes()); for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) { + const int64_t i_tower = src_oc / (stacks_per_tower * 4); /* Source */ - const float *const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block; + const float* const src_weight_oc_ptr = + src_weight_ptr + src_oc * src_block_sz; /* Destination */ - const int64_t dst_oc = src_oc / 4; - const int64_t dst_oc_offset = src_oc % 4; + const int64_t local_oc = src_oc % (stacks_per_tower * 4); + const int64_t dst_oc = local_oc / 4; + const int64_t dst_oc_offset = local_oc % 4; - float* const dst_weight_oc_ptr = - dst_weight_ptr + - dst_oc * dst_block + - dst_oc_offset * dst_kernel; + float* const dst_weight_oc_ptr = dst_weight_ptr + dst_oc * dst_block_sz + + dst_oc_offset * dst_kernel_sz; for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) { const int64_t dst_ic = 4 * src_ic; memcpy( - dst_weight_oc_ptr + dst_ic * dst_kernel, - src_weight_oc_ptr + src_ic * src_kernel, - sizeof(float) * dst_kernel); + dst_weight_oc_ptr + dst_ic * dst_kernel_sz + + (i_tower * src_kernel_sz), + src_weight_oc_ptr + src_ic * src_kernel_sz, + sizeof(float) * src_kernel_sz); } } @@ -431,36 +450,14 @@ void conv2d_pointwise( const float output_min, const float output_max) { if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { - - vTensor v_weight_reshaped{ - context, - {1,1, v_weight.sizes()[0], v_weight.sizes()[1]}, - v_input.options(), - }; - - api::Command::Buffer temp_command_buffer = - api::context()->command().pool.allocate(); - temp_command_buffer.begin(); - - temp_command_buffer.copy( - v_weight.buffer( - temp_command_buffer, - vTensor::Stage::Transfer), - v_weight_reshaped.buffer( - temp_command_buffer, - vTensor::Stage::Transfer, - vTensor::Access::Write) - ); - - temp_command_buffer.end(); - temp_command_buffer.submit(api::context()->gpu().queue); + const int64_t stacks_per_tower = v_weight.sizes()[0]; const struct { int32_t kernel_ic, kernel_oc; int32_t stride_x, stride_y; int32_t padding_x, padding_y; float clamp_x, clamp_y; - int32_t w; + int32_t stacks_per_tower; } block { safe_downcast(filter[Layout::Filter::input]), safe_downcast(filter[Layout::Filter::output]), @@ -470,7 +467,7 @@ void conv2d_pointwise( safe_downcast(padding[Layout::Parameter::height]), output_min, output_max, - v_weight.sizes()[1], + safe_downcast(stacks_per_tower), }; context->dispatch( @@ -497,10 +494,9 @@ void conv2d_pointwise( vTensor::Stage::Compute), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. - v_weight_reshaped.image( + v_weight.image( command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read), + vTensor::Stage::Compute), // Read-only access is implied on const tensors and triggers an async // synchronization if necessary. v_bias.buffer( @@ -529,12 +525,14 @@ void conv2d( const float output_min, const float output_max) { if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const int64_t stacks_per_tower = v_weight.sizes()[0]; const struct { int32_t kernel_x, kernel_y, kernel_ic, kernel_oc; int32_t stride_x, stride_y; int32_t padding_x, padding_y; int32_t dilate_x, dilate_y; float clamp_x, clamp_y; + int32_t stacks_per_tower; } block { safe_downcast(filter[Layout::Filter::width]), safe_downcast(filter[Layout::Filter::height]), @@ -548,6 +546,7 @@ void conv2d( safe_downcast(dilation[Layout::Parameter::height]), output_min, output_max, + safe_downcast(stacks_per_tower), }; context->dispatch( @@ -931,7 +930,7 @@ c10::intrusive_ptr conv2d_clamp_prepack( /* output_padding = */ {}, groups, output_min, - output_min)); + output_max)); } Tensor conv2d_clamp_run(