Skip to content

Commit

Permalink
[vulkan] Distribute weight prepacking along y dimension for conv2d
Browse files Browse the repository at this point in the history
ghstack-source-id: 43df57658e4d522b63b9f8aea08b7f09e19e3053
Pull Request resolved: #48266
  • Loading branch information
SS-JIA committed Dec 1, 2020
1 parent ea0ffbb commit 39e2a52
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 101 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/native/vulkan/api/Common.h
Expand Up @@ -50,6 +50,8 @@ namespace native {
namespace vulkan {
namespace api {

const int64_t MAX_STACK_DEPTH=2048*4;

struct Adapter;
struct Command;
class Context;
Expand Down
13 changes: 8 additions & 5 deletions aten/src/ATen/native/vulkan/glsl/conv2d.glsl
Expand Up @@ -18,6 +18,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block
ivec2 padding;
ivec2 dilate;
vec2 clamp;
int stacks_per_tower;
} uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand All @@ -28,7 +29,9 @@ void main() {
/* Dynamically Uniform */
const ivec3 size = imageSize(uOutput);
const ivec3 isize = textureSize(uInput, 0);
const ivec4 block = pos.z * uBlock.kernel.z + ivec4(0, 1, 2, 3);
const int tower = pos.z/(uBlock.stacks_per_tower);
const int tower_offset = pos.z % uBlock.stacks_per_tower;
const ivec4 block = tower_offset * uBlock.kernel.z + ivec4(0, 1, 2, 3);

if (all(lessThan(pos, size))) {
const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding;
Expand All @@ -46,10 +49,10 @@ void main() {
for (int x = start.x, kx = kstart.x; x < end.x; x += uBlock.dilate.x, ++kx) {
const vec4 In = texelFetch(uInput, ivec3(x, y, z), 0);

sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, ky, kz.x), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, ky, kz.y), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, ky, kz.z), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, ky, kz.w), 0), sum);
sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.x), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.y), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.z), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.w), 0), sum);
}
}
}
Expand Down
41 changes: 8 additions & 33 deletions aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl
Expand Up @@ -17,7 +17,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block
ivec2 stride;
ivec2 padding;
vec2 clamp;
int W;
int stacks_per_tower;
} uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand All @@ -28,7 +28,9 @@ void main() {
/* Dynamically Uniform */
const ivec3 size = imageSize(uOutput);
const ivec3 isize = textureSize(uInput, 0);
const ivec4 block = pos.z * uBlock.kernel.x + ivec4(0, 1, 2, 3);
const int tower = pos.z/(uBlock.stacks_per_tower);
const int tower_offset = pos.z % uBlock.stacks_per_tower;
const ivec4 block = tower_offset * uBlock.kernel.x + ivec4(0, 1, 2, 3);

if (all(lessThan(pos, size))) {
const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding;
Expand All @@ -39,37 +41,10 @@ void main() {
const vec4 In = texelFetch(uInput, ivec3(ipos.x, ipos.y, z), 0);
const ivec4 kz = block + 4 * z;

const int W = uBlock.W;

const vec4 val1 = vec4(
texelFetch(uKernel, ivec3((4*kz.x+0)%W, ((4*kz.x+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.x+1)%W, ((4*kz.x+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.x+2)%W, ((4*kz.x+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.x+3)%W, ((4*kz.x+3))/W, 0), 0).x
);
const vec4 val2 = vec4(
texelFetch(uKernel, ivec3((4*kz.y+0)%W, ((4*kz.y+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.y+1)%W, ((4*kz.y+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.y+2)%W, ((4*kz.y+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.y+3)%W, ((4*kz.y+3))/W, 0), 0).x
);
const vec4 val3 = vec4(
texelFetch(uKernel, ivec3((4*kz.z+0)%W, ((4*kz.z+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.z+1)%W, ((4*kz.z+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.z+2)%W, ((4*kz.z+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.z+3)%W, ((4*kz.z+3))/W, 0), 0).x
);
const vec4 val4 = vec4(
texelFetch(uKernel, ivec3((4*kz.w+0)%W, ((4*kz.w+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.w+1)%W, ((4*kz.w+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.w+2)%W, ((4*kz.w+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.w+3)%W, ((4*kz.w+3))/W, 0), 0).x
);

sum = fma(In.xxxx, val1, sum);
sum = fma(In.yyyy, val2, sum);
sum = fma(In.zzzz, val3, sum);
sum = fma(In.wwww, val4, sum);
sum = fma(In.xxxx, texelFetch(uKernel, ivec3(0, tower, kz.x), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec3(0, tower, kz.y), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec3(0, tower, kz.z), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec3(0, tower, kz.w), 0), sum);
}

imageStore(
Expand Down
127 changes: 64 additions & 63 deletions aten/src/ATen/native/vulkan/ops/Convolution.cpp
Expand Up @@ -44,10 +44,10 @@ vTensor pack_weights(

if (is_depthwise(src_filter, groups)) {
vTensor v_weight{
api::context(),
&pool,
src_filter,
weight.options(),
api::context(),
&pool,
src_filter,
weight.options(),
};

using Future = vTensor::Future<void, vTensor::Access::Write>;
Expand All @@ -67,55 +67,76 @@ vTensor pack_weights(
//


const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4));
const int64_t stack_depth =
4 * api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4));
const int64_t max_stacks_per_tower =
at::native::vulkan::api::MAX_STACK_DEPTH / stack_depth;
const int64_t num_towers = div_up(num_stacks, max_stacks_per_tower);
int64_t stacks_per_tower = num_stacks;
if (num_towers > 1) {
stacks_per_tower = div_up(num_stacks, num_towers);
}
vTensor v_weight{
api::context(),
&pool,
{
div_up(src_filter[Layout::Filter::output], INT64_C(4)),
4 * align_up(src_filter[Layout::Filter::input], INT64_C(4)),
src_filter[Layout::Filter::height],
src_filter[Layout::Filter::width],
},
weight.options(),
api::context(),
&pool,
{
stacks_per_tower,
stack_depth,
src_filter[Layout::Filter::height] * num_towers,
src_filter[Layout::Filter::width],
},
weight.options(),
};

using Future = vTensor::Future<float, vTensor::Access::Write>;
Future v_weight_future = v_weight.host<float, vTensor::Access::Write>();
Future::Payload v_weight_payload = v_weight_future.wait();

/* Source */
const int64_t src_kernel = src_filter[Layout::Filter::height] * src_filter[Layout::Filter::width];
const int64_t src_block = src_kernel * src_filter[Layout::Filter::input];
const int64_t src_kw_sz = src_filter[Layout::Filter::width];
const int64_t src_kh_sz = src_filter[Layout::Filter::height];
const int64_t src_kernel_sz = src_kw_sz * src_kh_sz;
const int64_t src_block_sz =
src_kernel_sz * src_filter[Layout::Filter::input];

/* Destination */
const IntArrayRef dst_filter = v_weight.sizes();
const int64_t dst_kernel = dst_filter[Layout::Filter::height] * dst_filter[Layout::Filter::width];
const int64_t dst_block = dst_kernel * dst_filter[Layout::Filter::input];
TORCH_INTERNAL_ASSERT(src_kernel == dst_kernel, "Internal error!");
const int64_t dst_kw_sz = src_filter[Layout::Filter::width];
const int64_t dst_kh_sz = src_filter[Layout::Filter::height] * num_towers;
const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz;
const int64_t dst_block_sz =
dst_kernel_sz * dst_filter[Layout::Filter::input];
// TORCH_INTERNAL_ASSERT(src_kernel_sz == dst_kernel_sz, "Internal error!");

float* const dst_weight_ptr = v_weight_payload.get();
memset(dst_weight_ptr, 0, v_weight.nbytes());

for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) {
/* Source */
const float *const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block;
for (int64_t i_tower = 0; i_tower < num_towers; ++i_tower) {
const float* const src_tower_ptr =
src_weight_ptr + (i_tower * stacks_per_tower * 4) * src_block_sz;
for (int64_t src_oc = 0; src_oc < (stacks_per_tower * 4); ++src_oc) {
/* Source */
const float* const src_weight_oc_ptr =
src_tower_ptr + src_oc * src_block_sz;

/* Destination */
const int64_t dst_oc = src_oc / 4;
const int64_t dst_oc_offset = src_oc % 4;
/* Destination */
const int64_t dst_oc = src_oc / 4;
const int64_t dst_oc_offset = src_oc % 4;

float* const dst_weight_oc_ptr =
dst_weight_ptr +
dst_oc * dst_block +
dst_oc_offset * dst_kernel;
float* const dst_weight_oc_ptr = dst_weight_ptr + dst_oc * dst_block_sz +
dst_oc_offset * dst_kernel_sz;

for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) {
const int64_t dst_ic = 4 * src_ic;
for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input];
++src_ic) {
const int64_t dst_ic = 4 * src_ic;

memcpy(
dst_weight_oc_ptr + dst_ic * dst_kernel,
src_weight_oc_ptr + src_ic * src_kernel,
sizeof(float) * dst_kernel);
memcpy(
dst_weight_oc_ptr + dst_ic * dst_kernel_sz +
(i_tower * src_kernel_sz),
src_weight_oc_ptr + src_ic * src_kernel_sz,
sizeof(float) * src_kernel_sz);
}
}
}

Expand Down Expand Up @@ -345,36 +366,14 @@ void conv2d_pointwise(
const float output_min,
const float output_max) {
if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) {

vTensor v_weight_reshaped{
context,
{1,1, v_weight.sizes()[0], v_weight.sizes()[1]},
v_input.options(),
};

api::Command::Buffer temp_command_buffer =
api::context()->command().pool.allocate();
temp_command_buffer.begin();

temp_command_buffer.copy(
v_weight.buffer(
temp_command_buffer,
vTensor::Stage::Transfer),
v_weight_reshaped.buffer(
temp_command_buffer,
vTensor::Stage::Transfer,
vTensor::Access::Write)
);

temp_command_buffer.end();
temp_command_buffer.submit(api::context()->gpu().queue);
const int64_t stacks_per_tower = v_weight.sizes()[0];

const struct {
int32_t kernel_ic, kernel_oc;
int32_t stride_x, stride_y;
int32_t padding_x, padding_y;
float clamp_x, clamp_y;
int32_t w;
int32_t stacks_per_tower;
} block {
safe_downcast<int32_t>(filter[Layout::Filter::input]),
safe_downcast<int32_t>(filter[Layout::Filter::output]),
Expand All @@ -384,7 +383,7 @@ void conv2d_pointwise(
safe_downcast<int32_t>(padding[Layout::Parameter::height]),
output_min,
output_max,
v_weight.sizes()[1],
safe_downcast<int32_t>(stacks_per_tower),
};

context->dispatch(
Expand All @@ -411,10 +410,9 @@ void conv2d_pointwise(
vTensor::Stage::Compute),
// Read-only access is implied on const tensors and triggers an async
// synchronization if necessary.
v_weight_reshaped.image(
v_weight.image(
command_buffer,
vTensor::Stage::Compute,
vTensor::Access::Read),
vTensor::Stage::Compute),
// Read-only access is implied on const tensors and triggers an async
// synchronization if necessary.
v_bias.buffer(
Expand Down Expand Up @@ -443,12 +441,14 @@ void conv2d(
const float output_min,
const float output_max) {
if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) {
const int64_t stacks_per_tower = v_weight.sizes()[0];
const struct {
int32_t kernel_x, kernel_y, kernel_ic, kernel_oc;
int32_t stride_x, stride_y;
int32_t padding_x, padding_y;
int32_t dilate_x, dilate_y;
float clamp_x, clamp_y;
int32_t stacks_per_tower;
} block {
safe_downcast<int32_t>(filter[Layout::Filter::width]),
safe_downcast<int32_t>(filter[Layout::Filter::height]),
Expand All @@ -462,6 +462,7 @@ void conv2d(
safe_downcast<int32_t>(dilation[Layout::Parameter::height]),
output_min,
output_max,
safe_downcast<int32_t>(stacks_per_tower),
};

context->dispatch(
Expand Down Expand Up @@ -734,7 +735,7 @@ c10::intrusive_ptr<Conv2dOpContext> conv2d_clamp_prepack(
/* output_padding = */ {},
groups,
output_min,
output_min));
output_max));
}

Tensor conv2d_clamp_run(
Expand Down

0 comments on commit 39e2a52

Please sign in to comment.