-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Pytorch] Add Vulkan support for aten::unsqueeze for 2d to 3d
Summary: Unsqueeze operator: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html#torch.unsqueeze Test Plan: Unsqueeze tests: https://www.internalfb.com/phabricator/paste/view/P738187802 ``` lfq@lfq-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*unsqueeze*" Downloaded 0/2 artifacts, 0.00 bytes, 100.0% cache miss (for updated rules) Building: finished in 15.0 sec (100%) 455/455 jobs, 2/455 updated Total time: 15.0 sec BUILD SUCCEEDED Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc Note: Google Test filter = *unsqueeze* [==========] Running 3 tests from 1 test suite. [----------] Global test environment set-up. [----------] 3 tests from VulkanAPITest [ RUN ] VulkanAPITest.unsqueeze_dim0 [ OK ] VulkanAPITest.unsqueeze_dim0 (96 ms) [ RUN ] VulkanAPITest.unsqueeze_dim1 [ OK ] VulkanAPITest.unsqueeze_dim1 (2 ms) [ RUN ] VulkanAPITest.unsqueeze_dim2 [ OK ] VulkanAPITest.unsqueeze_dim2 (3 ms) [----------] 3 tests from VulkanAPITest (101 ms total) [----------] Global test environment tear-down [==========] 3 tests from 1 test suite ran. (101 ms total) [ PASSED ] 3 tests. ``` All tests: buck run //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 https://www.internalfb.com/phabricator/paste/view/P738255852 Reviewed By: SS-JIA Differential Revision: D45893511 fbshipit-source-id: 36fbfb1c2274c981effb7083581c292657dd700c
- Loading branch information
1 parent
e3c66de
commit cb2fe9d
Showing
3 changed files
with
200 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
layout(std430) buffer; | ||
|
||
/* | ||
* Output Image | ||
*/ | ||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; | ||
|
||
/* | ||
* Input Sampler | ||
*/ | ||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uImage; | ||
|
||
/* | ||
* Params Buffer | ||
*/ | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
// dim: dimension to insert at | ||
ivec2 dim; | ||
} | ||
uBlock; | ||
|
||
/* | ||
* Local Work Group Size | ||
*/ | ||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
/* | ||
* Returns a new tensor with dimension of size one inserted at the specified | ||
* position (dim) | ||
*/ | ||
void main() { | ||
const ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
const int dim = uBlock.dim.x; | ||
vec4 out_texel = vec4(0, 0, 0, 0); | ||
if (dim == 0 || dim == -3) { | ||
imageStore(uOutput, pos, texelFetch(uImage, pos, 0)); | ||
} else if (dim == 1 || dim == -2) { | ||
int src_x = pos.x; | ||
int src_z = 0; | ||
for (int i = 0; i < 4; i++) { | ||
int src_y = pos.z * 4 + i; | ||
const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0); | ||
out_texel[i] = v[0]; | ||
} | ||
imageStore(uOutput, pos, out_texel); | ||
} else if (dim == 2 || dim == -1) { | ||
int src_x = pos.y; | ||
int src_z = 0; | ||
for (int i = 0; i < 4; i++) { | ||
int src_y = pos.z * 4 + i; | ||
const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0); | ||
out_texel[i] = v[0]; | ||
} | ||
imageStore(uOutput, pos, out_texel); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
|
||
#include <ATen/native/vulkan/ops/Common.h> | ||
#include <ATen/native/vulkan/ops/Utils.h> | ||
#include <torch/library.h> | ||
|
||
namespace at { | ||
namespace native { | ||
namespace vulkan { | ||
namespace ops { | ||
namespace { | ||
|
||
using namespace api::utils; | ||
|
||
struct Block final { | ||
ivec2 dim; | ||
}; | ||
|
||
Tensor unsqueeze_2dto3d(const at::Tensor& input_arg, int64_t dim) { | ||
// Get the global Vulkan context | ||
api::Context* const context = api::context(); | ||
|
||
// Cast the input Tensor to a vTensor | ||
const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); | ||
const vTensor& v_input = convert(input); | ||
|
||
// Create the output texture. For unsqueeze, add a dimension. | ||
std::vector<int64_t> output_size = input_arg.sizes().vec(); | ||
if (dim < 0) dim += 3; | ||
output_size.insert(output_size.begin() + dim, 1); | ||
const IntArrayRef v_input_sizes = v_input.sizes(); | ||
// Create the output texture | ||
vTensor v_output{ | ||
context, | ||
output_size, | ||
input_arg.scalar_type(), | ||
}; | ||
|
||
// Required to determine how to insert memory barriers in the command buffer | ||
api::PipelineBarrier pipeline_barrier{}; | ||
|
||
// Total number of work items is equal to the size of the output texture | ||
uvec3 global_size = v_output.extents(); | ||
// Adaptively determine local work group size, will usually be {4, 4, 4} | ||
uvec3 local_size = adaptive_work_group_size(global_size); | ||
|
||
// Create the params buffer | ||
struct Block block { | ||
{static_cast<int32_t>(dim)} | ||
}; | ||
api::UniformParamsBuffer params(context, block); | ||
|
||
context->submit_compute_job( | ||
// shader descriptor | ||
VK_KERNEL(unsqueeze_2dto3d), | ||
// pipeline barrier | ||
pipeline_barrier, | ||
// global work group size | ||
global_size, | ||
// local work group size | ||
local_size, | ||
// fence handle | ||
VK_NULL_HANDLE, | ||
// shader arguments | ||
v_output.image( | ||
pipeline_barrier, | ||
api::PipelineStage::COMPUTE, | ||
api::MemoryAccessType::WRITE), | ||
v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), | ||
// params buffer | ||
params.buffer()); | ||
|
||
return convert(v_output); | ||
} | ||
|
||
Tensor unsqueeze(const at::Tensor& self, int64_t dim) { | ||
TORCH_CHECK( | ||
self.dim() >= 1 || self.dim() <= 3, | ||
"Vulkan unsqueeze supports 1d, 2d, 3d tensors as input!"); | ||
TORCH_CHECK(dim >= -self.dim()-1 && dim <= self.dim(), | ||
"Vulkan unsqueeze dimension out of range (expected to be in range of [", | ||
-self.dim() - 1, ",", self.dim(), "], but got ", dim); | ||
// Remove this when 1d->2d and 3d->4d are supported. | ||
TORCH_CHECK(self.dim() == 2, "Vulkan unsqueeze expects input dimension = 2!"); | ||
return unsqueeze_2dto3d(self, dim); | ||
} | ||
|
||
#ifdef USE_VULKAN_API | ||
|
||
TORCH_LIBRARY_IMPL(aten, Vulkan, m) { | ||
m.impl(TORCH_SELECTIVE_NAME("aten::unsqueeze"), TORCH_FN(unsqueeze)); | ||
} | ||
|
||
#endif /* USE_VULKAN_API */ | ||
|
||
} // namespace | ||
} // namespace ops | ||
} // namespace vulkan | ||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters