Skip to content

Commit

Permalink
[vulkan] Distribute weight prepacking along y dimension for conv2d
Browse files Browse the repository at this point in the history
ghstack-source-id: f814ba805d5b801d7a8f9e7b652209c2c5487051
Pull Request resolved: #48266
  • Loading branch information
SS-JIA committed Dec 2, 2020
1 parent 5f181e2 commit dfd372e
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 101 deletions.
19 changes: 11 additions & 8 deletions aten/src/ATen/native/vulkan/glsl/conv2d.glsl
Expand Up @@ -18,6 +18,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block
ivec2 padding;
ivec2 dilate;
vec2 clamp;
int stacks_per_tower;
} uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand All @@ -28,7 +29,9 @@ void main() {
/* Dynamically Uniform */
const ivec3 size = imageSize(uOutput);
const ivec3 isize = textureSize(uInput, 0);
const ivec4 block = pos.z * uBlock.kernel.z + ivec4(0, 1, 2, 3);
const int tower = pos.z/(uBlock.stacks_per_tower);
const int tower_offset = pos.z % uBlock.stacks_per_tower;
const ivec4 block = tower_offset * uBlock.kernel.z + ivec4(0, 1, 2, 3);

if (all(lessThan(pos, size))) {
const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding;
Expand All @@ -39,17 +42,17 @@ void main() {

vec4 sum = uBias.data[pos.z];

for (int z = 0; z < uBlock.kernel.z; ++z) {
const ivec4 kz = block + 4 * z;
for (int z = 0; z < uBlock.kernel.z; z+=4) {
const ivec4 kz = block + z;

for (int y = start.y, ky = kstart.y; y < end.y; y += uBlock.dilate.y, ++ky) {
for (int x = start.x, kx = kstart.x; x < end.x; x += uBlock.dilate.x, ++kx) {
const vec4 In = texelFetch(uInput, ivec3(x, y, z), 0);
const vec4 In = texelFetch(uInput, ivec3(x, y, z/4), 0);

sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, ky, kz.x), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, ky, kz.y), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, ky, kz.z), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, ky, kz.w), 0), sum);
sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.x), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.y), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.z), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.w), 0), sum);
}
}
}
Expand Down
47 changes: 11 additions & 36 deletions aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl
Expand Up @@ -17,7 +17,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block
ivec2 stride;
ivec2 padding;
vec2 clamp;
int W;
int stacks_per_tower;
} uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand All @@ -28,48 +28,23 @@ void main() {
/* Dynamically Uniform */
const ivec3 size = imageSize(uOutput);
const ivec3 isize = textureSize(uInput, 0);
const ivec4 block = pos.z * uBlock.kernel.x + ivec4(0, 1, 2, 3);
const int tower = pos.z/(uBlock.stacks_per_tower);
const int tower_offset = pos.z % uBlock.stacks_per_tower;
const ivec4 block = tower_offset * uBlock.kernel.x + ivec4(0, 1, 2, 3);

if (all(lessThan(pos, size))) {
const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding;

vec4 sum = uBias.data[pos.z];

for (int z = 0; z < uBlock.kernel.x; ++z) {
const vec4 In = texelFetch(uInput, ivec3(ipos.x, ipos.y, z), 0);
const ivec4 kz = block + 4 * z;
for (int z = 0; z < uBlock.kernel.x; z+=4) {
const vec4 In = texelFetch(uInput, ivec3(ipos.x, ipos.y, z/4), 0);
const ivec4 kz = block + z;

const int W = uBlock.W;

const vec4 val1 = vec4(
texelFetch(uKernel, ivec3((4*kz.x+0)%W, ((4*kz.x+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.x+1)%W, ((4*kz.x+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.x+2)%W, ((4*kz.x+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.x+3)%W, ((4*kz.x+3))/W, 0), 0).x
);
const vec4 val2 = vec4(
texelFetch(uKernel, ivec3((4*kz.y+0)%W, ((4*kz.y+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.y+1)%W, ((4*kz.y+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.y+2)%W, ((4*kz.y+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.y+3)%W, ((4*kz.y+3))/W, 0), 0).x
);
const vec4 val3 = vec4(
texelFetch(uKernel, ivec3((4*kz.z+0)%W, ((4*kz.z+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.z+1)%W, ((4*kz.z+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.z+2)%W, ((4*kz.z+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.z+3)%W, ((4*kz.z+3))/W, 0), 0).x
);
const vec4 val4 = vec4(
texelFetch(uKernel, ivec3((4*kz.w+0)%W, ((4*kz.w+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.w+1)%W, ((4*kz.w+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.w+2)%W, ((4*kz.w+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.w+3)%W, ((4*kz.w+3))/W, 0), 0).x
);

sum = fma(In.xxxx, val1, sum);
sum = fma(In.yyyy, val2, sum);
sum = fma(In.zzzz, val3, sum);
sum = fma(In.wwww, val4, sum);
sum = fma(In.xxxx, texelFetch(uKernel, ivec3(0, tower, kz.x), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec3(0, tower, kz.y), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec3(0, tower, kz.z), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec3(0, tower, kz.w), 0), sum);
}

imageStore(
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Common.h
Expand Up @@ -39,6 +39,10 @@ struct Experimentation {
static constexpr bool kUseConv2dOldApi = true;
};

struct ConvPrepackLimits final {
static constexpr int64_t maxStackDepth = 2048*4;
};

} // namespace ops
} // namespace vulkan
} // namespace native
Expand Down
113 changes: 56 additions & 57 deletions aten/src/ATen/native/vulkan/ops/Convolution.cpp
Expand Up @@ -45,10 +45,10 @@ vTensor pack_weights(

if (is_depthwise(src_filter, groups)) {
vTensor v_weight{
api::context(),
&pool,
src_filter,
weight.options(),
api::context(),
&pool,
src_filter,
weight.options(),
};

using Future = vTensor::Future<void, vTensor::Access::Write>;
Expand Down Expand Up @@ -153,55 +153,74 @@ vTensor pack_weights(
return v_weight;
}

const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4));
const int64_t stack_depth =
4 * api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4));
const int64_t max_stacks_per_tower =
ConvPrepackLimits::maxStackDepth / stack_depth;
const int64_t num_towers = div_up(num_stacks, max_stacks_per_tower);
int64_t stacks_per_tower = num_stacks;
if (num_towers > 1) {
stacks_per_tower = div_up(num_stacks, num_towers);
}
vTensor v_weight{
api::context(),
&pool,
{
div_up(src_filter[Layout::Filter::output], INT64_C(4)),
4 * align_up(src_filter[Layout::Filter::input], INT64_C(4)),
src_filter[Layout::Filter::height],
src_filter[Layout::Filter::width],
},
weight.options(),
api::context(),
&pool,
{
stacks_per_tower,
stack_depth,
src_filter[Layout::Filter::height] * num_towers,
src_filter[Layout::Filter::width],
},
weight.options(),
};

using Future = vTensor::Future<float, vTensor::Access::Write>;
Future v_weight_future = v_weight.host<float, vTensor::Access::Write>();
Future::Payload v_weight_payload = v_weight_future.wait();

/* Source */
const int64_t src_kernel = src_filter[Layout::Filter::height] * src_filter[Layout::Filter::width];
const int64_t src_block = src_kernel * src_filter[Layout::Filter::input];
const int64_t src_kw_sz = src_filter[Layout::Filter::width];
const int64_t src_kh_sz = src_filter[Layout::Filter::height];
const int64_t src_kernel_sz = src_kw_sz * src_kh_sz;
const int64_t src_block_sz =
src_kernel_sz * src_filter[Layout::Filter::input];

/* Destination */
const IntArrayRef dst_filter = v_weight.sizes();
const int64_t dst_kernel = dst_filter[Layout::Filter::height] * dst_filter[Layout::Filter::width];
const int64_t dst_block = dst_kernel * dst_filter[Layout::Filter::input];
TORCH_INTERNAL_ASSERT(src_kernel == dst_kernel, "Internal error!");
const int64_t dst_kw_sz = src_filter[Layout::Filter::width];
const int64_t dst_kh_sz = src_filter[Layout::Filter::height] * num_towers;
const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz;
const int64_t dst_block_sz =
dst_kernel_sz * dst_filter[Layout::Filter::input];

TORCH_INTERNAL_ASSERT(src_kernel_sz*num_towers == dst_kernel_sz, "Internal error!");

float* const dst_weight_ptr = v_weight_payload.get();
memset(dst_weight_ptr, 0, v_weight.nbytes());

for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) {
const int64_t i_tower = src_oc / (stacks_per_tower * 4);
/* Source */
const float *const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block;
const float* const src_weight_oc_ptr =
src_weight_ptr + src_oc * src_block_sz;

/* Destination */
const int64_t dst_oc = src_oc / 4;
const int64_t dst_oc_offset = src_oc % 4;
const int64_t local_oc = src_oc % (stacks_per_tower * 4);
const int64_t dst_oc = local_oc / 4;
const int64_t dst_oc_offset = local_oc % 4;

float* const dst_weight_oc_ptr =
dst_weight_ptr +
dst_oc * dst_block +
dst_oc_offset * dst_kernel;
float* const dst_weight_oc_ptr = dst_weight_ptr + dst_oc * dst_block_sz +
dst_oc_offset * dst_kernel_sz;

for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) {
const int64_t dst_ic = 4 * src_ic;

memcpy(
dst_weight_oc_ptr + dst_ic * dst_kernel,
src_weight_oc_ptr + src_ic * src_kernel,
sizeof(float) * dst_kernel);
dst_weight_oc_ptr + dst_ic * dst_kernel_sz +
(i_tower * src_kernel_sz),
src_weight_oc_ptr + src_ic * src_kernel_sz,
sizeof(float) * src_kernel_sz);
}
}

Expand Down Expand Up @@ -431,36 +450,14 @@ void conv2d_pointwise(
const float output_min,
const float output_max) {
if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) {

vTensor v_weight_reshaped{
context,
{1,1, v_weight.sizes()[0], v_weight.sizes()[1]},
v_input.options(),
};

api::Command::Buffer temp_command_buffer =
api::context()->command().pool.allocate();
temp_command_buffer.begin();

temp_command_buffer.copy(
v_weight.buffer(
temp_command_buffer,
vTensor::Stage::Transfer),
v_weight_reshaped.buffer(
temp_command_buffer,
vTensor::Stage::Transfer,
vTensor::Access::Write)
);

temp_command_buffer.end();
temp_command_buffer.submit(api::context()->gpu().queue);
const int64_t stacks_per_tower = v_weight.sizes()[0];

const struct {
int32_t kernel_ic, kernel_oc;
int32_t stride_x, stride_y;
int32_t padding_x, padding_y;
float clamp_x, clamp_y;
int32_t w;
int32_t stacks_per_tower;
} block {
safe_downcast<int32_t>(filter[Layout::Filter::input]),
safe_downcast<int32_t>(filter[Layout::Filter::output]),
Expand All @@ -470,7 +467,7 @@ void conv2d_pointwise(
safe_downcast<int32_t>(padding[Layout::Parameter::height]),
output_min,
output_max,
v_weight.sizes()[1],
safe_downcast<int32_t>(stacks_per_tower),
};

context->dispatch(
Expand All @@ -497,10 +494,9 @@ void conv2d_pointwise(
vTensor::Stage::Compute),
// Read-only access is implied on const tensors and triggers an async
// synchronization if necessary.
v_weight_reshaped.image(
v_weight.image(
command_buffer,
vTensor::Stage::Compute,
vTensor::Access::Read),
vTensor::Stage::Compute),
// Read-only access is implied on const tensors and triggers an async
// synchronization if necessary.
v_bias.buffer(
Expand Down Expand Up @@ -529,12 +525,14 @@ void conv2d(
const float output_min,
const float output_max) {
if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) {
const int64_t stacks_per_tower = v_weight.sizes()[0];
const struct {
int32_t kernel_x, kernel_y, kernel_ic, kernel_oc;
int32_t stride_x, stride_y;
int32_t padding_x, padding_y;
int32_t dilate_x, dilate_y;
float clamp_x, clamp_y;
int32_t stacks_per_tower;
} block {
safe_downcast<int32_t>(filter[Layout::Filter::width]),
safe_downcast<int32_t>(filter[Layout::Filter::height]),
Expand All @@ -548,6 +546,7 @@ void conv2d(
safe_downcast<int32_t>(dilation[Layout::Parameter::height]),
output_min,
output_max,
safe_downcast<int32_t>(stacks_per_tower),
};

context->dispatch(
Expand Down Expand Up @@ -931,7 +930,7 @@ c10::intrusive_ptr<Conv2dOpContext> conv2d_clamp_prepack(
/* output_padding = */ {},
groups,
output_min,
output_min));
output_max));
}

Tensor conv2d_clamp_run(
Expand Down

0 comments on commit dfd372e

Please sign in to comment.