Skip to content

Commit

Permalink
[Pytorch][Vulkan] mean.dim
Browse files Browse the repository at this point in the history
Summary:
We implement [`torch.mean(input, dim, keepdim)`](https://pytorch.org/docs/stable/generated/torch.mean.html) for tensors of 2d to 4d.

Since 0-dim tensor hasn't been supported yet, we only support `dim.size() < input.dim()` for now. We will support following cases in the future work:
- `dim.size() == input.dim()`
- `input.dim() == 1`

Test Plan:
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (970fcd90c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*mean*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *mean*
[==========] Running 7 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 7 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.mean_invalid_inputs
[       OK ] VulkanAPITest.mean_invalid_inputs (46 ms)
[ RUN      ] VulkanAPITest.mean_dim_2d
[       OK ] VulkanAPITest.mean_dim_2d (127 ms)
[ RUN      ] VulkanAPITest.mean_dim_3d
[       OK ] VulkanAPITest.mean_dim_3d (103 ms)
[ RUN      ] VulkanAPITest.mean_dim_4d
[       OK ] VulkanAPITest.mean_dim_4d (89 ms)
[ RUN      ] VulkanAPITest.mean_dim_keepdim_2d
[       OK ] VulkanAPITest.mean_dim_keepdim_2d (66 ms)
[ RUN      ] VulkanAPITest.mean_dim_keepdim_3d
[       OK ] VulkanAPITest.mean_dim_keepdim_3d (127 ms)
[ RUN      ] VulkanAPITest.mean_dim_keepdim_4d
[       OK ] VulkanAPITest.mean_dim_keepdim_4d (4 ms)
[----------] 7 tests from VulkanAPITest (564 ms total)

[----------] Global test environment tear-down
[==========] 7 tests from 1 test suite ran. (564 ms total)
[  PASSED  ] 7 tests.
```

Reviewed By: yipjustin

Differential Revision: D50312990
  • Loading branch information
copyrightly authored and facebook-github-bot committed Oct 19, 2023
1 parent 6662435 commit 7fbfdbf
Show file tree
Hide file tree
Showing 6 changed files with 351 additions and 229 deletions.
77 changes: 0 additions & 77 deletions aten/src/ATen/native/vulkan/glsl/mean.glsl

This file was deleted.

90 changes: 0 additions & 90 deletions aten/src/ATen/native/vulkan/glsl/mean2d.glsl

This file was deleted.

82 changes: 82 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/mean_dim.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// dim_info.x: dim to compute mean
// dim_info.y: size of dim (in the input)
uvec2 dim_info;
int channel;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Returns a new tensor with values averaged along dimension dim
* Dimension dim is squeezed
* For each pos:
* - Iterate over the out_texel and the averaged dimension
* - For H,W; rearrange pos.x, pos.y
* - For C,H,W;
* When CHW are averaged, batch moves into channel
* The src N is determined by pos.z * 4 + out_index
*/

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

int flattened_channels = int(ceil(uBlock.channel / 4.0));
vec4 out_texel = vec4(0, 0, 0, 0);

// Batch
if (uBlock.dim_info.x == 0) {
for (int batch = 0; batch < uBlock.dim_info.y; batch++) {
// src_n = batch
// src_c = pos.z
int src_z = batch * flattened_channels + pos.z;
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
out_texel += v;
}
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y);
}

// Channel
else if (uBlock.dim_info.x == 1) {
for (int out_index = 0; out_index < 4; out_index++) {
for (int channel = 0; channel < uBlock.dim_info.y; channel++) {
// src_n = pos.z * 4 + out_index
// src_c = channel
int src_z =
(pos.z * 4 + out_index) * flattened_channels + int(channel / 4);
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
out_texel[out_index] += v[channel % 4];
}
}
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y);
}

// Height, Width
else {
for (int out_index = 0; out_index < 4; out_index++) {
// src_n = pos.z * 4 + out_index
// src_c = pos.y
int src_z = (pos.z * 4 + out_index) * flattened_channels + pos.y / 4;
for (int hw = 0; hw < uBlock.dim_info.y; hw++) {
vec4 v = (uBlock.dim_info.x == 2)
? texelFetch(uInput, ivec3(pos.x, hw, src_z), 0) // Height
: texelFetch(uInput, ivec3(hw, pos.x, src_z), 0); // Width
out_texel[out_index] += v[pos.y % 4];
}
}
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y);
}
}
70 changes: 70 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/mean_dim_keepdim.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// dim_info.x: dim to compute mean
// dim_info.y: size of dim (in the input)
uvec2 dim_info;
int channel;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Returns a new tensor with values averaged along dimension dim.
* Output and input have same number of dimensions.
* averaged dimension is of size 1.
*/

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

int flattened_channels = int(ceil(uBlock.channel / 4.0));
vec4 out_texel = vec4(0, 0, 0, 0);

// Batch
if (uBlock.dim_info.x == 0) {
for (int batch = 0; batch < uBlock.dim_info.y; batch++) {
// src_n = batch
// src_c = pos.z
int src_z = batch * flattened_channels + pos.z;
out_texel += texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
}
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y);
}

// Channel
else if (uBlock.dim_info.x == 1) {
for (int out_index = 0; out_index < 4; out_index++) {
for (int channel = 0; channel < uBlock.dim_info.y; channel++) {
// src_n = pos.z
// src_c = channel
int src_z = pos.z * flattened_channels + int(channel / 4);
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
out_texel[out_index] += v[channel % 4];
}
}
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y);
}

// Height, Width
else {
for (int hw = 0; hw < uBlock.dim_info.y; hw++) {
vec4 v = (uBlock.dim_info.x == 2)
? texelFetch(uInput, ivec3(pos.x, hw, pos.z), 0) // Height
: texelFetch(uInput, ivec3(hw, pos.y, pos.z), 0); // Width
out_texel += v;
}
imageStore(uOutput, pos, out_texel / uBlock.dim_info.y);
}
}
Loading

0 comments on commit 7fbfdbf

Please sign in to comment.