Skip to content

Commit

Permalink
[vulkan] Add mean.dim op for vulkan (#47312)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #47312

Test Plan:
```
cd ~/pytorch
BUILD_CUSTOM_PROTOBUF=OFF \
  BUILD_TEST=ON \
  USE_EIGEN_FOR_BLAS=OFF \
  USE_FBGEMM=OFF \
  USE_MKLDNN=OFF \
  USE_NNPACK=OFF \
  USE_NUMPY=OFF \
  USE_OBSERVERS=OFF \
  USE_PYTORCH_QNNPACK=OFF \
  USE_QNNPACK=OFF \
  USE_VULKAN=ON \
  USE_VULKAN_API=ON \
  USE_VULKAN_SHADERC_RUNTIME=ON \
  USE_VULKAN_WRAPPER=OFF \
  MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python3 setup.py develop --cmake && ./build/bin/vulkan_api_test
```

Reviewed By: IvanKobzarev

Differential Revision: D24713617

Pulled By: SS-JIA

fbshipit-source-id: 20c0f411fb390ad2114c7deff27cc6fc77448089
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Nov 4, 2020
1 parent 9b168a1 commit 464c569
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 80 deletions.
28 changes: 9 additions & 19 deletions aten/src/ATen/native/vulkan/glsl/mean.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,26 @@
#define PRECISION $precision
layout(std430) buffer;
layout(std430) uniform;
layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform constBlock {
layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
int W;
int H;
int OW;
int OH;
}
uConstBlock;
} uBlock;

layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;

void main() {
ivec3 pos = ivec3(gl_GlobalInvocationID);
int W = uConstBlock.W;
int H = uConstBlock.H;
int OW = uConstBlock.OW;
int OH = uConstBlock.OH;
vec4 r = vec4(1.0) / float(W) / float(H);
vec4 r = vec4(1.0) / (float(uBlock.W) * float(uBlock.H));
vec4 acc = vec4(0);
int xi, yi;
for (xi = 0; xi < W; ++xi) {
for (yi = 0; yi < H; ++yi) {
for (yi = 0; yi < uBlock.H; ++yi) {
for (xi = 0; xi < uBlock.W; ++xi) {
acc += texelFetch(uInput, ivec3(xi, yi, pos.z), 0);
}
}
vec4 outValue = r * acc;
for (int vi = 0; vi < 4; ++vi) {
int oy = (4 * pos.z + vi) / OW;
int ox = (4 * pos.z + vi) % OW;
imageStore(uOutput, ivec3(ox, oy, 0), vec4(outValue[vi], 0, 0, 0));
}

imageStore(uOutput, pos, outValue);
}
29 changes: 29 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/mean2d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#version 450 core
#define PRECISION $precision
layout(std430) buffer;
layout(std430) uniform;
layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
int W;
int H;
} uBlock;

layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;

void main() {
ivec3 pos = ivec3(gl_GlobalInvocationID);
vec4 r = vec4(1.0) / (float(uBlock.W) * float(uBlock.H));
vec4 acc = vec4(0);
int xi, yi;
int zi = (imageSize(uOutput).x*pos.y + pos.x)/4;
int zo = (imageSize(uOutput).x*pos.y + pos.x)%4;
for (yi = 0; yi < uBlock.H; ++yi) {
for (xi = 0; xi < uBlock.W; ++xi) {
acc += texelFetch(uInput, ivec3(xi, yi, zi), 0);
}
}
vec4 outValue = r * acc;

imageStore(uOutput, pos, vec4(outValue[zo], 0,0,0));
}
192 changes: 150 additions & 42 deletions aten/src/ATen/native/vulkan/ops/Pool.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <ATen/native/vulkan/ops/Common.h>
#include <ATen/native/Pool.h>
#include <ATen/native/vulkan/ops/Common.h>
#include <torch/library.h>

namespace at {
Expand All @@ -8,17 +8,104 @@ namespace vulkan {
namespace ops {
namespace {

Tensor adaptive_avg_pool2d(const at::Tensor& input_arg, IntArrayRef output_size) {
int64_t normalize_dim(int64_t d, int64_t n) {
return (d % n + n) % n;
}

Tensor mean(
const at::Tensor& input_arg,
const IntArrayRef dim,
const bool keepdim,
const optional<ScalarType> dtype) {
TORCH_INTERNAL_ASSERT(
input_arg.dim() == 4, "vulkan_mean expects 4-dimensional input");
static const std::unordered_set<int64_t> expected_dims_set({2, 3});
std::unordered_set<int64_t> dims_set;
for (const auto& d : dim) {
dims_set.insert(normalize_dim(d, 4));
}
TORCH_INTERNAL_ASSERT(
dims_set == expected_dims_set,
"vulkan_mean currently only supported for image-wide reduction");

std::vector<int64_t> output_dims{input_arg.sizes()[0], input_arg.sizes()[1]};
if (keepdim) {
output_dims.push_back(1);
output_dims.push_back(1);
}

api::Context* const context = api::context();
const vTensor& v_input = convert(input_arg);
vTensor v_output{
context,
output_dims,
input_arg.options(),
};

api::Command::Buffer command_buffer = context->command().pool.allocate();
command_buffer.begin();
{
if (v_input.has_image()) {
const struct {
uint32_t input_width, input_height;
} block{
input_arg.sizes()[3],
input_arg.sizes()[2],
};

if (keepdim) {
context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(mean),
v_output.extents(),
v_output.image(command_buffer, vTensor::Access::Write),
v_input.image(command_buffer),
context->resource().pool.uniform(block).object);
} else {
context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(mean2d),
v_output.extents(),
v_output.image(command_buffer, vTensor::Access::Write),
v_input.image(command_buffer),
context->resource().pool.uniform(block).object);
}
} else {
TORCH_CHECK(false, "Not implemented!");
}
}
command_buffer.end();
command_buffer.submit(context->gpu().queue);

return convert(v_output);
}

Tensor adaptive_avg_pool2d(
const at::Tensor& input_arg,
IntArrayRef output_size) {
TORCH_INTERNAL_ASSERT(
input_arg.dim() == 4,
"vulkan_adaptive_avg_pool2d expects 4-dimensional input");

api::Context* const context = api::context();
const vTensor& v_input = convert(input_arg);
vTensor v_output{
context,
{input_arg.sizes()[0], input_arg.sizes()[1], output_size[0], output_size[1]},
input_arg.options(),
context,
{input_arg.sizes()[0],
input_arg.sizes()[1],
output_size[0],
output_size[1]},
input_arg.options(),
};

api::Command::Buffer command_buffer = context->command().pool.allocate();
Expand All @@ -27,27 +114,26 @@ Tensor adaptive_avg_pool2d(const at::Tensor& input_arg, IntArrayRef output_size)
if (v_input.has_image()) {
const struct {
uint32_t input_width, input_height, output_width, output_height;
} block {
input_arg.sizes()[3],
input_arg.sizes()[2],
output_size[1],
output_size[0],
} block{
input_arg.sizes()[3],
input_arg.sizes()[2],
output_size[1],
output_size[0],
};

context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(adaptive_avg_pool2d),
v_output.extents(),
v_output.image(command_buffer, vTensor::Access::Write),
v_input.image(command_buffer),
context->resource().pool.uniform(block).object);
}
else {
} else {
TORCH_CHECK(false, "Not implemented!");
}
}
Expand All @@ -69,16 +155,17 @@ Tensor avg_pool2d(
kernel_size.size() == 1 || kernel_size.size() == 2,
"avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
const int kernel_height = safe_downcast<int>(kernel_size[0]);
const int kernel_width =
kernel_size.size() == 1 ? kernel_height : safe_downcast<int>(kernel_size[1]);
const int kernel_width = kernel_size.size() == 1
? kernel_height
: safe_downcast<int>(kernel_size[1]);

TORCH_CHECK(
stride.empty() || stride.size() == 1 || stride.size() == 2,
"avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
const int dH = stride.empty() ? kernel_height : safe_downcast<int>(stride[0]);
const int dW = stride.empty()
? kernel_width
: stride.size() == 1 ? dH : safe_downcast<int>(stride[1]);
const int dW = stride.empty() ? kernel_width
: stride.size() == 1 ? dH
: safe_downcast<int>(stride[1]);

TORCH_CHECK(
padding.size() == 1 || padding.size() == 2,
Expand All @@ -91,22 +178,35 @@ Tensor avg_pool2d(
const int64_t input_height = self.sizes()[2];
const int64_t input_width = self.sizes()[3];

const int64_t output_height =
pooling_output_shape<int64_t>(input_height, kernel_height, padH, dH, 1, ceil_mode);
const int64_t output_width =
pooling_output_shape<int64_t>(input_width, kernel_width, padW, dW, 1, ceil_mode);
const int64_t output_height = pooling_output_shape<int64_t>(
input_height, kernel_height, padH, dH, 1, ceil_mode);
const int64_t output_width = pooling_output_shape<int64_t>(
input_width, kernel_width, padW, dW, 1, ceil_mode);

pool2d_shape_check(
self, kernel_height, kernel_width, dH, dW, padH, padW, 1, 1, input_channels, input_height, input_width, output_height, output_width);
self,
kernel_height,
kernel_width,
dH,
dW,
padH,
padW,
1,
1,
input_channels,
input_height,
input_width,
output_height,
output_width);

api::Context* const context = api::context();

const vTensor& v_self = convert(self);

vTensor v_output{
context,
{input_batch, input_channels, output_height, output_width},
self.options(),
context,
{input_batch, input_channels, output_height, output_width},
self.options(),
};

api::Command::Buffer command_buffer = context->command().pool.allocate();
Expand All @@ -120,41 +220,49 @@ Tensor avg_pool2d(
uint32_t stride_x, stride_y;
uint32_t padding_x, padding_y;
uint32_t dilate_x, dilate_y;
} block {
input_width, input_height, input_batch * input_channels, 0u,
output_width, output_height, input_batch * input_channels, 0u,
kernel_width, kernel_height,
dW, dH,
padW, padH,
1u, 1u
};
} block{
input_width,
input_height,
input_batch * input_channels,
0u,
output_width,
output_height,
input_batch * input_channels,
0u,
kernel_width,
kernel_height,
dW,
dH,
padW,
padH,
1u,
1u};

context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(avg_pool2d),
v_output.extents(),
v_output.image(command_buffer, vTensor::Access::Write),
v_self.image(command_buffer),
context->resource().pool.uniform(block).object);
}
else {
} else {
TORCH_CHECK(false, "Not implemented!");
}
}
command_buffer.end();
command_buffer.submit(context->gpu().queue);

return convert(v_output);

}
#ifdef USE_VULKAN_API

TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl("mean.dim", TORCH_FN(mean));
m.impl("_adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d));
m.impl("avg_pool2d", TORCH_FN(avg_pool2d));
}
Expand Down

0 comments on commit 464c569

Please sign in to comment.