Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add Vulkan support for at::softmax 1,2,3 dimension tensors #105012

Closed
wants to merge 1 commit into from

Commits on Jul 12, 2023

  1. [PyTorch] Add Vulkan support for at::softmax 1,2,3 dimension tensors (p…

    …ytorch#105012)
    
    Summary:
    bypass-github-export-checks
    
    Pull Request resolved: pytorch#105012
    
    This rounds out the support for the [softmax function](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html) on the Vulkan GPU backend. The test inputs of the 1,2,3 dimension cases are simply the truncated existing 4 dimension inputs. The existing shader algorithms are reused.
    
    Test Plan:
    1. `buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource  //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1` on Apple M1 MacBook
    2. Confirm all tests pass with no regression, and the added tests `*softmax*` pass under `-- --gtest_filter="*softmax*"`
    2a. All tests P782531732
    2b. `softmax` tests P782529114
    
    ```
    ~/fbsource » buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*softmax*"
    Buck UI: https://www.internalfb.com/buck2/692eb82d-c2ee-49bb-833f-3c11d6e2fea9
    Jobs completed: 4. Time elapsed: 0.1s.
    Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc
    Note: Google Test filter = *softmax*
    [==========] Running 1 test from 1 test suite.
    [----------] Global test environment set-up.
    [----------] 1 test from VulkanAPITest
    [ RUN      ] VulkanAPITest.softmax
    [       OK ] VulkanAPITest.softmax (42 ms)
    [ DISABLED ] VulkanAPITest.DISABLED_log_softmax
    [----------] 1 test from VulkanAPITest (42 ms total)
    
    [----------] Global test environment tear-down
    [==========] 1 test from 1 test suite ran. (42 ms total)
    [  PASSED  ] 1 test.
    
      YOU HAVE 1 DISABLED TEST
    
    ```
    
    Reviewed By: SS-JIA
    
    Differential Revision: D46985319
    
    fbshipit-source-id: 73fc0d4223bdf815b37ffb7ec80faf92b696cb38
    liuk22 authored and facebook-github-bot committed Jul 12, 2023
    Configuration menu
    Copy the full SHA
    7939905 View commit details
    Browse the repository at this point in the history