…104444)
Summary:
Pull Request resolved: pytorch#104444
Implemented `aten::masked_fill` for Vulkan backend, see https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill.html for the behavior of this operator.
Some explanation of the implementation:
- The shapes of the input tensor and mask should be broadcastable (see [broadcasting semantics](https://pytorch.org/docs/stable/notes/broadcasting.html)). For example, the input tensor is of shape [3, 1, 5] and mask of shape [2, 1]. Then the output is of shape [3, 2, 5].
- A straightforward implementation is to generate an output and a mask, both of shape [3, 2, 5], by applying `repeat` operations on the input tensor and mask respectively. Then we traverse the mask and fill elements of output with `value` where mask is `True`.
- However the `repeat` operation on mask is unnecessary and incurs extra time and space overhead. Instead we can keep the mask as it is and traverse the original mask and compute the corresponding broadcasted positions in the output tensor (see the shader file `masked_fill.glsl` for such computation).
Some explanation of the test:
- We test all possible broadcasting of the input tensor and mask. Manually setting all possible broadcastable shapes is intimidating. Instead we apply the following algorithm to automatically generate all possible cases which only requires one input of the shapes of the input tensor and mask.
- First we set an identical shape for the `input_shape` and `mask_shape`, e.g. both are of [3, 5, 2, 3].
- Then we truncate all possible proceeding dimensions of `input_shape` and `mask_shape` respectively. Denote the results as `curr_input_shape` and `curr_mask_shape`, e.g. `curr_input_shape = [5, 2, 3]` and `curr_mask_shape = [2, 3]`.
- Next, for both `curr_input_shape` and `curr_mask_shape` we generate all possible subsets of the indices and set the corresponding elements to 1 for each subset. For example, for `curr_input_shape = [5, 2, 3]`, a possible `input_idx_subset = [0, 2]`. We set the 0th and 2nd elements of `curr_input_shape` to be 1, then `curr_input_shape = [1, 2, 1]`. Similarly for `curr_mask_shape = [2, 3]`, a possible `mask_idx_subset = [0]`, then the updated `curr_mask_shape = [1, 3]`.
- In the end, we test `masked_fill` with the combinations of `curr_input_shape` and `curr_mask_shape`. In the example above, an output tensor of shape [1, 2, 3] will be generated.
- In `vulkan_api_test.cpp`, a function `gen_all_subsets` is implemented to generate all possible subsets of a given set of indices through backtracking.
Test Plan:
Full test result is shown in P777851326. `masked_fill` related tests are shown below.
```
(base) luwei@luwei-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*mask*"
Building: finished in 0.1 sec (100%) 264/2820 jobs, 0/2820 updated
Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc
Note: Google Test filter = *mask*
[==========] Running 5 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 5 tests from VulkanAPITest
[ RUN ] VulkanAPITest.masked_fill_invalidinputs_exceptions
[ OK ] VulkanAPITest.masked_fill_invalidinputs_exceptions (35 ms)
[ RUN ] VulkanAPITest.masked_fill_scalar_mult4ch
[ OK ] VulkanAPITest.masked_fill_scalar_mult4ch (582 ms)
[ RUN ] VulkanAPITest.masked_fill_scalar_nonmult4ch
[ OK ] VulkanAPITest.masked_fill_scalar_nonmult4ch (592 ms)
[ RUN ] VulkanAPITest.masked_fill_tensor_mult4ch
[ OK ] VulkanAPITest.masked_fill_tensor_mult4ch (0 ms)
[ RUN ] VulkanAPITest.masked_fill_tensor_nonmult4ch
[ OK ] VulkanAPITest.masked_fill_tensor_nonmult4ch (0 ms)
[----------] 5 tests from VulkanAPITest (1212 ms total)
[----------] Global test environment tear-down
[==========] 5 tests from 1 test suite ran. (1212 ms total)
[ PASSED ] 5 tests.
```
Reviewed By: SS-JIA
Differential Revision: D46423811
fbshipit-source-id: a0af853da09d2ee53ba050e9624dd529bd52ed74