Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Pytorch][Vulkan] sum.dim_IntList with keepdim
Summary: Add Vulkan support for [sum](https://pytorch.org/docs/stable/generated/torch.sum.html).dim_IntList) with `keep_dim=true` [sum.dim_IntList](https://www.internalfb.com/code/fbsource/[49b7951b7eb6]/xplat/caffe2/aten/src/ATen/native/native_functions.yaml?lines=5466) ``` if keepdim is true, the output tensor is of the same size as input except in the dimension(s) dim, where it is of size 1 otherwise, the dim is squeezed, result in the output tensor having 1 fewer dimension/s. ``` Test Plan: ``` lfq@lfq-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*.sum*" Action graph will be rebuilt because files have been added or removed. Parsing buck files: finished in 1.4 sec Downloaded 4/58 artifacts, 3.08 Mbytes, 50.0% cache miss (for updated rules) Building: finished in 41.2 sec (100%) 536/536 jobs, 13/536 updated Total time: 42.8 sec BUILD SUCCEEDED Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc Note: Google Test filter = *.sum* [==========] Running 6 tests from 1 test suite. [----------] Global test environment set-up. [----------] 6 tests from VulkanAPITest [ RUN ] VulkanAPITest.sum_dim_2d [ OK ] VulkanAPITest.sum_dim_2d (558 ms) [ RUN ] VulkanAPITest.sum_dim_3d [ OK ] VulkanAPITest.sum_dim_3d (7 ms) [ RUN ] VulkanAPITest.sum_dim_4d [ OK ] VulkanAPITest.sum_dim_4d (14 ms) [ RUN ] VulkanAPITest.sum_dim_keepdim_2d [ OK ] VulkanAPITest.sum_dim_keepdim_2d (4 ms) [ RUN ] VulkanAPITest.sum_dim_keepdim_3d [ OK ] VulkanAPITest.sum_dim_keepdim_3d (7 ms) [ RUN ] VulkanAPITest.sum_dim_keepdim_4d [ OK ] VulkanAPITest.sum_dim_keepdim_4d (18 ms) [----------] 6 tests from VulkanAPITest (612 ms total) [----------] Global test environment tear-down [==========] 6 tests from 1 test suite ran. (612 ms total) [ PASSED ] 6 tests. ``` Reviewed By: SS-JIA Differential Revision: D47652931 fbshipit-source-id: 62ce3a217338770e0401a7779c60661ac1067045
- Loading branch information