Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyTorch Vulkan] fix bug of
aten::cat
for concatenation of 3D tenso…
…rs at channel dim with channels as multiple of 4 (#103718) Summary: Pull Request resolved: #103718 The original `cat_feature_mult4ch` assumes input tensors are of 4d and use `tensor.sizes()[1]` to obtain the channel info of the tensor. This will cause bugs when the input tensors are of 3D. We generalize `cat_feature_mult4ch` to make it cover both 3D and 4D. Test Plan: Test for 3D tensors with channels as multiple of 4 is show below. Full test result is in P771032677. ``` (base) luwei@luwei-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*cat_3d_dim0_mult4ch_success*" Building: finished in 0.1 sec (100%) 263/2812 jobs, 0/2812 updated Total time: 0.1 sec BUILD SUCCEEDED Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc Note: Google Test filter = *cat_3d_dim0_mult4ch_success* [==========] Running 1 test from 1 test suite. [----------] Global test environment set-up. [----------] 1 test from VulkanAPITest [ RUN ] VulkanAPITest.cat_3d_dim0_mult4ch_success [ OK ] VulkanAPITest.cat_3d_dim0_mult4ch_success (129 ms) [----------] 1 test from VulkanAPITest (129 ms total) [----------] Global test environment tear-down [==========] 1 test from 1 test suite ran. (129 ms total) [ PASSED ] 1 test. ``` Reviewed By: SS-JIA Differential Revision: D46755034 fbshipit-source-id: 112216ff29bbe50ef7e8723782c7d0322beddc63
- Loading branch information