Skip to content

Commit

Permalink
[Pytorch][Vulkan] sum.dim_IntList with keepdim
Browse files Browse the repository at this point in the history
Summary:
Add Vulkan support for [sum](https://pytorch.org/docs/stable/generated/torch.sum.html).dim_IntList) with `keep_dim=true`

[sum.dim_IntList](https://www.internalfb.com/code/fbsource/[49b7951b7eb6]/xplat/caffe2/aten/src/ATen/native/native_functions.yaml?lines=5466)

```
if keepdim is true, the output tensor is of the same size as input except in the dimension(s) dim, where it is of size 1
otherwise, the dim is squeezed, result in the output tensor having 1 fewer dimension/s.
```

Test Plan:
```
lfq@lfq-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*.sum*"
Action graph will be rebuilt because files have been added or removed.
Parsing buck files: finished in 1.4 sec
Downloaded 4/58 artifacts, 3.08 Mbytes, 50.0% cache miss (for updated rules)
Building: finished in 41.2 sec (100%) 536/536 jobs, 13/536 updated
  Total time: 42.8 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *.sum*
[==========] Running 6 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 6 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.sum_dim_2d
[       OK ] VulkanAPITest.sum_dim_2d (558 ms)
[ RUN      ] VulkanAPITest.sum_dim_3d
[       OK ] VulkanAPITest.sum_dim_3d (7 ms)
[ RUN      ] VulkanAPITest.sum_dim_4d
[       OK ] VulkanAPITest.sum_dim_4d (14 ms)
[ RUN      ] VulkanAPITest.sum_dim_keepdim_2d
[       OK ] VulkanAPITest.sum_dim_keepdim_2d (4 ms)
[ RUN      ] VulkanAPITest.sum_dim_keepdim_3d
[       OK ] VulkanAPITest.sum_dim_keepdim_3d (7 ms)
[ RUN      ] VulkanAPITest.sum_dim_keepdim_4d
[       OK ] VulkanAPITest.sum_dim_keepdim_4d (18 ms)
[----------] 6 tests from VulkanAPITest (612 ms total)

[----------] Global test environment tear-down
[==========] 6 tests from 1 test suite ran. (612 ms total)
[  PASSED  ] 6 tests.
```

Reviewed By: SS-JIA

Differential Revision: D47652931

fbshipit-source-id: 62ce3a217338770e0401a7779c60661ac1067045
  • Loading branch information
lucylq authored and facebook-github-bot committed Jul 27, 2023
1 parent ca7ece9 commit 8382227
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 8 deletions.
70 changes: 70 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/sum_dim_keepdim.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// dim_info.x: dim to sum
// dim_info.y: size of dim (in the input)
uvec2 dim_info;
int channel;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Returns a new tensor with values summed along dimension dim.
* Output and input have same number of dimensions.
* summed dimension is of size 1.
*/

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

int flattened_channels = int(ceil(uBlock.channel / 4.0));
vec4 out_texel = vec4(0, 0, 0, 0);

// Batch
if (uBlock.dim_info.x == 0) {
for (int batch = 0; batch < uBlock.dim_info.y; batch++) {
// src_n = batch
// src_c = pos.z
int src_z = batch * flattened_channels + pos.z;
out_texel += texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
}
imageStore(uOutput, pos, out_texel);
}

// Channel
else if (uBlock.dim_info.x == 1) {
for (int out_index = 0; out_index < 4; out_index++) {
for (int channel = 0; channel < uBlock.dim_info.y; channel++) {
// src_n = pos.z
// src_c = channel
int src_z = pos.z * flattened_channels + int(channel / 4);
vec4 v = texelFetch(uInput, ivec3(pos.x, pos.y, src_z), 0);
out_texel[out_index] += v[channel % 4];
}
}
imageStore(uOutput, pos, out_texel);
}

// Height, Width
else {
for (int hw = 0; hw < uBlock.dim_info.y; hw++) {
vec4 v = (uBlock.dim_info.x == 2)
? texelFetch(uInput, ivec3(pos.x, hw, pos.z), 0) // Height
: texelFetch(uInput, ivec3(hw, pos.y, pos.z), 0); // Width
out_texel += v;
}
imageStore(uOutput, pos, out_texel);
}
}
11 changes: 6 additions & 5 deletions aten/src/ATen/native/vulkan/ops/Sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ Tensor sum_dim(
// Create the output texture
std::vector<int64_t> output_size = self.sizes().vec();
uint32_t dim_size = output_size[dim];
output_size.erase(output_size.begin() + dim);
if (keepdim) {
output_size[dim] = 1;
} else {
output_size.erase(output_size.begin() + dim);
}

ScalarType type = self.scalar_type();
if (dtype.has_value()) {
Expand Down Expand Up @@ -74,7 +78,7 @@ Tensor sum_dim(

context->submit_compute_job(
// shader descriptor
VK_KERNEL(sum_dim),
keepdim ? VK_KERNEL(sum_dim_keepdim) : VK_KERNEL(sum_dim),
// pipeline barrier
pipeline_barrier,
// global work group size
Expand Down Expand Up @@ -102,9 +106,6 @@ Tensor sum_dim_IntList(
TORCH_CHECK(
opt_dim.has_value(),
"Vulkan sum.dim_IntList without a dim arg is not implemented");
TORCH_CHECK(
keepdim == false,
"Vulkan sum.dim_IntList with keepdim=true is not implemented");

std::set<int64_t> dims_set;
if (opt_dim.has_value()) {
Expand Down
39 changes: 36 additions & 3 deletions aten/src/ATen/test/vulkan_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3702,12 +3702,12 @@ TEST_F(VulkanAPITest, sub_to_scalar_wrapped) {
ASSERT_TRUE(check);
}

void test_sum_dim(const at::IntArrayRef input_shape, const at::IntArrayRef dim_list) {
void test_sum_dim(const at::IntArrayRef input_shape, const at::IntArrayRef dim_list, bool keepdim=false) {
const auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat));
const auto in_vulkan = in_cpu.vulkan();

const auto out_cpu = at::sum(in_cpu, dim_list);
const auto out_vulkan = at::sum(in_vulkan, dim_list);
const auto out_cpu = at::sum(in_cpu, dim_list, keepdim);
const auto out_vulkan = at::sum(in_vulkan, dim_list, keepdim);

const auto check = almostEqual(out_cpu, out_vulkan.cpu());
if (!check) {
Expand Down Expand Up @@ -3753,6 +3753,39 @@ TEST_F(VulkanAPITest, sum_dim_4d) {
test_sum_dim({10, 7, 5, 6}, {3, 2, 1});
}

TEST_F(VulkanAPITest, sum_dim_keepdim_2d) {
test_sum_dim({5, 7}, {-1}, true);
test_sum_dim({5, 7}, {-2}, true);
}

TEST_F(VulkanAPITest, sum_dim_keepdim_3d) {
test_sum_dim({9, 5, 7}, {-1}, true);
test_sum_dim({5, 9, 7}, {-2}, true);
test_sum_dim({7, 9, 5}, {-3}, true);

test_sum_dim({9, 5, 7}, {0, 1}, true);
test_sum_dim({5, 9, 7}, {0, 2}, true);
test_sum_dim({7, 9, 5}, {1, 2}, true);
}

TEST_F(VulkanAPITest, sum_dim_keepdim_4d) {
test_sum_dim({9, 5, 7, 11}, {-1}, true);
test_sum_dim({5, 9, 11, 7}, {-2}, true);
test_sum_dim({7, 11, 9, 5}, {-3}, true);
test_sum_dim({11, 7, 9, 5}, {-4}, true);

test_sum_dim({9, 5, 7, 11}, {0, 1}, true);
test_sum_dim({5, 9, 11, 7}, {0, 2}, true);
test_sum_dim({7, 11, 9, 5}, {0, 3}, true);
test_sum_dim({11, 7, 9, 5}, {1, 2}, true);
test_sum_dim({9, 5, 7, 11}, {1, 3}, true);
test_sum_dim({5, 9, 11, 7}, {2, 3}, true);

test_sum_dim({7, 11, 9, 5}, {-1, -2, -3}, true);
test_sum_dim({11, 7, 9, 5}, {-1, -2, -4}, true);
test_sum_dim({9, 5, 7, 11}, {-2, -3, -4}, true);
}

TEST_F(VulkanAPITest, uniform) {
float a_min = -8.2f;
float a_max = -1.4f;
Expand Down

0 comments on commit 8382227

Please sign in to comment.