-
Notifications
You must be signed in to change notification settings - Fork 22.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: We implement [`torch.mean(input, dim, keepdim)`](https://pytorch.org/docs/stable/generated/torch.mean.html) for tensors of 2d to 4d. Since 0-dim tensor hasn't been supported yet, we only support `dim.size() < input.dim()` for now. We will support following cases in the future work: - `dim.size() == input.dim()` - `input.dim() == 1` Test Plan: ``` [luwei@devbig984.prn1 /data/users/luwei/fbsource (970fcd90c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*mean*" Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated Total time: 0.1 sec BUILD SUCCEEDED Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc Note: Google Test filter = *mean* [==========] Running 7 tests from 1 test suite. [----------] Global test environment set-up. [----------] 7 tests from VulkanAPITest [ RUN ] VulkanAPITest.mean_invalid_inputs [ OK ] VulkanAPITest.mean_invalid_inputs (46 ms) [ RUN ] VulkanAPITest.mean_dim_2d [ OK ] VulkanAPITest.mean_dim_2d (127 ms) [ RUN ] VulkanAPITest.mean_dim_3d [ OK ] VulkanAPITest.mean_dim_3d (103 ms) [ RUN ] VulkanAPITest.mean_dim_4d [ OK ] VulkanAPITest.mean_dim_4d (89 ms) [ RUN ] VulkanAPITest.mean_dim_keepdim_2d [ OK ] VulkanAPITest.mean_dim_keepdim_2d (66 ms) [ RUN ] VulkanAPITest.mean_dim_keepdim_3d [ OK ] VulkanAPITest.mean_dim_keepdim_3d (127 ms) [ RUN ] VulkanAPITest.mean_dim_keepdim_4d [ OK ] VulkanAPITest.mean_dim_keepdim_4d (4 ms) [----------] 7 tests from VulkanAPITest (564 ms total) [----------] Global test environment tear-down [==========] 7 tests from 1 test suite ran. (564 ms total) [ PASSED ] 7 tests. ``` Reviewed By: yipjustin Differential Revision: D50312990
- Loading branch information
1 parent
6662435
commit 7fbfdbf
Showing
6 changed files
with
351 additions
and
229 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
layout(std430) buffer; | ||
|
||
/* Qualifiers: layout - storage - precision - memory */ | ||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; | ||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
// dim_info.x: dim to compute mean | ||
// dim_info.y: size of dim (in the input) | ||
uvec2 dim_info; | ||
int channel; | ||
} | ||
uBlock; | ||
|
||
/* | ||
* Local Work Group Size | ||
*/ | ||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
/* | ||
* Returns a new tensor with values averaged along dimension dim | ||
* Dimension dim is squeezed | ||
* For each pos: | ||
* - Iterate over the out_texel and the averaged dimension | ||
* - For H,W; rearrange pos.x, pos.y | ||
* - For C,H,W; | ||
* When CHW are averaged, batch moves into channel | ||
* The src N is determined by pos.z * 4 + out_index | ||
*/ | ||
|
||
void main() { | ||
const ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
int flattened_channels = int(ceil(uBlock.channel / 4.0)); | ||
vec4 out_texel = vec4(0, 0, 0, 0); | ||
|
||
// Batch | ||
if (uBlock.dim_info.x == 0) { | ||
for (int batch = 0; batch < uBlock.dim_info.y; batch++) { | ||
// src_n = batch | ||
// src_c = pos.z | ||
int src_z = batch * flattened_channels + pos.z; | ||
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0); | ||
out_texel += v; | ||
} | ||
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y); | ||
} | ||
|
||
// Channel | ||
else if (uBlock.dim_info.x == 1) { | ||
for (int out_index = 0; out_index < 4; out_index++) { | ||
for (int channel = 0; channel < uBlock.dim_info.y; channel++) { | ||
// src_n = pos.z * 4 + out_index | ||
// src_c = channel | ||
int src_z = | ||
(pos.z * 4 + out_index) * flattened_channels + int(channel / 4); | ||
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0); | ||
out_texel[out_index] += v[channel % 4]; | ||
} | ||
} | ||
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y); | ||
} | ||
|
||
// Height, Width | ||
else { | ||
for (int out_index = 0; out_index < 4; out_index++) { | ||
// src_n = pos.z * 4 + out_index | ||
// src_c = pos.y | ||
int src_z = (pos.z * 4 + out_index) * flattened_channels + pos.y / 4; | ||
for (int hw = 0; hw < uBlock.dim_info.y; hw++) { | ||
vec4 v = (uBlock.dim_info.x == 2) | ||
? texelFetch(uInput, ivec3(pos.x, hw, src_z), 0) // Height | ||
: texelFetch(uInput, ivec3(hw, pos.x, src_z), 0); // Width | ||
out_texel[out_index] += v[pos.y % 4]; | ||
} | ||
} | ||
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
layout(std430) buffer; | ||
|
||
/* Qualifiers: layout - storage - precision - memory */ | ||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; | ||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
// dim_info.x: dim to compute mean | ||
// dim_info.y: size of dim (in the input) | ||
uvec2 dim_info; | ||
int channel; | ||
} | ||
uBlock; | ||
|
||
/* | ||
* Local Work Group Size | ||
*/ | ||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
/* | ||
* Returns a new tensor with values averaged along dimension dim. | ||
* Output and input have same number of dimensions. | ||
* averaged dimension is of size 1. | ||
*/ | ||
|
||
void main() { | ||
const ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
int flattened_channels = int(ceil(uBlock.channel / 4.0)); | ||
vec4 out_texel = vec4(0, 0, 0, 0); | ||
|
||
// Batch | ||
if (uBlock.dim_info.x == 0) { | ||
for (int batch = 0; batch < uBlock.dim_info.y; batch++) { | ||
// src_n = batch | ||
// src_c = pos.z | ||
int src_z = batch * flattened_channels + pos.z; | ||
out_texel += texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0); | ||
} | ||
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y); | ||
} | ||
|
||
// Channel | ||
else if (uBlock.dim_info.x == 1) { | ||
for (int out_index = 0; out_index < 4; out_index++) { | ||
for (int channel = 0; channel < uBlock.dim_info.y; channel++) { | ||
// src_n = pos.z | ||
// src_c = channel | ||
int src_z = pos.z * flattened_channels + int(channel / 4); | ||
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0); | ||
out_texel[out_index] += v[channel % 4]; | ||
} | ||
} | ||
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y); | ||
} | ||
|
||
// Height, Width | ||
else { | ||
for (int hw = 0; hw < uBlock.dim_info.y; hw++) { | ||
vec4 v = (uBlock.dim_info.x == 2) | ||
? texelFetch(uInput, ivec3(pos.x, hw, pos.z), 0) // Height | ||
: texelFetch(uInput, ivec3(hw, pos.y, pos.z), 0); // Width | ||
out_texel += v; | ||
} | ||
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y); | ||
} | ||
} |
Oops, something went wrong.