-
Notifications
You must be signed in to change notification settings - Fork 21.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pytorch] Add support for "height" and "width" dimension for the "sel…
…ect" operator on pytorch vulkan backend (#94612) Summary: Add support for "height" and "width" dimension for the "select" operator on pytorch vulkan backend. Test Plan: ``` yipjustin@yipjustin-mbp fbsource % buck run -c pt.vulkan_full_precision=1 --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -- --gtest_filter="*select_3d*" Downloaded 1/2 artifacts, 1.29 Mbytes, 0.0% cache miss (for updated rules) Building: finished in 3.7 sec (100%) 450/450 jobs, 2/450 updated Total time: 3.8 sec BUILD SUCCEEDED Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc Note: Google Test filter = *select_3d* [==========] Running 9 tests from 1 test suite. [----------] Global test environment set-up. [----------] 9 tests from VulkanAPITest [ RUN ] VulkanAPITest.select_3d_depth_small [ OK ] VulkanAPITest.select_3d_depth_small (30 ms) [ RUN ] VulkanAPITest.select_3d_depth_medium [ OK ] VulkanAPITest.select_3d_depth_medium (0 ms) [ RUN ] VulkanAPITest.select_3d_depth_large [ OK ] VulkanAPITest.select_3d_depth_large (1 ms) [ RUN ] VulkanAPITest.select_3d_height_small [ OK ] VulkanAPITest.select_3d_height_small (0 ms) [ RUN ] VulkanAPITest.select_3d_height_medium [ OK ] VulkanAPITest.select_3d_height_medium (0 ms) [ RUN ] VulkanAPITest.select_3d_height_large [ OK ] VulkanAPITest.select_3d_height_large (3 ms) [ RUN ] VulkanAPITest.select_3d_width_small [ OK ] VulkanAPITest.select_3d_width_small (0 ms) [ RUN ] VulkanAPITest.select_3d_width_medium [ OK ] VulkanAPITest.select_3d_width_medium (0 ms) [ RUN ] VulkanAPITest.select_3d_width_large [ OK ] VulkanAPITest.select_3d_width_large (1 ms) [----------] 9 tests from VulkanAPITest (40 ms total) [----------] Global test environment tear-down [==========] 9 tests from 1 test suite ran. (40 ms total) [ PASSED ] 9 tests. ``` Reviewed By: SS-JIA Differential Revision: D43020796 Pull Request resolved: #94612 Approved by: https://github.com/SS-JIA
- Loading branch information
1 parent
fa1ea9f
commit f2c2642
Showing
4 changed files
with
240 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
layout(std430) buffer; | ||
|
||
/* Qualifiers: layout - storage - precision - memory */ | ||
|
||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; | ||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
ivec3 size; | ||
int index; | ||
} uBlock; | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
void main() { | ||
const ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
// w | ||
const int src_x = pos.x; | ||
// h | ||
const int src_y = uBlock.index; | ||
// c | ||
const int src_z = pos.y; | ||
|
||
const vec4 v = texelFetch(uInput, ivec3(src_x, src_y, src_z), 0); | ||
|
||
for (int i = 0; i < 4; i++) { | ||
ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); | ||
|
||
// When the C-channel exceeds original block size, exit early | ||
if (new_pos.y >= uBlock.size.y) { | ||
return; | ||
} | ||
|
||
imageStore(uOutput, new_pos, vec4(v[i], 0, 0, 0)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
layout(std430) buffer; | ||
|
||
/* Qualifiers: layout - storage - precision - memory */ | ||
|
||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; | ||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
ivec3 size; | ||
int index; | ||
} uBlock; | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
void main() { | ||
const ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
// w | ||
const int src_x = uBlock.index; | ||
// h | ||
const int src_y = pos.x; | ||
// c | ||
const int src_z = pos.y; | ||
|
||
const vec4 v = texelFetch(uInput, ivec3(src_x, src_y, src_z), 0); | ||
|
||
for (int i = 0; i < 4; i++) { | ||
ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); | ||
|
||
// When the C-channel exceeds original block size, exit early | ||
if (new_pos.y >= uBlock.size.y) { | ||
return; | ||
} | ||
|
||
imageStore(uOutput, new_pos, vec4(v[i], 0, 0, 0)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters