Skip to content

Commit

Permalink
[vulkan_api][ops] Mm, Pool, Upsample
Browse files Browse the repository at this point in the history
ghstack-source-id: 13da256a2f788240980fc58a3eceb4612cccc09d
Pull Request resolved: #47063
  • Loading branch information
IvanKobzarev committed Oct 30, 2020
1 parent ecdbea7 commit 7c974cc
Show file tree
Hide file tree
Showing 12 changed files with 616 additions and 90 deletions.
12 changes: 5 additions & 7 deletions aten/src/ATen/native/vulkan/VulkanOps.cpp
Expand Up @@ -66,7 +66,7 @@ void upsample_nearest2d(

WorkGroupSize workGroupSize{8, 8, 1};
auto& computeUnit = context().computeUnitFactory().get(
GLSL_SPV(upsampleNearest2d), descriptorSetLayout, workGroupSize);
GLSL_SPV(upsample_nearest2d), descriptorSetLayout, workGroupSize);
computeUnit.createCommandBuffer(descriptorSet);
input.image()->addImageMemoryBarrierToShaderRead(computeUnit.commandBuffer());
computeUnit.dispatchCommandBuffer(OW, OH, C, workGroupSize);
Expand Down Expand Up @@ -239,16 +239,16 @@ void avg_pool2d(
auto device = context().device();
const auto c = _n * _c;
struct ConstBlock {
int32_t inputSize[4];
int32_t outputSize[4];
int32_t inputSize[3];
int32_t outputSize[3];
int32_t kernelSize[2];
int32_t stride[2];
int32_t padding[2];
int32_t dilate[2];
};
ConstBlock cb{
{iW, iH, c, 0},
{oW, oH, c, 0},
{iW, iH, c},
{oW, oH, c},
{kW, kH},
{dW, dH},
{padW, padH},
Expand Down Expand Up @@ -1245,15 +1245,13 @@ void addmm(
int32_t OW;
int32_t OH;
int32_t C_4;
int32_t C;
float beta;
float alpha;
int32_t K;
};
ConstBlock cb{safe_downcast<int32_t>(OW),
safe_downcast<int32_t>(OH),
safe_downcast<int32_t>(C_4),
safe_downcast<int32_t>(C),
beta,
alpha,
safe_downcast<int32_t>(K)};
Expand Down
21 changes: 12 additions & 9 deletions aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl
@@ -1,26 +1,29 @@
#version 450 core
#define PRECISION $precision

layout(std430) buffer;
layout(std430) uniform;
layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform constBlock {

/* Qualifiers: layout - storage - precision - memory */

layout(set = 0, binding = 0, rgba16f) uniform PRECISION writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform restrict Block {
int IW;
int IH;
int OW;
int OH;
}
uConstBlock;
} uBlock;

layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;

void main() {
ivec3 pos = ivec3(gl_GlobalInvocationID);
int ow = uConstBlock.OW;
int oh = uConstBlock.OH;
int ow = uBlock.OW;
int oh = uBlock.OH;
if (pos.x < ow && pos.y < oh) {
int iw = uConstBlock.IW;
int ih = uConstBlock.IH;
int iw = uBlock.IW;
int ih = uBlock.IH;

int sx = int(floor(float(pos.x * iw) / ow));
int sy = int(floor(float(pos.y * ih) / oh));
Expand Down
26 changes: 14 additions & 12 deletions aten/src/ATen/native/vulkan/glsl/addmm.glsl
Expand Up @@ -2,24 +2,26 @@
#define PRECISION $precision
layout(std430) buffer;
layout(std430) uniform;
layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1;
layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2;
layout(set = 0, binding = 3) uniform constBlock {
ivec4 outputSize;

/* Qualifiers: layout - storage - precision - memory */

layout(set = 0, binding = 0, rgba16f) uniform PRECISION writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1;
layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2;
layout(set = 0, binding = 3) uniform restrict Block {
ivec3 WHC;
float beta;
float alpha;
int K;
}
uConstBlock;
layout(set = 0, binding = 4) uniform PRECISION sampler3D uT;
} uBlock;
layout(set = 0, binding = 4) uniform PRECISION sampler3D uT;

layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;

void main() {
ivec3 pos = ivec3(gl_GlobalInvocationID);
if (all(lessThan(pos, uConstBlock.outputSize.xyz))) {
int K = uConstBlock.K;
const ivec3 pos = ivec3(gl_GlobalInvocationID);
if (all(lessThan(pos, uBlock.WHC))) {
const int K = uBlock.K;
vec4 mmv = vec4(0);
int ki = 0;
for (; ki < K; ++ki) {
Expand All @@ -28,6 +30,6 @@ void main() {
mmv += m1ki * m2ki;
}
vec4 tv = texelFetch(uT, pos, 0);
imageStore(uOutput, pos, uConstBlock.beta * tv + uConstBlock.alpha * mmv);
imageStore(uOutput, pos, uBlock.beta * tv + uBlock.alpha * mmv);
}
}
29 changes: 15 additions & 14 deletions aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl
@@ -1,31 +1,32 @@
#version 450 core
#define PRECISION $precision

layout(std430) buffer;
layout(std430) uniform;
layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput;
layout(set = 0, binding = 1) uniform highp sampler3D uInput;
layout(set = 0, binding = 2) uniform constBlock {
ivec4 inputSize;
ivec4 outputSize;

/* Qualifiers: layout - storage - precision - memory */

layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform restrict Block {
ivec3 inputSize;
ivec3 outputSize;
ivec2 kernelSize;
ivec2 stride;
ivec2 padding;
ivec2 dilate;
}
uConstBlock;
} uBlock;

#define UP_DIV(x, y) (((x) + (y)-1) / (y))

layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;

void main() {
ivec3 pos = ivec3(gl_GlobalInvocationID);
ivec3 outputSize = uConstBlock.outputSize.xyz;
if (all(lessThan(pos, outputSize))) {
ivec2 s0 = pos.xy * uConstBlock.stride - uConstBlock.padding;
ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uConstBlock.dilate)));
ivec2 efxy =
min(uConstBlock.kernelSize,
UP_DIV(uConstBlock.inputSize.xy - s0, uConstBlock.dilate));
if (all(lessThan(pos, uBlock.outputSize))) {
ivec2 s0 = pos.xy * uBlock.stride - uBlock.padding;
ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uBlock.dilate)));
ivec2 efxy = min(uBlock.kernelSize, UP_DIV(uBlock.inputSize.xy - s0, uBlock.dilate));

vec4 r = vec4(1.0) / float(efxy.x - sfxy.x) / float(efxy.x - sfxy.x);
vec4 acc = vec4(0);
Expand Down
26 changes: 13 additions & 13 deletions aten/src/ATen/native/vulkan/glsl/mm.glsl
Expand Up @@ -2,30 +2,30 @@
#define PRECISION $precision
layout(std430) buffer;
layout(std430) uniform;
layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1;
layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2;
layout(set = 0, binding = 3) uniform constBlock {
ivec4 outputSize;
float beta;
float alpha;

/* Qualifiers: layout - storage - precision - memory */

layout(set = 0, binding = 0, rgba16f) uniform PRECISION writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1;
layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2;
layout(set = 0, binding = 3) uniform restrict Block {
ivec3 WHC;
int K;
}
uConstBlock;
} uBlock;

layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;

void main() {
ivec3 pos = ivec3(gl_GlobalInvocationID);
if (all(lessThan(pos, uConstBlock.outputSize.xyz))) {
int K = uConstBlock.K;
const ivec3 pos = ivec3(gl_GlobalInvocationID);
if (all(lessThan(pos, uBlock.WHC))) {
const int K = uBlock.K;
vec4 mmv = vec4(0);
int ki = 0;
for (; ki < K; ++ki) {
vec4 m1ki = texelFetch(uM1, ivec3(ki, pos.y, pos.z), 0);
vec4 m2ki = texelFetch(uM2, ivec3(pos.x, ki, pos.z), 0);
mmv += m1ki * m2ki;
}
imageStore(uOutput, pos, uConstBlock.alpha * mmv);
imageStore(uOutput, pos, mmv);
}
}
35 changes: 0 additions & 35 deletions aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl

This file was deleted.

39 changes: 39 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl
@@ -0,0 +1,39 @@
#version 450 core
#define PRECISION $precision

layout(std430) buffer;
layout(std430) uniform;

/* Qualifiers: layout - storage - precision - memory */

layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform restrict Block {
int input_width;
int input_height;
int output_width;
int output_height;
float scale_x;
float scale_y;
}
uBlock;

layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;

void main() {
ivec3 pos = ivec3(gl_GlobalInvocationID);
const int ow = uBlock.output_width;
const int oh = uBlock.output_height;
if (pos.x < ow && pos.y < oh) {
const int iw = uBlock.input_width;
const int ih = uBlock.input_height;
float srcX = float(pos.x) * uBlock.scale_x;
int x1 = int(floor(srcX));
int x11 = clamp(x1, 0, iw - 1);
float srcY = float(pos.y) * uBlock.scale_y;
int y1 = int(floor(srcY));
int y11 = clamp(y1, 0, ih - 1);
vec4 outValue = texelFetch(uInput, ivec3(x11, y11, pos.z), 0);
imageStore(uOutput, pos, outValue);
}
}
38 changes: 38 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Common.h
Expand Up @@ -6,4 +6,42 @@
#include <ATen/native/vulkan/api/api.h>
#include <ATen/native/vulkan/ops/Tensor.h>

namespace at {
namespace native {
namespace vulkan {

template <typename To, typename From>
inline constexpr To safe_downcast_internal(const From v) {
typedef std::common_type_t<From, To> Type;
constexpr Type min{static_cast<Type>(std::numeric_limits<To>::lowest())};
constexpr Type max{static_cast<Type>(std::numeric_limits<To>::max())};
TORCH_CHECK(min <= v && v <= max, "Cast failed: out of range");
return static_cast<To>(v);
}

template <typename To, typename From>
inline constexpr bool is_signed_to_unsigned() {
return std::is_signed<From>::value && std::is_unsigned<To>::value;
}

template <
typename To,
typename From,
std::enable_if_t<is_signed_to_unsigned<To, From>(), bool> = true>
inline constexpr To safe_downcast(const From v) {
TORCH_CHECK(v >= From{}, "Cast failed: negative signed to unsigned");
return safe_downcast_internal<To, From>(v);
}

template <
typename To,
typename From,
std::enable_if_t<!is_signed_to_unsigned<To, From>(), bool> = true>
inline constexpr To safe_downcast(const From v) {
return safe_downcast_internal<To, From>(v);
}

} // namespace vulkan
} // namespace native
} // namespace at
#endif /* USE_VULKAN_API */

0 comments on commit 7c974cc

Please sign in to comment.