diff --git a/aten/src/ATen/native/vulkan/VulkanOps.cpp b/aten/src/ATen/native/vulkan/VulkanOps.cpp index 0e13dce41a7c..4fc256fa7355 100644 --- a/aten/src/ATen/native/vulkan/VulkanOps.cpp +++ b/aten/src/ATen/native/vulkan/VulkanOps.cpp @@ -67,7 +67,7 @@ void upsample_nearest2d( WorkGroupSize workGroupSize{8, 8, 1}; auto& computeUnit = context().computeUnitFactory().get( - GLSL_SPV(upsampleNearest2d), descriptorSetLayout, workGroupSize); + GLSL_SPV(upsample_nearest2d), descriptorSetLayout, workGroupSize); computeUnit.createCommandBuffer(descriptorSet); input.image()->addImageMemoryBarrierToShaderRead(computeUnit.commandBuffer()); computeUnit.dispatchCommandBuffer(OW, OH, C, workGroupSize); @@ -240,16 +240,16 @@ void avg_pool2d( auto device = context().device(); const auto c = _n * _c; struct ConstBlock { - int32_t inputSize[4]; - int32_t outputSize[4]; + int32_t inputSize[3]; + int32_t outputSize[3]; int32_t kernelSize[2]; int32_t stride[2]; int32_t padding[2]; int32_t dilate[2]; }; ConstBlock cb{ - {iW, iH, c, 0}, - {oW, oH, c, 0}, + {iW, iH, c}, + {oW, oH, c}, {kW, kH}, {dW, dH}, {padW, padH}, @@ -1243,7 +1243,6 @@ void addmm( int32_t OW; int32_t OH; int32_t C_4; - int32_t C; float beta; float alpha; int32_t K; @@ -1251,7 +1250,6 @@ void addmm( ConstBlock cb{safe_downcast(OW), safe_downcast(OH), safe_downcast(C_4), - safe_downcast(C), beta, alpha, safe_downcast(K)}; diff --git a/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl index ab07da5e4897..a1f3d6f21df9 100644 --- a/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl @@ -1,26 +1,29 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { int IW; int IH; int OW; int OH; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); - int ow = uConstBlock.OW; - int oh = uConstBlock.OH; + int ow = uBlock.OW; + int oh = uBlock.OH; if (pos.x < ow && pos.y < oh) { - int iw = uConstBlock.IW; - int ih = uConstBlock.IH; + int iw = uBlock.IW; + int ih = uBlock.IH; int sx = int(floor(float(pos.x * iw) / ow)); int sy = int(floor(float(pos.y * ih) / oh)); diff --git a/aten/src/ATen/native/vulkan/glsl/addmm.glsl b/aten/src/ATen/native/vulkan/glsl/addmm.glsl index 79987990e595..55fa3da02e0b 100644 --- a/aten/src/ATen/native/vulkan/glsl/addmm.glsl +++ b/aten/src/ATen/native/vulkan/glsl/addmm.glsl @@ -2,24 +2,26 @@ #define PRECISION $precision layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; -layout(set = 0, binding = 3) uniform constBlock { - ivec4 outputSize; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; +layout(set = 0, binding = 3) uniform restrict Block { + ivec3 WHC; float beta; float alpha; int K; -} -uConstBlock; -layout(set = 0, binding = 4) uniform PRECISION sampler3D uT; +} uBlock; +layout(set = 0, binding = 4) uniform PRECISION sampler3D uT; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.outputSize.xyz))) { - int K = uConstBlock.K; + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.WHC))) { + const int K = uBlock.K; vec4 mmv = vec4(0); int ki = 0; for (; ki < K; ++ki) { @@ -28,6 +30,6 @@ void main() { mmv += m1ki * m2ki; } vec4 tv = texelFetch(uT, pos, 0); - imageStore(uOutput, pos, uConstBlock.beta * tv + uConstBlock.alpha * mmv); + imageStore(uOutput, pos, uBlock.beta * tv + uBlock.alpha * mmv); } } diff --git a/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl index 552e75c11d59..7de1455a9051 100644 --- a/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl @@ -1,17 +1,21 @@ #version 450 core +#define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput; -layout(set = 0, binding = 1) uniform highp sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { ivec4 inputSize; ivec4 outputSize; ivec2 kernelSize; ivec2 stride; ivec2 padding; ivec2 dilate; -} -uConstBlock; +} uBlock; #define UP_DIV(x, y) (((x) + (y)-1) / (y)) @@ -19,13 +23,10 @@ layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); - ivec3 outputSize = uConstBlock.outputSize.xyz; - if (all(lessThan(pos, outputSize))) { - ivec2 s0 = pos.xy * uConstBlock.stride - uConstBlock.padding; - ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uConstBlock.dilate))); - ivec2 efxy = - min(uConstBlock.kernelSize, - UP_DIV(uConstBlock.inputSize.xy - s0, uConstBlock.dilate)); + if (all(lessThan(pos, uBlock.outputSize.xyz))) { + ivec2 s0 = pos.xy * uBlock.stride - uBlock.padding; + ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uBlock.dilate))); + ivec2 efxy = min(uBlock.kernelSize, UP_DIV(uBlock.inputSize.xy - s0, uBlock.dilate)); vec4 r = vec4(1.0) / float(efxy.x - sfxy.x) / float(efxy.x - sfxy.x); vec4 acc = vec4(0); diff --git a/aten/src/ATen/native/vulkan/glsl/mm.glsl b/aten/src/ATen/native/vulkan/glsl/mm.glsl index 771617d64b8a..2d39b28802e5 100644 --- a/aten/src/ATen/native/vulkan/glsl/mm.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mm.glsl @@ -2,23 +2,23 @@ #define PRECISION $precision layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; -layout(set = 0, binding = 3) uniform constBlock { - ivec4 outputSize; - float beta; - float alpha; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; +layout(set = 0, binding = 3) uniform restrict Block { + ivec3 WHC; int K; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.outputSize.xyz))) { - int K = uConstBlock.K; + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.WHC))) { + const int K = uBlock.K; vec4 mmv = vec4(0); int ki = 0; for (; ki < K; ++ki) { @@ -26,6 +26,6 @@ void main() { vec4 m2ki = texelFetch(uM2, ivec3(pos.x, ki, pos.z), 0); mmv += m1ki * m2ki; } - imageStore(uOutput, pos, uConstBlock.alpha * mmv); + imageStore(uOutput, pos, mmv); } } diff --git a/aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl deleted file mode 100644 index d7e4619a283a..000000000000 --- a/aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl +++ /dev/null @@ -1,35 +0,0 @@ -#version 450 core -#define PRECISION $precision -layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - int IW; - int IH; - int OW; - int OH; - float scaleX; - float scaleY; -} -uConstBlock; - -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; - -void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - int ow = uConstBlock.OW; - int oh = uConstBlock.OH; - if (pos.x < ow && pos.y < oh) { - int iw = uConstBlock.IW; - int ih = uConstBlock.IH; - float srcX = float(pos.x) * uConstBlock.scaleX; - int x1 = int(floor(srcX)); - int x11 = clamp(x1, 0, iw - 1); - float srcY = float(pos.y) * uConstBlock.scaleY; - int y1 = int(floor(srcY)); - int y11 = clamp(y1, 0, ih - 1); - vec4 outValue = texelFetch(uInput, ivec3(x11, y11, pos.z), 0); - imageStore(uOutput, pos, outValue); - } -} diff --git a/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl new file mode 100644 index 000000000000..9e0da8bf6211 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl @@ -0,0 +1,39 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; +layout(std430) uniform; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { + int input_width; + int input_height; + int output_width; + int output_height; + float scale_x; + float scale_y; +} +uBlock; + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + ivec3 pos = ivec3(gl_GlobalInvocationID); + const int ow = uBlock.output_width; + const int oh = uBlock.output_height; + if (pos.x < ow && pos.y < oh) { + const int iw = uBlock.input_width; + const int ih = uBlock.input_height; + float srcX = float(pos.x) * uBlock.scale_x; + int x1 = int(floor(srcX)); + int x11 = clamp(x1, 0, iw - 1); + float srcY = float(pos.y) * uBlock.scale_y; + int y1 = int(floor(srcY)); + int y11 = clamp(y1, 0, ih - 1); + vec4 outValue = texelFetch(uInput, ivec3(x11, y11, pos.z), 0); + imageStore(uOutput, pos, outValue); + } +} diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h index 121b40cbdb4b..91fc585fe193 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.h +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -6,4 +6,42 @@ #include #include +namespace at { +namespace native { +namespace vulkan { + +template +inline constexpr To safe_downcast_internal(const From v) { + typedef std::common_type_t Type; + constexpr Type min{static_cast(std::numeric_limits::lowest())}; + constexpr Type max{static_cast(std::numeric_limits::max())}; + TORCH_CHECK(min <= v && v <= max, "Cast failed: out of range"); + return static_cast(v); +} + +template +inline constexpr bool is_signed_to_unsigned() { + return std::is_signed::value && std::is_unsigned::value; +} + +template < + typename To, + typename From, + std::enable_if_t(), bool> = true> +inline constexpr To safe_downcast(const From v) { + TORCH_CHECK(v >= From{}, "Cast failed: negative signed to unsigned"); + return safe_downcast_internal(v); +} + +template < + typename To, + typename From, + std::enable_if_t(), bool> = true> +inline constexpr To safe_downcast(const From v) { + return safe_downcast_internal(v); +} + +} // namespace vulkan +} // namespace native +} // namespace at #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp new file mode 100644 index 000000000000..cf39d845178a --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -0,0 +1,139 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor addmm( + const Tensor& self_arg, + const Tensor& mat1_arg, + const Tensor& mat2_arg, + const Scalar beta, + const Scalar alpha) { + api::Context* const context = api::context(); + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + const Tensor mat1 = mat1_arg.is_vulkan() ? mat1_arg : mat1_arg.vulkan(); + const vTensor& v_mat1 = convert(mat1); + + const Tensor mat2 = mat2_arg.is_vulkan() ? mat2_arg : mat2_arg.vulkan(); + const vTensor& v_mat2 = convert(mat2); + + vTensor v_output{ + context, + {mat1.sizes()[0], mat2.sizes()[1]}, + self.options() + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float beta, alpha; + uint32_t k; + } block { + mat2_arg.sizes()[1], + mat1_arg.sizes()[0], + 1u, + beta.to(), + alpha.to(), + mat1_arg.sizes()[1], + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + }, + VK_KERNEL(addmm), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_mat1.image(command_buffer), + v_mat2.image(command_buffer), + context->resource().pool.uniform(block).object, + v_self.image(command_buffer)); + } else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor mm(const Tensor& self_arg, const Tensor& mat2_arg) { + api::Context* const context = api::context(); + const Tensor mat1 = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_mat1 = convert(mat1); + + const Tensor mat2 = mat2_arg.is_vulkan() ? mat2_arg : mat2_arg.vulkan(); + const vTensor& v_mat2 = convert(mat2); + + vTensor v_output{ + context, + {mat1.sizes()[0], mat2.sizes()[1]}, + mat1.options() + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_mat1.has_image() && v_mat2.has_image()) { + const struct { + uint32_t width, height, channels, k; + } block { + mat2.sizes()[1], + mat1.sizes()[0], + 1u, + mat1.sizes()[1], + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(mm), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_mat1.image(command_buffer), + v_mat2.image(command_buffer), + context->resource().pool.uniform(block).object); + } else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("addmm", TORCH_FN(addmm)); + m.impl("mm", TORCH_FN(mm)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Pool.cpp b/aten/src/ATen/native/vulkan/ops/Pool.cpp new file mode 100644 index 000000000000..8c2c05ff26a3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Pool.cpp @@ -0,0 +1,168 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor adaptive_avg_pool2d(const at::Tensor& input_arg, IntArrayRef output_size) { + TORCH_INTERNAL_ASSERT( + input_arg.dim() == 4, + "vulkan_adaptive_avg_pool2d expects 4-dimensional input"); + + api::Context* const context = api::context(); + const vTensor& v_input = convert(input_arg); + vTensor v_output{ + context, + {input_arg.sizes()[0], input_arg.sizes()[1], output_size[0], output_size[1]}, + input_arg.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_input.has_image()) { + const struct { + uint32_t input_width, input_height, output_width, output_height; + } block { + input_arg.sizes()[3], + input_arg.sizes()[2], + output_size[1], + output_size[0], + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(adaptive_avg_pool2d), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_input.image(command_buffer), + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor avg_pool2d( + const Tensor& self, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override) { + TORCH_CHECK( + kernel_size.size() == 1 || kernel_size.size() == 2, + "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints"); + const int kernel_height = safe_downcast(kernel_size[0]); + const int kernel_width = + kernel_size.size() == 1 ? kernel_height : safe_downcast(kernel_size[1]); + + TORCH_CHECK( + stride.empty() || stride.size() == 1 || stride.size() == 2, + "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints"); + const int dH = stride.empty() ? kernel_height : safe_downcast(stride[0]); + const int dW = stride.empty() + ? kernel_width + : stride.size() == 1 ? dH : safe_downcast(stride[1]); + + TORCH_CHECK( + padding.size() == 1 || padding.size() == 2, + "avg_pool2d: padding must either be a single int, or a tuple of two ints"); + const int padH = safe_downcast(padding[0]); + const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); + + const int64_t input_batch = self.sizes()[0]; + const int64_t input_channels = self.sizes()[1]; + const int64_t input_height = self.sizes()[2]; + const int64_t input_width = self.sizes()[3]; + + const int64_t output_height = + pooling_output_shape(input_height, kernel_height, padH, dH, 1, ceil_mode); + const int64_t output_width = + pooling_output_shape(input_width, kernel_width, padW, dW, 1, ceil_mode); + + pool2d_shape_check( + self, kernel_height, kernel_width, dH, dW, padH, padW, 1, 1, input_channels, input_height, input_width, output_height, output_width); + + api::Context* const context = api::context(); + + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + {input_batch, input_channels, output_height, output_width}, + self.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image()) { + const struct { + uint32_t input_width, input_height, input_channels, input_size_stub; + uint32_t output_width, output_height, output_channels, output_size_stub; + uint32_t kernel_width, kernel_height; + uint32_t stride_x, stride_y; + uint32_t padding_x, padding_y; + uint32_t dilate_x, dilate_y; + } block { + input_width, input_height, input_batch * input_channels, 0u, + output_width, output_height, input_batch * input_channels, 0u, + kernel_width, kernel_height, + dW, dH, + padW, padH, + 1u, 1u + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(avg_pool2d), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_self.image(command_buffer), + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); + +} +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("_adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d)); + m.impl("avg_pool2d", TORCH_FN(avg_pool2d)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Upsample.cpp b/aten/src/ATen/native/vulkan/ops/Upsample.cpp new file mode 100644 index 000000000000..2a95751e59bd --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Upsample.cpp @@ -0,0 +1,80 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor upsample_nearest2d( + const Tensor& input_arg, + const IntArrayRef output_sizes, + const c10::optional scales_h, + const c10::optional scales_w) { + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + + vTensor v_output{ + context, + {input_arg.sizes()[0], input_arg.sizes()[1], output_sizes[0], output_sizes[1]}, + input.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + const float scale_x = compute_scales_value(scales_w, input_arg.sizes()[3], output_sizes[1]); + const float scale_y = compute_scales_value(scales_h, input_arg.sizes()[2], output_sizes[0]); + if (v_input.has_image()) { + const struct { + uint32_t input_width, input_height, output_width, output_height; + float scale_x, scale_y; + } block { + input_arg.sizes()[3], + input_arg.sizes()[2], + output_sizes[1], + output_sizes[0], + scale_x, + scale_y + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(upsample_nearest2d), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_input.image(command_buffer), + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("upsample_nearest2d", TORCH_FN(upsample_nearest2d)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 05febc539966..814b9ba667ae 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -103,6 +103,99 @@ TEST(VulkanAPITest, mul_scalar_) { ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu())); } +TEST(VulkanTest, addmm) { + auto t_m1 = at::rand({2, 2}, at::device(at::kCPU).dtype(at::kFloat)); + auto t_m2 = at::rand({2, 3}, at::device(at::kCPU).dtype(at::kFloat)); + auto t_b = at::rand({2, 3}, at::device(at::kCPU).dtype(at::kFloat)); + + float beta = 100; + float alpha = 2; + auto t_out_expected = at::addmm(t_b, t_m1, t_m2, beta, alpha); + + auto tv_m1 = t_m1.vulkan(); + auto tv_m2 = t_m2.vulkan(); + auto tv_b = t_b.vulkan(); + auto tv_out = at::addmm(tv_b, tv_m1, tv_m2, beta, alpha); + auto t_out = tv_out.cpu(); + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + +TEST(VulkanTest, mm) { + auto t_m1 = at::rand({2, 3}, at::device(at::kCPU).dtype(at::kFloat)); + auto t_m2 = at::rand({3, 2}, at::device(at::kCPU).dtype(at::kFloat)); + + auto t_out_expected = t_m1.mm(t_m2); + + auto tv_m1 = t_m1.vulkan(); + auto tv_m2 = t_m2.vulkan(); + auto tv_out = tv_m1.mm(tv_m2); + auto t_out = tv_out.cpu(); + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + +TEST(VulkanTest, adaptive_avg_pool2d) { + auto t_in = + at::rand({1, 2, 7, 7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto t_out_expected = at::adaptive_avg_pool2d(t_in, {3, 3}); + auto tv_in = t_in.vulkan(); + + auto tv_out = at::adaptive_avg_pool2d(tv_in, {3, 3}); + auto t_out = tv_out.cpu(); + + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + +TEST(VulkanTest, upsample_nearest2d) { + auto t_in = at::rand({1, 2, 2, 3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto t_out_expected = at::upsample_nearest2d(t_in, {4, 6}); + auto tv_in = t_in.vulkan(); + + auto tv_out = at::upsample_nearest2d(tv_in, {4, 6}); + auto t_out = tv_out.cpu(); + + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + +TEST(VulkanTest, avg_pool2d) { + if (!at::is_vulkan_available()) + return; + + auto t_in = + at::rand({1, 3, 7, 7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto t_out_expected = at::avg_pool2d(t_in, {2, 2}, {1}, {0}, {1}); + auto tv_in = t_in.vulkan(); + + auto tv_out = at::avg_pool2d(tv_in, {2, 2}, {1}, {0}, {1}); + auto t_out = tv_out.cpu(); + + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + TEST(VulkanAPITest, copy) { const auto cpu = at::rand({13, 17, 37, 19}, at::device(at::kCPU).dtype(at::kFloat)); ASSERT_TRUE(exactlyEqual(cpu, cpu.vulkan().cpu()));