-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyTorch Vulkan] add Vulkan support for
aten::masked_fill
(#104444)
Summary: Pull Request resolved: #104444 Implemented `aten::masked_fill` for Vulkan backend, see https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill.html for the behavior of this operator. Some explanation of the implementation: - The shapes of the input tensor and mask should be broadcastable (see [broadcasting semantics](https://pytorch.org/docs/stable/notes/broadcasting.html)). For example, the input tensor is of shape [3, 1, 5] and mask of shape [2, 1]. Then the output is of shape [3, 2, 5]. - A straightforward implementation is to generate an output and a mask, both of shape [3, 2, 5], by applying `repeat` operations on the input tensor and mask respectively. Then we traverse the mask and fill elements of output with `value` where mask is `True`. - However the `repeat` operation on mask is unnecessary and incurs extra time and space overhead. Instead we can keep the mask as it is and traverse the original mask and compute the corresponding broadcasted positions in the output tensor (see the shader file `masked_fill.glsl` for such computation). Some explanation of the test: - We test all possible broadcasting of the input tensor and mask. Manually setting all possible broadcastable shapes is intimidating. Instead we apply the following algorithm to automatically generate all possible cases which only requires one input of the shapes of the input tensor and mask. - First we set an identical shape for the `input_shape` and `mask_shape`, e.g. both are of [3, 5, 2, 3]. - Then we truncate all possible proceeding dimensions of `input_shape` and `mask_shape` respectively. Denote the results as `curr_input_shape` and `curr_mask_shape`, e.g. `curr_input_shape = [5, 2, 3]` and `curr_mask_shape = [2, 3]`. - Next, for both `curr_input_shape` and `curr_mask_shape` we generate all possible subsets of the indices and set the corresponding elements to 1 for each subset. For example, for `curr_input_shape = [5, 2, 3]`, a possible `input_idx_subset = [0, 2]`. We set the 0th and 2nd elements of `curr_input_shape` to be 1, then `curr_input_shape = [1, 2, 1]`. Similarly for `curr_mask_shape = [2, 3]`, a possible `mask_idx_subset = [0]`, then the updated `curr_mask_shape = [1, 3]`. - In the end, we test `masked_fill` with the combinations of `curr_input_shape` and `curr_mask_shape`. In the example above, an output tensor of shape [1, 2, 3] will be generated. - In `vulkan_api_test.cpp`, a function `gen_all_subsets` is implemented to generate all possible subsets of a given set of indices through backtracking. Test Plan: Full test result is shown in P777851326. `masked_fill` related tests are shown below. ``` (base) luwei@luwei-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*mask*" Building: finished in 0.1 sec (100%) 264/2820 jobs, 0/2820 updated Total time: 0.1 sec BUILD SUCCEEDED Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc Note: Google Test filter = *mask* [==========] Running 5 tests from 1 test suite. [----------] Global test environment set-up. [----------] 5 tests from VulkanAPITest [ RUN ] VulkanAPITest.masked_fill_invalidinputs_exceptions [ OK ] VulkanAPITest.masked_fill_invalidinputs_exceptions (35 ms) [ RUN ] VulkanAPITest.masked_fill_scalar_mult4ch [ OK ] VulkanAPITest.masked_fill_scalar_mult4ch (582 ms) [ RUN ] VulkanAPITest.masked_fill_scalar_nonmult4ch [ OK ] VulkanAPITest.masked_fill_scalar_nonmult4ch (592 ms) [ RUN ] VulkanAPITest.masked_fill_tensor_mult4ch [ OK ] VulkanAPITest.masked_fill_tensor_mult4ch (0 ms) [ RUN ] VulkanAPITest.masked_fill_tensor_nonmult4ch [ OK ] VulkanAPITest.masked_fill_tensor_nonmult4ch (0 ms) [----------] 5 tests from VulkanAPITest (1212 ms total) [----------] Global test environment tear-down [==========] 5 tests from 1 test suite ran. (1212 ms total) [ PASSED ] 5 tests. ``` Reviewed By: SS-JIA Differential Revision: D46423811 fbshipit-source-id: ec76aa614498981aded737fcc953abee8a9071ed
- Loading branch information
1 parent
d0509fe
commit 2495101
Showing
9 changed files
with
721 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
layout(std430) buffer; | ||
|
||
/* | ||
* Output Image | ||
*/ | ||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput; | ||
|
||
/* | ||
* Input Textures | ||
*/ | ||
layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput; | ||
|
||
/* | ||
* Params Buffer | ||
*/ | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
// output texture size (x=width,y=height,z=depth,w=unused) | ||
ivec4 out_extents; | ||
// mask texture size (x=width,y=height,z=depth,w=unused) | ||
ivec4 mask_extents; | ||
// output extent sizes (x=batch,y=channel,z=height,w=width) | ||
uvec4 out_size_info; | ||
// mask extent sizes (x=batch,y=channel,z=height,w=width) | ||
uvec4 mask_size_info; | ||
// x: size of output channel dim up-aligned to 4 | ||
// y: size of mask channel dim up-aligned to 4 | ||
uvec2 aligned_channel_info; | ||
// value to replace | ||
float value; | ||
} | ||
uBlock; | ||
|
||
/* | ||
* Local Work Group | ||
*/ | ||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
void main() { | ||
const ivec3 pos_mask = ivec3(gl_GlobalInvocationID); | ||
|
||
if (any(greaterThanEqual(pos_mask, uBlock.out_extents.xyz))) { | ||
return; | ||
} | ||
|
||
ivec4 inval = texelFetch(uInput, pos_mask, 0); | ||
|
||
bool mask_has_true = false; | ||
for (uint i = 0; i < 4; ++i) { | ||
if ((pos_mask.z * 4 + i) % uBlock.aligned_channel_info.y >= | ||
uBlock.mask_size_info.y) { | ||
break; | ||
} | ||
if (inval[i] == 1) { | ||
mask_has_true = true; | ||
} | ||
} | ||
|
||
// we traverse the elements of mask. If an element is True, we find the | ||
// corresponding positions in the output according to broadcasting and fill | ||
// the elements of output with value. Due to the padding at channel dimension, | ||
// we have different ways to fill the value depending on whether the channel | ||
// dimension is broadcasted or not | ||
if (mask_has_true) { | ||
bool mask_channel_is_broadcast = | ||
uBlock.mask_size_info.y < uBlock.out_size_info.y; | ||
uint tex_cnt_in_output_batch = uBlock.aligned_channel_info.x / 4; | ||
|
||
for (uint batch = 0; | ||
batch < uBlock.out_size_info.x / uBlock.mask_size_info.x; | ||
++batch) { | ||
for (uint height = 0; | ||
height < uBlock.out_size_info.z / uBlock.mask_size_info.z; | ||
++height) { | ||
for (uint width = 0; | ||
width < uBlock.out_size_info.w / uBlock.mask_size_info.w; | ||
++width) { | ||
if (mask_channel_is_broadcast) { | ||
for (int tex_idx = 0; tex_idx < tex_cnt_in_output_batch; | ||
++tex_idx) { | ||
ivec3 write_pos = ivec3( | ||
pos_mask.x + width, | ||
pos_mask.y + height, | ||
tex_cnt_in_output_batch * (batch + pos_mask.z) + tex_idx); | ||
vec4 out_tex = imageLoad(uOutput, write_pos); | ||
for (int i = 0; i < 4; ++i) { | ||
if (tex_idx * 4 + i >= uBlock.out_size_info.y) { | ||
break; | ||
} | ||
out_tex[i] = uBlock.value; | ||
} | ||
imageStore(uOutput, write_pos, out_tex); | ||
} | ||
} else { | ||
ivec3 write_pos = ivec3( | ||
pos_mask.x + width, | ||
pos_mask.y + height, | ||
pos_mask.z + tex_cnt_in_output_batch * batch); | ||
vec4 out_tex = imageLoad(uOutput, write_pos); | ||
out_tex = vec4(equal(inval, ivec4(1))) * uBlock.value + vec4(notEqual(inval, ivec4(1))) * out_tex; | ||
imageStore(uOutput, write_pos, out_tex); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#version 450 core | ||
#define PRECISION $precision | ||
#define FORMAT $format | ||
|
||
layout(std430) buffer; | ||
|
||
/* Qualifiers: layout - storage - precision - memory */ | ||
|
||
/* | ||
* Output Image | ||
*/ | ||
layout(set = 0, binding = 0, rgba8i) uniform PRECISION restrict writeonly iimage3D uImage; | ||
|
||
/* | ||
* Input Buffer | ||
*/ | ||
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { | ||
int data[]; | ||
} | ||
uBuffer; | ||
|
||
/* | ||
* Params Buffer | ||
*/ | ||
layout(set = 0, binding = 2) uniform PRECISION restrict Block { | ||
// xyz contain the extents of the output texture, w contains HxW to help | ||
// calculate buffer offsets | ||
ivec4 out_extents; | ||
// x: number of texels spanned by one batch | ||
// y: number of channels | ||
ivec2 c_info; | ||
} | ||
uBlock; | ||
|
||
/* | ||
* Local Work Group Size | ||
*/ | ||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
/* | ||
* Extends sign of int8 | ||
*/ | ||
int extend_sign(int x) { | ||
if (x >> 7 == 1) { | ||
return x | 0xFFFFFF00; | ||
} | ||
return x; | ||
} | ||
|
||
void main() { | ||
const ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) { | ||
return; | ||
} | ||
|
||
const int n_index = int(pos.z / uBlock.c_info.x); | ||
const int c_index = (pos.z % uBlock.c_info.x) * 4; | ||
int d_offset = (n_index * uBlock.c_info.y) + c_index; | ||
|
||
const int base_index = | ||
pos.x + uBlock.out_extents.x * pos.y + uBlock.out_extents.w * d_offset; | ||
const ivec4 buf_indices = | ||
base_index + ivec4(0, 1, 2, 3) * uBlock.out_extents.w; | ||
|
||
int shift = (1 << 8) - 1; | ||
ivec4 masks; | ||
masks.x = shift << 8 * (buf_indices.x % 4); | ||
masks.y = shift << 8 * (buf_indices.y % 4); | ||
masks.z = shift << 8 * (buf_indices.z % 4); | ||
masks.w = shift << 8 * (buf_indices.w % 4); | ||
|
||
int buf_in_1 = uBuffer.data[buf_indices.x / 4]; | ||
int a_v = (buf_in_1 & masks.x) >> 8 * (buf_indices.x % 4); | ||
a_v = extend_sign(a_v); | ||
|
||
int buf_in_2 = uBuffer.data[buf_indices.y / 4]; | ||
int b_v = (buf_in_2 & masks.y) >> 8 * (buf_indices.y % 4); | ||
b_v = extend_sign(b_v); | ||
|
||
int buf_in_3 = uBuffer.data[buf_indices.z / 4]; | ||
int g_v = (buf_in_3 & masks.z) >> 8 * (buf_indices.z % 4); | ||
g_v = extend_sign(g_v); | ||
|
||
int buf_in_4 = uBuffer.data[buf_indices.w / 4]; | ||
int r_v = (buf_in_4 & masks.w) >> 8 * (buf_indices.w % 4); | ||
r_v = extend_sign(r_v); | ||
|
||
ivec4 texel = ivec4(a_v, b_v, g_v, r_v); | ||
|
||
if (c_index + 3 >= uBlock.c_info.y) { | ||
ivec4 c_ind = ivec4(c_index) + ivec4(0, 1, 2, 3); | ||
ivec4 valid_c = ivec4(lessThan(c_ind, ivec4(uBlock.c_info.y))); | ||
texel = texel * valid_c; | ||
} | ||
|
||
imageStore(uImage, pos, texel); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.