Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vulkan] Add 2D transposed convolutions #67104

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 65 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/conv_transpose2d.glsl
@@ -0,0 +1,65 @@
#version 450 core
#define PRECISION $precision

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

layout(set = 0, binding = 0) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel;
layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias;
layout(set = 0, binding = 4) uniform PRECISION restrict Block {
ivec4 size;
ivec4 kernel;
ivec2 ikernel;
ivec2 stride;
ivec2 padding;
ivec2 dilate;
vec2 clamp;
} uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

const ivec2 isize = ivec2(uBlock.kernel.zw);
const vec2 ksize = vec2(uBlock.kernel.xy);
const vec2 stride = vec2(uBlock.stride);
const vec2 padding = vec2(uBlock.padding);

if (all(lessThan(pos, uBlock.size.xyz))) {
ivec2 ipos = pos.xy + uBlock.padding;
vec2 ipos_f = vec2(ipos);

const ivec2 start = max(ivec2(0), ivec2(ceil((ipos_f - ksize + 1)/stride)));
const ivec2 end = min(isize, ivec2(floor(ipos_f/stride))+1);
ivec2 kstart = start;

vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0);

int ky_start = uBlock.kernel.y - 1 - (ipos.y - uBlock.stride.y*start.y) + pos.z * uBlock.ikernel.y;
int kx_start = (uBlock.kernel.x - 1 - (ipos.x - uBlock.stride.x*start.x)) * uBlock.size.w;
int kx_stride = uBlock.size.w * (uBlock.stride.x - 1);
for (int y = start.y, ky = ky_start; y < end.y; ++y, ky += uBlock.stride.y) {
int kx = kx_start;
for (int x = start.x, kx = kx_start; x < end.x; ++x, kx += kx_stride) {
for (int z4 = 0; z4 < uBlock.size.w/4; ++z4, kx += 4) {
const vec4 In = texelFetch(uInput, ivec3(x, y, z4), 0);
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);

sum = fma(In.xxxx, texelFetch(uKernel, ivec2(kxs.x, ky), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec2(kxs.y, ky), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec2(kxs.z, ky), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec2(kxs.w, ky), 0), sum);
}
}
}

imageStore(
uOutput,
pos,
clamp(sum, uBlock.clamp.x, uBlock.clamp.y));
}
}
8 changes: 8 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Common.h
Expand Up @@ -28,6 +28,14 @@ struct Layout final {
static constexpr size_t width = 3u;
};

// Transposed Convolution Filters
struct TransposedFilter final {
static constexpr size_t input = 0u;
static constexpr size_t output = 1u;
static constexpr size_t height = 2u;
static constexpr size_t width = 3u;
};

// Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.)
struct Parameter final {
static constexpr size_t height = 0u;
Expand Down
151 changes: 134 additions & 17 deletions aten/src/ATen/native/vulkan/ops/Convolution.cpp
Expand Up @@ -53,7 +53,11 @@ Conv2dMethod determine_method(
const IntArrayRef stride,
const IntArrayRef padding,
const IntArrayRef dilation,
const int64_t groups) {
const int64_t groups,
const bool transposed) {
if (transposed)
return Conv2dTranspose;

if (is_depthwise(filter, groups))
return Conv2dDepthwise;
if (is_pointwise(filter))
Expand Down Expand Up @@ -188,6 +192,73 @@ vTensor pack_weights_2d(
return v_weight;
}

vTensor pack_weights_2d_reverse(
api::Context* const context,
api::Command::Buffer& command_buffer,
const Tensor& weight,
bool reversed) {
/* Source */
const IntArrayRef src_filter = weight.sizes();
const float* const src_weight_ptr = weight.data_ptr<float>();

const int64_t src_kw_sz = src_filter[Layout::Filter::width];
const int64_t src_kh_sz = src_filter[Layout::Filter::height];
const int64_t src_kernel_sz = src_kw_sz * src_kh_sz;
const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input];

const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4));
const int64_t stack_depth = api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4));

/* Destination */
const int64_t dst_kw_sz = src_kw_sz * stack_depth;
const int64_t dst_kh_sz = src_kh_sz * num_stacks;
const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz;

vTensor v_weight{
context,
{
4,
dst_kh_sz,
dst_kw_sz,
},
weight.options(),
};

using Future = vTensor::Future<float, vTensor::Access::Write>;
Future v_weight_future = v_weight.host<float, vTensor::Access::Write>(command_buffer);
Future::Payload v_weight_payload = v_weight_future.wait();

float* const dst_weight_ptr = v_weight_payload.get();
memset(dst_weight_ptr, 0, v_weight.nbytes());

for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) {
/* Source */
const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz;

/* Destination */
const int64_t dst_oh = src_oc / 4;
const int64_t dst_c = src_oc % 4;

float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz;

for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) {
for (int64_t src_ih = 0; src_ih < src_kh_sz; ++src_ih) {
const int64_t dst_h = reversed ? (src_kh_sz - 1 - src_ih) : src_ih;
for (int64_t src_iw = 0; src_iw < src_kw_sz; ++src_iw) {
const int64_t dst_w = reversed ? (src_kw_sz - 1 - src_iw) : src_iw;
const int64_t dst_w_offset = dst_w * stack_depth;
memcpy(
dst_weight_c_ptr + (dst_oh * src_kh_sz + dst_h) * dst_kw_sz + src_ic + dst_w_offset,
src_weight_oc_ptr + src_ic * src_kernel_sz + src_ih * src_kw_sz + src_iw,
sizeof(float));
}
}
}
}

return v_weight;
}

vTensor pack_weights_2d_winograd_2_3(
api::Context* const context,
api::Command::Buffer& command_buffer,
Expand Down Expand Up @@ -283,15 +354,16 @@ vTensor pack_weights_2d_winograd_2_3(

vTensor pack_weights(
const Tensor& weight_arg,
const Conv2dMethod conv_method) {
const Conv2dMethod conv_method,
const bool transposed) {
if (weight_arg.is_vulkan()) {
return convert(weight_arg);
}

api::Context* const context = api::context();
api::Command::Buffer& command_buffer = context->command().pool.stream();

const Tensor weight = weight_arg.contiguous();
const Tensor weight = transposed ? at::permute(weight_arg, {1, 0, 2, 3}).contiguous() : weight_arg.contiguous();

if (conv_method == Conv2dDepthwise) {
return pack_weights_dw(
Expand All @@ -306,6 +378,13 @@ vTensor pack_weights(
command_buffer,
weight);
}
if (conv_method == Conv2dTranspose) {
return pack_weights_2d_reverse(
context,
command_buffer,
weight,
true);
}

return pack_weights_2d(
context,
Expand Down Expand Up @@ -365,16 +444,20 @@ vTensor pack_biases(

std::array<int64_t, 4> pack_filter(
const Tensor& weight,
const IntArrayRef dilation) {
const IntArrayRef dilation,
const bool transposed) {
const IntArrayRef filter = weight.sizes();

const auto effective = [](const int64_t k, const int64_t d) {
return k + (k - 1) * (d - 1);
};

const size_t filter_output_ind = transposed ? Layout::TransposedFilter::output : Layout::Filter::output;
const size_t filter_input_ind = transposed ? Layout::TransposedFilter::input : Layout::Filter::input;

return {
align_up(filter[Layout::Filter::output], INT64_C(4)),
align_up(filter[Layout::Filter::input], INT64_C(4)),
align_up(filter[filter_output_ind], INT64_C(4)),
align_up(filter[filter_input_ind], INT64_C(4)),
effective(
filter[Layout::Filter::height],
dilation[Layout::Parameter::height]),
Expand Down Expand Up @@ -417,7 +500,8 @@ bool available(
((bias->device().is_cpu()) ||
(c10::DeviceType::Vulkan == bias->device().type())) &&
(kFloat == bias->scalar_type()) &&
(transposed ? false /* to be addded in the future */
(transposed ? (weight.size(Layout::TransposedFilter::output) ==
bias->size(Layout::Filter::output))
: (weight.size(Layout::Filter::output) ==
bias->size(Layout::Filter::output))))
: true) &&
Expand All @@ -428,8 +512,10 @@ bool available(
(padding[Layout::Parameter::height] >= 0) &&
(padding[Layout::Parameter::width] >= 0) &&
// Dilation
(dilation[Layout::Parameter::height] > 0) &&
(dilation[Layout::Parameter::width] > 0) &&
(transposed ? (dilation[Layout::Parameter::height] == 1) &&
(dilation[Layout::Parameter::width] == 1)
: (dilation[Layout::Parameter::height] > 0) &&
(dilation[Layout::Parameter::width] > 0)) &&
// Groups
(groups > 0) &&
// Input
Expand Down Expand Up @@ -457,6 +543,28 @@ bool usable(const Tensor& input) {
true;
}

static inline std::vector<int64_t> get_conv_output_size(
IntArrayRef input_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef(),
const bool transposed = false
) {
if (transposed) {
auto dim = input_size.size();
std::vector<int64_t> output_size(dim);
output_size[0] = input_size[input_batch_size_dim];
output_size[1] = weight_size[weight_input_channels_dim];
for (size_t d = 2; d < dim; ++d) {
output_size[d] = stride[d - 2] * (input_size[d] - 1) + weight_size[d] - 2 * padding[d - 2];
}
return output_size;
}
return conv_output_size(
input_size,
weight_size,
padding,
stride,
dilation);
}

Tensor convolution(
const Tensor& input,
Expand Down Expand Up @@ -496,16 +604,16 @@ Conv2dOpContext::Conv2dOpContext(
const IntArrayRef stride,
const IntArrayRef padding,
const IntArrayRef dilation,
const bool /* transposed */,
const bool transposed,
const IntArrayRef /* output_padding */,
const int64_t groups,
const Conv2dMethod method,
const c10::optional<Scalar>& output_min,
const c10::optional<Scalar>& output_max)
: packed_{
pack_weights(weight, method),
pack_weights(weight, method, transposed),
pack_biases(bias, weight),
pack_filter(weight, expand_param_if_needed(dilation, "dilation", 2)),
pack_filter(weight, expand_param_if_needed(dilation, "dilation", 2), transposed),
pack_params(expand_param_if_needed(stride, "stride", 2)),
pack_params(expand_param_if_needed(padding, "padding", 2)),
pack_params(expand_param_if_needed(dilation, "dilation", 2)),
Expand All @@ -524,7 +632,8 @@ Conv2dOpContext::Conv2dOpContext(
output_min,
output_max,
},
method_(method) {
method_(method),
transposed_(transposed) {
}

Conv2dOpContext Conv2dOpContext::create(
Expand Down Expand Up @@ -565,7 +674,8 @@ Conv2dOpContext Conv2dOpContext::create(
stride,
padding,
dilation,
groups);
groups,
transposed);

// Pass in the originals
return Conv2dOpContext{
Expand Down Expand Up @@ -606,7 +716,7 @@ void Conv2dOpContext::conv2d_sliding_window(
ivec4 src_filter;
} block {
v_output.extents(),
safe_downcast<int32_t>(packed_.filter[Layout::Filter::input]),
safe_downcast<int32_t>(packed_.filter[Layout::Filter::input]), /* this is aligned up */
{
safe_downcast<int32_t>(packed_.filter[Layout::Filter::width]),
safe_downcast<int32_t>(packed_.filter[Layout::Filter::height]),
Expand Down Expand Up @@ -811,12 +921,13 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const {

vTensor v_output{
context,
conv_output_size(
get_conv_output_size(
v_input.sizes(),
unpacked_.filter,
packed_.padding,
packed_.stride,
packed_.dilation),
packed_.dilation,
transposed_),
input.options(),
};

Expand All @@ -835,6 +946,12 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const {
v_output,
v_input);
break;
case Conv2dTranspose:
conv2d_sliding_window(
VK_KERNEL(conv_transpose2d),
v_output,
v_input);
break;
default:
conv2d_sliding_window(
VK_KERNEL(conv2d),
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Convolution.h
Expand Up @@ -15,6 +15,7 @@ enum Conv2dMethod {
Conv2dPointwise,
Conv2dSlidingWindow,
Conv2dWinograd_2_3,
Conv2dTranspose,
};

class Conv2dOpContext final : public torch::jit::CustomClassHolder {
Expand Down Expand Up @@ -93,6 +94,7 @@ class Conv2dOpContext final : public torch::jit::CustomClassHolder {
} unpacked_;

Conv2dMethod method_;
bool transposed_;
};

Tensor conv2d_clamp_run(
Expand Down