Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vulkan] Distribute weight prepacking along y dimension for conv2d #48266

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 11 additions & 8 deletions aten/src/ATen/native/vulkan/glsl/conv2d.glsl
Expand Up @@ -18,6 +18,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block
ivec2 padding;
ivec2 dilate;
vec2 clamp;
int stacks_per_tower;
} uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand All @@ -28,7 +29,9 @@ void main() {
/* Dynamically Uniform */
const ivec3 size = imageSize(uOutput);
const ivec3 isize = textureSize(uInput, 0);
const ivec4 block = pos.z * uBlock.kernel.z + ivec4(0, 1, 2, 3);
const int tower = pos.z/(uBlock.stacks_per_tower);
const int tower_offset = pos.z % uBlock.stacks_per_tower;
const ivec4 block = tower_offset * uBlock.kernel.z + ivec4(0, 1, 2, 3);

if (all(lessThan(pos, size))) {
const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding;
Expand All @@ -39,17 +42,17 @@ void main() {

vec4 sum = uBias.data[pos.z];

for (int z = 0; z < uBlock.kernel.z; ++z) {
const ivec4 kz = block + 4 * z;
for (int z = 0; z < uBlock.kernel.z; z+=4) {
const ivec4 kz = block + z;

for (int y = start.y, ky = kstart.y; y < end.y; y += uBlock.dilate.y, ++ky) {
for (int x = start.x, kx = kstart.x; x < end.x; x += uBlock.dilate.x, ++kx) {
const vec4 In = texelFetch(uInput, ivec3(x, y, z), 0);
const vec4 In = texelFetch(uInput, ivec3(x, y, z/4), 0);

sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, ky, kz.x), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, ky, kz.y), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, ky, kz.z), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, ky, kz.w), 0), sum);
sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.x), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.y), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.z), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec3(kx, (uBlock.kernel.y*tower) + ky, kz.w), 0), sum);
}
}
}
Expand Down
47 changes: 11 additions & 36 deletions aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl
Expand Up @@ -17,7 +17,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block
ivec2 stride;
ivec2 padding;
vec2 clamp;
int W;
int stacks_per_tower;
} uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand All @@ -28,48 +28,23 @@ void main() {
/* Dynamically Uniform */
const ivec3 size = imageSize(uOutput);
const ivec3 isize = textureSize(uInput, 0);
const ivec4 block = pos.z * uBlock.kernel.x + ivec4(0, 1, 2, 3);
const int tower = pos.z/(uBlock.stacks_per_tower);
const int tower_offset = pos.z % uBlock.stacks_per_tower;
const ivec4 block = tower_offset * uBlock.kernel.x + ivec4(0, 1, 2, 3);

if (all(lessThan(pos, size))) {
const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding;

vec4 sum = uBias.data[pos.z];

for (int z = 0; z < uBlock.kernel.x; ++z) {
const vec4 In = texelFetch(uInput, ivec3(ipos.x, ipos.y, z), 0);
const ivec4 kz = block + 4 * z;
for (int z = 0; z < uBlock.kernel.x; z+=4) {
const vec4 In = texelFetch(uInput, ivec3(ipos.x, ipos.y, z/4), 0);
const ivec4 kz = block + z;

const int W = uBlock.W;

const vec4 val1 = vec4(
texelFetch(uKernel, ivec3((4*kz.x+0)%W, ((4*kz.x+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.x+1)%W, ((4*kz.x+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.x+2)%W, ((4*kz.x+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.x+3)%W, ((4*kz.x+3))/W, 0), 0).x
);
const vec4 val2 = vec4(
texelFetch(uKernel, ivec3((4*kz.y+0)%W, ((4*kz.y+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.y+1)%W, ((4*kz.y+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.y+2)%W, ((4*kz.y+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.y+3)%W, ((4*kz.y+3))/W, 0), 0).x
);
const vec4 val3 = vec4(
texelFetch(uKernel, ivec3((4*kz.z+0)%W, ((4*kz.z+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.z+1)%W, ((4*kz.z+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.z+2)%W, ((4*kz.z+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.z+3)%W, ((4*kz.z+3))/W, 0), 0).x
);
const vec4 val4 = vec4(
texelFetch(uKernel, ivec3((4*kz.w+0)%W, ((4*kz.w+0))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.w+1)%W, ((4*kz.w+1))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.w+2)%W, ((4*kz.w+2))/W, 0), 0).x,
texelFetch(uKernel, ivec3((4*kz.w+3)%W, ((4*kz.w+3))/W, 0), 0).x
);

sum = fma(In.xxxx, val1, sum);
sum = fma(In.yyyy, val2, sum);
sum = fma(In.zzzz, val3, sum);
sum = fma(In.wwww, val4, sum);
sum = fma(In.xxxx, texelFetch(uKernel, ivec3(0, tower, kz.x), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec3(0, tower, kz.y), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec3(0, tower, kz.z), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec3(0, tower, kz.w), 0), sum);
}

imageStore(
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Common.h
Expand Up @@ -39,6 +39,10 @@ struct Experimentation {
static constexpr bool kUseConv2dOldApi = true;
};

struct ConvPrepackLimits final {
static constexpr int64_t maxStackDepth = 2048*4;
};

} // namespace ops
} // namespace vulkan
} // namespace native
Expand Down
113 changes: 56 additions & 57 deletions aten/src/ATen/native/vulkan/ops/Convolution.cpp
Expand Up @@ -45,10 +45,10 @@ vTensor pack_weights(

if (is_depthwise(src_filter, groups)) {
vTensor v_weight{
api::context(),
&pool,
src_filter,
weight.options(),
api::context(),
&pool,
src_filter,
weight.options(),
};

using Future = vTensor::Future<void, vTensor::Access::Write>;
Expand Down Expand Up @@ -153,55 +153,74 @@ vTensor pack_weights(
return v_weight;
}

const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4));
const int64_t stack_depth =
4 * api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4));
const int64_t max_stacks_per_tower =
ConvPrepackLimits::maxStackDepth / stack_depth;
const int64_t num_towers = div_up(num_stacks, max_stacks_per_tower);
int64_t stacks_per_tower = num_stacks;
if (num_towers > 1) {
stacks_per_tower = div_up(num_stacks, num_towers);
}
vTensor v_weight{
api::context(),
&pool,
{
div_up(src_filter[Layout::Filter::output], INT64_C(4)),
4 * align_up(src_filter[Layout::Filter::input], INT64_C(4)),
src_filter[Layout::Filter::height],
src_filter[Layout::Filter::width],
},
weight.options(),
api::context(),
&pool,
{
stacks_per_tower,
stack_depth,
src_filter[Layout::Filter::height] * num_towers,
src_filter[Layout::Filter::width],
},
weight.options(),
};

using Future = vTensor::Future<float, vTensor::Access::Write>;
Future v_weight_future = v_weight.host<float, vTensor::Access::Write>();
Future::Payload v_weight_payload = v_weight_future.wait();

/* Source */
const int64_t src_kernel = src_filter[Layout::Filter::height] * src_filter[Layout::Filter::width];
const int64_t src_block = src_kernel * src_filter[Layout::Filter::input];
const int64_t src_kw_sz = src_filter[Layout::Filter::width];
const int64_t src_kh_sz = src_filter[Layout::Filter::height];
const int64_t src_kernel_sz = src_kw_sz * src_kh_sz;
const int64_t src_block_sz =
src_kernel_sz * src_filter[Layout::Filter::input];

/* Destination */
const IntArrayRef dst_filter = v_weight.sizes();
const int64_t dst_kernel = dst_filter[Layout::Filter::height] * dst_filter[Layout::Filter::width];
const int64_t dst_block = dst_kernel * dst_filter[Layout::Filter::input];
TORCH_INTERNAL_ASSERT(src_kernel == dst_kernel, "Internal error!");
const int64_t dst_kw_sz = src_filter[Layout::Filter::width];
const int64_t dst_kh_sz = src_filter[Layout::Filter::height] * num_towers;
const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz;
const int64_t dst_block_sz =
dst_kernel_sz * dst_filter[Layout::Filter::input];

TORCH_INTERNAL_ASSERT(src_kernel_sz*num_towers == dst_kernel_sz, "Internal error!");

float* const dst_weight_ptr = v_weight_payload.get();
memset(dst_weight_ptr, 0, v_weight.nbytes());

for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) {
const int64_t i_tower = src_oc / (stacks_per_tower * 4);
/* Source */
const float *const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block;
const float* const src_weight_oc_ptr =
src_weight_ptr + src_oc * src_block_sz;

/* Destination */
const int64_t dst_oc = src_oc / 4;
const int64_t dst_oc_offset = src_oc % 4;
const int64_t local_oc = src_oc % (stacks_per_tower * 4);
const int64_t dst_oc = local_oc / 4;
const int64_t dst_oc_offset = local_oc % 4;

float* const dst_weight_oc_ptr =
dst_weight_ptr +
dst_oc * dst_block +
dst_oc_offset * dst_kernel;
float* const dst_weight_oc_ptr = dst_weight_ptr + dst_oc * dst_block_sz +
dst_oc_offset * dst_kernel_sz;

for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) {
const int64_t dst_ic = 4 * src_ic;

memcpy(
dst_weight_oc_ptr + dst_ic * dst_kernel,
src_weight_oc_ptr + src_ic * src_kernel,
sizeof(float) * dst_kernel);
dst_weight_oc_ptr + dst_ic * dst_kernel_sz +
(i_tower * src_kernel_sz),
src_weight_oc_ptr + src_ic * src_kernel_sz,
sizeof(float) * src_kernel_sz);
}
}

Expand Down Expand Up @@ -431,36 +450,14 @@ void conv2d_pointwise(
const float output_min,
const float output_max) {
if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) {

vTensor v_weight_reshaped{
context,
{1,1, v_weight.sizes()[0], v_weight.sizes()[1]},
v_input.options(),
};

api::Command::Buffer temp_command_buffer =
api::context()->command().pool.allocate();
temp_command_buffer.begin();

temp_command_buffer.copy(
v_weight.buffer(
temp_command_buffer,
vTensor::Stage::Transfer),
v_weight_reshaped.buffer(
temp_command_buffer,
vTensor::Stage::Transfer,
vTensor::Access::Write)
);

temp_command_buffer.end();
temp_command_buffer.submit(api::context()->gpu().queue);
const int64_t stacks_per_tower = v_weight.sizes()[0];

const struct {
int32_t kernel_ic, kernel_oc;
int32_t stride_x, stride_y;
int32_t padding_x, padding_y;
float clamp_x, clamp_y;
int32_t w;
int32_t stacks_per_tower;
} block {
safe_downcast<int32_t>(filter[Layout::Filter::input]),
safe_downcast<int32_t>(filter[Layout::Filter::output]),
Expand All @@ -470,7 +467,7 @@ void conv2d_pointwise(
safe_downcast<int32_t>(padding[Layout::Parameter::height]),
output_min,
output_max,
v_weight.sizes()[1],
safe_downcast<int32_t>(stacks_per_tower),
};

context->dispatch(
Expand All @@ -497,10 +494,9 @@ void conv2d_pointwise(
vTensor::Stage::Compute),
// Read-only access is implied on const tensors and triggers an async
// synchronization if necessary.
v_weight_reshaped.image(
v_weight.image(
command_buffer,
vTensor::Stage::Compute,
vTensor::Access::Read),
vTensor::Stage::Compute),
// Read-only access is implied on const tensors and triggers an async
// synchronization if necessary.
v_bias.buffer(
Expand Down Expand Up @@ -529,12 +525,14 @@ void conv2d(
const float output_min,
const float output_max) {
if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) {
const int64_t stacks_per_tower = v_weight.sizes()[0];
const struct {
int32_t kernel_x, kernel_y, kernel_ic, kernel_oc;
int32_t stride_x, stride_y;
int32_t padding_x, padding_y;
int32_t dilate_x, dilate_y;
float clamp_x, clamp_y;
int32_t stacks_per_tower;
} block {
safe_downcast<int32_t>(filter[Layout::Filter::width]),
safe_downcast<int32_t>(filter[Layout::Filter::height]),
Expand All @@ -548,6 +546,7 @@ void conv2d(
safe_downcast<int32_t>(dilation[Layout::Parameter::height]),
output_min,
output_max,
safe_downcast<int32_t>(stacks_per_tower),
};

context->dispatch(
Expand Down Expand Up @@ -931,7 +930,7 @@ c10::intrusive_ptr<Conv2dOpContext> conv2d_clamp_prepack(
/* output_padding = */ {},
groups,
output_min,
output_min));
output_max));
}

Tensor conv2d_clamp_run(
Expand Down