Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch][Vulkan] sum.dim_IntList with keepdim #106159

Closed
wants to merge 1 commit into from

Commits on Jul 27, 2023

  1. [Pytorch][Vulkan] sum.dim_IntList with keepdim

    Summary:
    Add Vulkan support for [sum](https://pytorch.org/docs/stable/generated/torch.sum.html).dim_IntList) with `keep_dim=true`
    
    [sum.dim_IntList](https://www.internalfb.com/code/fbsource/[49b7951b7eb6]/xplat/caffe2/aten/src/ATen/native/native_functions.yaml?lines=5466)
    
    ```
    if keepdim is true, the output tensor is of the same size as input except in the dimension(s) dim, where it is of size 1
    otherwise, the dim is squeezed, result in the output tensor having 1 fewer dimension/s.
    ```
    
    Test Plan:
    ```
    lfq@lfq-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*.sum*"
    Action graph will be rebuilt because files have been added or removed.
    Parsing buck files: finished in 1.4 sec
    Downloaded 4/58 artifacts, 3.08 Mbytes, 50.0% cache miss (for updated rules)
    Building: finished in 41.2 sec (100%) 536/536 jobs, 13/536 updated
      Total time: 42.8 sec
    BUILD SUCCEEDED
    Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
    Note: Google Test filter = *.sum*
    [==========] Running 6 tests from 1 test suite.
    [----------] Global test environment set-up.
    [----------] 6 tests from VulkanAPITest
    [ RUN      ] VulkanAPITest.sum_dim_2d
    [       OK ] VulkanAPITest.sum_dim_2d (558 ms)
    [ RUN      ] VulkanAPITest.sum_dim_3d
    [       OK ] VulkanAPITest.sum_dim_3d (7 ms)
    [ RUN      ] VulkanAPITest.sum_dim_4d
    [       OK ] VulkanAPITest.sum_dim_4d (14 ms)
    [ RUN      ] VulkanAPITest.sum_dim_keepdim_2d
    [       OK ] VulkanAPITest.sum_dim_keepdim_2d (4 ms)
    [ RUN      ] VulkanAPITest.sum_dim_keepdim_3d
    [       OK ] VulkanAPITest.sum_dim_keepdim_3d (7 ms)
    [ RUN      ] VulkanAPITest.sum_dim_keepdim_4d
    [       OK ] VulkanAPITest.sum_dim_keepdim_4d (18 ms)
    [----------] 6 tests from VulkanAPITest (612 ms total)
    
    [----------] Global test environment tear-down
    [==========] 6 tests from 1 test suite ran. (612 ms total)
    [  PASSED  ] 6 tests.
    ```
    
    Reviewed By: SS-JIA
    
    Differential Revision: D47652931
    
    fbshipit-source-id: 62ce3a217338770e0401a7779c60661ac1067045
    lucylq authored and facebook-github-bot committed Jul 27, 2023
    Configuration menu
    Copy the full SHA
    8382227 View commit details
    Browse the repository at this point in the history