Skip to content

Commit

Permalink
[vulkan] Add 2D transposed convolutions (#67104)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #67104

Add 2D transposed convolutions to Vulkan. Currently, only `dilation={1,1}` is supported. We plan to support dilation at a later time.

Test Plan:
Build and run `vulkan_api_test`:

```
cd ~/pytorch
BUILD_CUSTOM_PROTOBUF=OFF \
  BUILD_TEST=ON \
  USE_EIGEN_FOR_BLAS=OFF \
  USE_FBGEMM=OFF \
  USE_MKLDNN=OFF \
  USE_NNPACK=OFF \
  USE_NUMPY=OFF \
  USE_OBSERVERS=OFF \
  USE_PYTORCH_QNNPACK=OFF \
  USE_QNNPACK=OFF \
  USE_VULKAN=ON \
  USE_VULKAN_API=ON \
  USE_VULKAN_SHADERC_RUNTIME=ON \
  USE_VULKAN_WRAPPER=OFF \
  MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python3 setup.py develop --cmake && ./build/bin/vulkan_api_test
```

Reviewed By: beback4u

Differential Revision: D31731742

fbshipit-source-id: b79c946c8d988bb4d83e9fd3381992a4f2f4be80
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Oct 25, 2021
1 parent 059ae96 commit 0acc21b
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 17 deletions.
65 changes: 65 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/conv_transpose2d.glsl
@@ -0,0 +1,65 @@
#version 450 core
#define PRECISION $precision

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

layout(set = 0, binding = 0) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel;
layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias;
layout(set = 0, binding = 4) uniform PRECISION restrict Block {
ivec4 size;
ivec4 kernel;
ivec2 ikernel;
ivec2 stride;
ivec2 padding;
ivec2 dilate;
vec2 clamp;
} uBlock;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

const ivec2 isize = ivec2(uBlock.kernel.zw);
const vec2 ksize = vec2(uBlock.kernel.xy);
const vec2 stride = vec2(uBlock.stride);
const vec2 padding = vec2(uBlock.padding);

if (all(lessThan(pos, uBlock.size.xyz))) {
ivec2 ipos = pos.xy + uBlock.padding;
vec2 ipos_f = vec2(ipos);

const ivec2 start = max(ivec2(0), ivec2(ceil((ipos_f - ksize + 1)/stride)));
const ivec2 end = min(isize, ivec2(floor(ipos_f/stride))+1);
ivec2 kstart = start;

vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0);

int ky_start = uBlock.kernel.y - 1 - (ipos.y - uBlock.stride.y*start.y) + pos.z * uBlock.ikernel.y;
int kx_start = (uBlock.kernel.x - 1 - (ipos.x - uBlock.stride.x*start.x)) * uBlock.size.w;
int kx_stride = uBlock.size.w * (uBlock.stride.x - 1);
for (int y = start.y, ky = ky_start; y < end.y; ++y, ky += uBlock.stride.y) {
int kx = kx_start;
for (int x = start.x, kx = kx_start; x < end.x; ++x, kx += kx_stride) {
for (int z4 = 0; z4 < uBlock.size.w/4; ++z4, kx += 4) {
const vec4 In = texelFetch(uInput, ivec3(x, y, z4), 0);
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);

sum = fma(In.xxxx, texelFetch(uKernel, ivec2(kxs.x, ky), 0), sum);
sum = fma(In.yyyy, texelFetch(uKernel, ivec2(kxs.y, ky), 0), sum);
sum = fma(In.zzzz, texelFetch(uKernel, ivec2(kxs.z, ky), 0), sum);
sum = fma(In.wwww, texelFetch(uKernel, ivec2(kxs.w, ky), 0), sum);
}
}
}

imageStore(
uOutput,
pos,
clamp(sum, uBlock.clamp.x, uBlock.clamp.y));
}
}
8 changes: 8 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Common.h
Expand Up @@ -28,6 +28,14 @@ struct Layout final {
static constexpr size_t width = 3u;
};

// Transposed Convolution Filters
struct TransposedFilter final {
static constexpr size_t input = 0u;
static constexpr size_t output = 1u;
static constexpr size_t height = 2u;
static constexpr size_t width = 3u;
};

// Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.)
struct Parameter final {
static constexpr size_t height = 0u;
Expand Down
151 changes: 134 additions & 17 deletions aten/src/ATen/native/vulkan/ops/Convolution.cpp
Expand Up @@ -53,7 +53,11 @@ Conv2dMethod determine_method(
const IntArrayRef stride,
const IntArrayRef padding,
const IntArrayRef dilation,
const int64_t groups) {
const int64_t groups,
const bool transposed) {
if (transposed)
return Conv2dTranspose;

if (is_depthwise(filter, groups))
return Conv2dDepthwise;
if (is_pointwise(filter))
Expand Down Expand Up @@ -188,6 +192,73 @@ vTensor pack_weights_2d(
return v_weight;
}

vTensor pack_weights_2d_reverse(
api::Context* const context,
api::Command::Buffer& command_buffer,
const Tensor& weight,
bool reversed) {
/* Source */
const IntArrayRef src_filter = weight.sizes();
const float* const src_weight_ptr = weight.data_ptr<float>();

const int64_t src_kw_sz = src_filter[Layout::Filter::width];
const int64_t src_kh_sz = src_filter[Layout::Filter::height];
const int64_t src_kernel_sz = src_kw_sz * src_kh_sz;
const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input];

const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4));
const int64_t stack_depth = api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4));

/* Destination */
const int64_t dst_kw_sz = src_kw_sz * stack_depth;
const int64_t dst_kh_sz = src_kh_sz * num_stacks;
const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz;

vTensor v_weight{
context,
{
4,
dst_kh_sz,
dst_kw_sz,
},
weight.options(),
};

using Future = vTensor::Future<float, vTensor::Access::Write>;
Future v_weight_future = v_weight.host<float, vTensor::Access::Write>(command_buffer);
Future::Payload v_weight_payload = v_weight_future.wait();

float* const dst_weight_ptr = v_weight_payload.get();
memset(dst_weight_ptr, 0, v_weight.nbytes());

for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) {
/* Source */
const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz;

/* Destination */
const int64_t dst_oh = src_oc / 4;
const int64_t dst_c = src_oc % 4;

float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz;

for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) {
for (int64_t src_ih = 0; src_ih < src_kh_sz; ++src_ih) {
const int64_t dst_h = reversed ? (src_kh_sz - 1 - src_ih) : src_ih;
for (int64_t src_iw = 0; src_iw < src_kw_sz; ++src_iw) {
const int64_t dst_w = reversed ? (src_kw_sz - 1 - src_iw) : src_iw;
const int64_t dst_w_offset = dst_w * stack_depth;
memcpy(
dst_weight_c_ptr + (dst_oh * src_kh_sz + dst_h) * dst_kw_sz + src_ic + dst_w_offset,
src_weight_oc_ptr + src_ic * src_kernel_sz + src_ih * src_kw_sz + src_iw,
sizeof(float));
}
}
}
}

return v_weight;
}

vTensor pack_weights_2d_winograd_2_3(
api::Context* const context,
api::Command::Buffer& command_buffer,
Expand Down Expand Up @@ -283,15 +354,16 @@ vTensor pack_weights_2d_winograd_2_3(

vTensor pack_weights(
const Tensor& weight_arg,
const Conv2dMethod conv_method) {
const Conv2dMethod conv_method,
const bool transposed) {
if (weight_arg.is_vulkan()) {
return convert(weight_arg);
}

api::Context* const context = api::context();
api::Command::Buffer& command_buffer = context->command().pool.stream();

const Tensor weight = weight_arg.contiguous();
const Tensor weight = transposed ? at::permute(weight_arg, {1, 0, 2, 3}).contiguous() : weight_arg.contiguous();

if (conv_method == Conv2dDepthwise) {
return pack_weights_dw(
Expand All @@ -306,6 +378,13 @@ vTensor pack_weights(
command_buffer,
weight);
}
if (conv_method == Conv2dTranspose) {
return pack_weights_2d_reverse(
context,
command_buffer,
weight,
true);
}

return pack_weights_2d(
context,
Expand Down Expand Up @@ -365,16 +444,20 @@ vTensor pack_biases(

std::array<int64_t, 4> pack_filter(
const Tensor& weight,
const IntArrayRef dilation) {
const IntArrayRef dilation,
const bool transposed) {
const IntArrayRef filter = weight.sizes();

const auto effective = [](const int64_t k, const int64_t d) {
return k + (k - 1) * (d - 1);
};

const size_t filter_output_ind = transposed ? Layout::TransposedFilter::output : Layout::Filter::output;
const size_t filter_input_ind = transposed ? Layout::TransposedFilter::input : Layout::Filter::input;

return {
align_up(filter[Layout::Filter::output], INT64_C(4)),
align_up(filter[Layout::Filter::input], INT64_C(4)),
align_up(filter[filter_output_ind], INT64_C(4)),
align_up(filter[filter_input_ind], INT64_C(4)),
effective(
filter[Layout::Filter::height],
dilation[Layout::Parameter::height]),
Expand Down Expand Up @@ -417,7 +500,8 @@ bool available(
((bias->device().is_cpu()) ||
(c10::DeviceType::Vulkan == bias->device().type())) &&
(kFloat == bias->scalar_type()) &&
(transposed ? false /* to be addded in the future */
(transposed ? (weight.size(Layout::TransposedFilter::output) ==
bias->size(Layout::Filter::output))
: (weight.size(Layout::Filter::output) ==
bias->size(Layout::Filter::output))))
: true) &&
Expand All @@ -428,8 +512,10 @@ bool available(
(padding[Layout::Parameter::height] >= 0) &&
(padding[Layout::Parameter::width] >= 0) &&
// Dilation
(dilation[Layout::Parameter::height] > 0) &&
(dilation[Layout::Parameter::width] > 0) &&
(transposed ? (dilation[Layout::Parameter::height] == 1) &&
(dilation[Layout::Parameter::width] == 1)
: (dilation[Layout::Parameter::height] > 0) &&
(dilation[Layout::Parameter::width] > 0)) &&
// Groups
(groups > 0) &&
// Input
Expand Down Expand Up @@ -457,6 +543,28 @@ bool usable(const Tensor& input) {
true;
}

static inline std::vector<int64_t> get_conv_output_size(
IntArrayRef input_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef(),
const bool transposed = false
) {
if (transposed) {
auto dim = input_size.size();
std::vector<int64_t> output_size(dim);
output_size[0] = input_size[input_batch_size_dim];
output_size[1] = weight_size[weight_input_channels_dim];
for (size_t d = 2; d < dim; ++d) {
output_size[d] = stride[d - 2] * (input_size[d] - 1) + weight_size[d] - 2 * padding[d - 2];
}
return output_size;
}
return conv_output_size(
input_size,
weight_size,
padding,
stride,
dilation);
}

Tensor convolution(
const Tensor& input,
Expand Down Expand Up @@ -496,16 +604,16 @@ Conv2dOpContext::Conv2dOpContext(
const IntArrayRef stride,
const IntArrayRef padding,
const IntArrayRef dilation,
const bool /* transposed */,
const bool transposed,
const IntArrayRef /* output_padding */,
const int64_t groups,
const Conv2dMethod method,
const c10::optional<Scalar>& output_min,
const c10::optional<Scalar>& output_max)
: packed_{
pack_weights(weight, method),
pack_weights(weight, method, transposed),
pack_biases(bias, weight),
pack_filter(weight, expand_param_if_needed(dilation, "dilation", 2)),
pack_filter(weight, expand_param_if_needed(dilation, "dilation", 2), transposed),
pack_params(expand_param_if_needed(stride, "stride", 2)),
pack_params(expand_param_if_needed(padding, "padding", 2)),
pack_params(expand_param_if_needed(dilation, "dilation", 2)),
Expand All @@ -524,7 +632,8 @@ Conv2dOpContext::Conv2dOpContext(
output_min,
output_max,
},
method_(method) {
method_(method),
transposed_(transposed) {
}

Conv2dOpContext Conv2dOpContext::create(
Expand Down Expand Up @@ -565,7 +674,8 @@ Conv2dOpContext Conv2dOpContext::create(
stride,
padding,
dilation,
groups);
groups,
transposed);

// Pass in the originals
return Conv2dOpContext{
Expand Down Expand Up @@ -606,7 +716,7 @@ void Conv2dOpContext::conv2d_sliding_window(
ivec4 src_filter;
} block {
v_output.extents(),
safe_downcast<int32_t>(packed_.filter[Layout::Filter::input]),
safe_downcast<int32_t>(packed_.filter[Layout::Filter::input]), /* this is aligned up */
{
safe_downcast<int32_t>(packed_.filter[Layout::Filter::width]),
safe_downcast<int32_t>(packed_.filter[Layout::Filter::height]),
Expand Down Expand Up @@ -811,12 +921,13 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const {

vTensor v_output{
context,
conv_output_size(
get_conv_output_size(
v_input.sizes(),
unpacked_.filter,
packed_.padding,
packed_.stride,
packed_.dilation),
packed_.dilation,
transposed_),
input.options(),
};

Expand All @@ -835,6 +946,12 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const {
v_output,
v_input);
break;
case Conv2dTranspose:
conv2d_sliding_window(
VK_KERNEL(conv_transpose2d),
v_output,
v_input);
break;
default:
conv2d_sliding_window(
VK_KERNEL(conv2d),
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Convolution.h
Expand Up @@ -15,6 +15,7 @@ enum Conv2dMethod {
Conv2dPointwise,
Conv2dSlidingWindow,
Conv2dWinograd_2_3,
Conv2dTranspose,
};

class Conv2dOpContext final : public torch::jit::CustomClassHolder {
Expand Down Expand Up @@ -93,6 +94,7 @@ class Conv2dOpContext final : public torch::jit::CustomClassHolder {
} unpacked_;

Conv2dMethod method_;
bool transposed_;
};

Tensor conv2d_clamp_run(
Expand Down

0 comments on commit 0acc21b

Please sign in to comment.