Skip to content

Commit

Permalink
[PyTorch Vulkan] add Vulkan support for aten::masked_fill (#104444)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #104444

Implemented `aten::masked_fill` for Vulkan backend, see https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill.html for the behavior of this operator.

Some explanation of the implementation:
- The shapes of the input tensor and mask should be broadcastable (see [broadcasting semantics](https://pytorch.org/docs/stable/notes/broadcasting.html)). For example, the input tensor is of shape [3, 1, 5] and mask of shape [2, 1]. Then the output is of shape [3, 2, 5].
- A straightforward implementation is to generate an output and a mask, both of shape [3, 2, 5], by applying `repeat` operations on the input tensor and mask respectively. Then we traverse the mask and fill elements of output with `value` where mask is `True`.
- However the `repeat` operation on mask is unnecessary and incurs extra time and space overhead. Instead we can keep the mask as it is and traverse the original mask and compute the corresponding broadcasted positions in the output tensor (see the shader file `masked_fill.glsl` for such computation).

Some explanation of the test:
- We test all possible broadcasting of the input tensor and mask. Manually setting all possible broadcastable shapes is intimidating. Instead we apply the following algorithm to automatically generate all possible cases which only requires one input of the shapes of the input tensor and mask.
  - First we set an identical shape for the `input_shape` and `mask_shape`, e.g. both are of [3, 5, 2, 3].
  - Then we truncate all possible proceeding dimensions of `input_shape` and `mask_shape` respectively. Denote the results as `curr_input_shape` and `curr_mask_shape`, e.g. `curr_input_shape = [5, 2, 3]` and `curr_mask_shape = [2, 3]`.
  - Next, for both `curr_input_shape` and `curr_mask_shape` we generate all possible subsets of the indices and set the corresponding elements to 1 for each subset. For example, for `curr_input_shape = [5, 2, 3]`, a possible `input_idx_subset = [0, 2]`. We set the 0th and 2nd elements of `curr_input_shape` to be 1, then `curr_input_shape = [1, 2, 1]`. Similarly for `curr_mask_shape = [2, 3]`, a possible `mask_idx_subset = [0]`, then the updated `curr_mask_shape = [1, 3]`.
  - In the end, we test `masked_fill` with the combinations of `curr_input_shape` and `curr_mask_shape`. In the example above, an output tensor of shape [1, 2, 3] will be generated.
  - In `vulkan_api_test.cpp`, a function `gen_all_subsets` is implemented to generate all possible subsets of a given set of indices through backtracking.

Test Plan:
Full test result is shown in P777851326. `masked_fill` related tests are shown below.

```
(base) luwei@luwei-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*mask*"
Building: finished in 0.1 sec (100%) 264/2820 jobs, 0/2820 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc
Note: Google Test filter = *mask*
[==========] Running 5 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 5 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.masked_fill_invalidinputs_exceptions
[       OK ] VulkanAPITest.masked_fill_invalidinputs_exceptions (35 ms)
[ RUN      ] VulkanAPITest.masked_fill_scalar_mult4ch
[       OK ] VulkanAPITest.masked_fill_scalar_mult4ch (582 ms)
[ RUN      ] VulkanAPITest.masked_fill_scalar_nonmult4ch
[       OK ] VulkanAPITest.masked_fill_scalar_nonmult4ch (592 ms)
[ RUN      ] VulkanAPITest.masked_fill_tensor_mult4ch
[       OK ] VulkanAPITest.masked_fill_tensor_mult4ch (0 ms)
[ RUN      ] VulkanAPITest.masked_fill_tensor_nonmult4ch
[       OK ] VulkanAPITest.masked_fill_tensor_nonmult4ch (0 ms)
[----------] 5 tests from VulkanAPITest (1212 ms total)

[----------] Global test environment tear-down
[==========] 5 tests from 1 test suite ran. (1212 ms total)
[  PASSED  ] 5 tests.
```

Reviewed By: SS-JIA

Differential Revision: D46423811

fbshipit-source-id: 8f9af31a5e3a400afaa32118a7480023682cbe2d
  • Loading branch information
copyrightly authored and facebook-github-bot committed Jun 30, 2023
1 parent 537a6c0 commit 4772bbe
Show file tree
Hide file tree
Showing 9 changed files with 721 additions and 19 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/native/vulkan/api/Resource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace api {
*/
VkFormat vk_format(const at::ScalarType dtype) {
switch (dtype) {
case c10::kBool:
return VK_FORMAT_R8G8B8A8_SINT;
case kFloat:
#ifdef USE_VULKAN_FP16_INFERENCE
return VK_FORMAT_R16G16B16A16_SFLOAT;
Expand Down
110 changes: 110 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/masked_fill.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput;

/*
* Input Textures
*/
layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// output texture size (x=width,y=height,z=depth,w=unused)
ivec4 out_extents;
// mask texture size (x=width,y=height,z=depth,w=unused)
ivec4 mask_extents;
// output extent sizes (x=batch,y=channel,z=height,w=width)
uvec4 out_size_info;
// mask extent sizes (x=batch,y=channel,z=height,w=width)
uvec4 mask_size_info;
// x: size of output channel dim up-aligned to 4
// y: size of mask channel dim up-aligned to 4
uvec2 aligned_channel_info;
// value to replace
float value;
}
uBlock;

/*
* Local Work Group
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos_mask = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos_mask, uBlock.out_extents.xyz))) {
return;
}

ivec4 inval = texelFetch(uInput, pos_mask, 0);

bool mask_has_true = false;
for (uint i = 0; i < 4; ++i) {
if ((pos_mask.z * 4 + i) % uBlock.aligned_channel_info.y >=
uBlock.mask_size_info.y) {
break;
}
if (inval[i] == 1) {
mask_has_true = true;
}
}

// we traverse the elements of mask. If an element is True, we find the
// corresponding positions in the output according to broadcasting and fill
// the elements of output with value. Due to the padding at channel dimension,
// we have different ways to fill the value depending on whether the channel
// dimension is broadcasted or not
if (mask_has_true) {
bool mask_channel_is_broadcast =
uBlock.mask_size_info.y < uBlock.out_size_info.y;
uint tex_cnt_in_output_batch = uBlock.aligned_channel_info.x / 4;

for (uint batch = 0;
batch < uBlock.out_size_info.x / uBlock.mask_size_info.x;
++batch) {
for (uint height = 0;
height < uBlock.out_size_info.z / uBlock.mask_size_info.z;
++height) {
for (uint width = 0;
width < uBlock.out_size_info.w / uBlock.mask_size_info.w;
++width) {
if (mask_channel_is_broadcast) {
for (int tex_idx = 0; tex_idx < tex_cnt_in_output_batch;
++tex_idx) {
ivec3 write_pos = ivec3(
pos_mask.x + width,
pos_mask.y + height,
tex_cnt_in_output_batch * (batch + pos_mask.z) + tex_idx);
vec4 out_tex = imageLoad(uOutput, write_pos);
for (int i = 0; i < 4; ++i) {
if (tex_idx * 4 + i >= uBlock.out_size_info.y) {
break;
}
out_tex[i] = uBlock.value;
}
imageStore(uOutput, write_pos, out_tex);
}
} else {
ivec3 write_pos = ivec3(
pos_mask.x + width,
pos_mask.y + height,
pos_mask.z + tex_cnt_in_output_batch * batch);
vec4 out_tex = imageLoad(uOutput, write_pos);
out_tex = vec4(equal(inval, ivec4(1))) * uBlock.value + vec4(notEqual(inval, ivec4(1))) * out_tex;
imageStore(uOutput, write_pos, out_tex);
}
}
}
}
}
}
98 changes: 98 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/nchw_to_image_bool.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

/*
* Output Image
*/
layout(set = 0, binding = 0, rgba8i) uniform PRECISION restrict writeonly iimage3D uImage;

/*
* Input Buffer
*/
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
int data[];
}
uBuffer;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// xyz contain the extents of the output texture, w contains HxW to help
// calculate buffer offsets
ivec4 out_extents;
// x: number of texels spanned by one batch
// y: number of channels
ivec2 c_info;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Extends sign of int8
*/
int extend_sign(int x) {
if (x >> 7 == 1) {
return x | 0xFFFFFF00;
}
return x;
}

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
return;
}

const int n_index = int(pos.z / uBlock.c_info.x);
const int c_index = (pos.z % uBlock.c_info.x) * 4;
int d_offset = (n_index * uBlock.c_info.y) + c_index;

const int base_index =
pos.x + uBlock.out_extents.x * pos.y + uBlock.out_extents.w * d_offset;
const ivec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * uBlock.out_extents.w;

int shift = (1 << 8) - 1;
ivec4 masks;
masks.x = shift << 8 * (buf_indices.x % 4);
masks.y = shift << 8 * (buf_indices.y % 4);
masks.z = shift << 8 * (buf_indices.z % 4);
masks.w = shift << 8 * (buf_indices.w % 4);

int buf_in_1 = uBuffer.data[buf_indices.x / 4];
int a_v = (buf_in_1 & masks.x) >> 8 * (buf_indices.x % 4);
a_v = extend_sign(a_v);

int buf_in_2 = uBuffer.data[buf_indices.y / 4];
int b_v = (buf_in_2 & masks.y) >> 8 * (buf_indices.y % 4);
b_v = extend_sign(b_v);

int buf_in_3 = uBuffer.data[buf_indices.z / 4];
int g_v = (buf_in_3 & masks.z) >> 8 * (buf_indices.z % 4);
g_v = extend_sign(g_v);

int buf_in_4 = uBuffer.data[buf_indices.w / 4];
int r_v = (buf_in_4 & masks.w) >> 8 * (buf_indices.w % 4);
r_v = extend_sign(r_v);

ivec4 texel = ivec4(a_v, b_v, g_v, r_v);

if (c_index + 3 >= uBlock.c_info.y) {
ivec4 c_ind = ivec4(c_index) + ivec4(0, 1, 2, 3);
ivec4 valid_c = ivec4(lessThan(c_ind, ivec4(uBlock.c_info.y)));
texel = texel * valid_c;
}

imageStore(uImage, pos, texel);
}
56 changes: 39 additions & 17 deletions aten/src/ATen/native/vulkan/impl/Packing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,24 @@ api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
}
}

switch (v_dst.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(nchw_to_image);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(nchw_to_image2d);
default:
TORCH_CHECK(false, "No kernel available!");
if (v_dst.dtype() == at::kFloat) {
switch (v_dst.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(nchw_to_image);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(nchw_to_image2d);
default:
TORCH_CHECK(false, "No kernel available!");
}
} else if (v_dst.dtype() == at::kBool) {
switch (v_dst.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(nchw_to_image_bool);
default:
TORCH_CHECK(false, "No kernel available!");
}
} else {
TORCH_CHECK(false, "Unsupported dtype!");
}
}

Expand All @@ -50,10 +61,10 @@ api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) {
switch (v_src.dtype()) {
case c10::ScalarType::QUInt8:
return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4)
: VK_KERNEL(image_to_nchw_quantized);
: VK_KERNEL(image_to_nchw_uint);
case c10::ScalarType::QInt8:
return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4)
: VK_KERNEL(image_to_nchw_quantized);
: VK_KERNEL(image_to_nchw_uint);
case c10::ScalarType::QInt32:
return VK_KERNEL(image_to_nchw_int32);
default:
Expand All @@ -70,13 +81,24 @@ api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) {
}
}

switch (v_src.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(image_to_nchw);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(image2d_to_nchw);
default:
TORCH_CHECK(false, "No kernel available!");
if (v_src.dtype() == at::kFloat) {
switch (v_src.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(image_to_nchw);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(image2d_to_nchw);
default:
TORCH_CHECK(false, "No kernel available!");
}
} else if (v_src.dtype() == at::kBool) {
switch (v_src.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(image_to_nchw_uint);
default:
TORCH_CHECK(false, "No kernel available!");
}
} else {
TORCH_CHECK(false, "Unsupported dtype!");
}
}

Expand Down Expand Up @@ -161,7 +183,7 @@ void record_image_to_nchw_op(
};

if (v_src.dtype() == c10::ScalarType::QUInt8 ||
v_src.dtype() == c10::ScalarType::QInt8) {
v_src.dtype() == c10::ScalarType::QInt8 || v_src.dtype() == at::kBool) {
if (plane_size % 4 == 0) {
global_size.data[0u] = plane_size / 4;
global_size.data[1u] = 1;
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/native/vulkan/ops/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ void memcpy_to_mapping(const Tensor& src, api::MemoryMap& dst_mapping) {
memcpy_to_mapping_impl<c10::qint8>(src, dst_mapping);
} else if (src.dtype() == c10::kQInt32) {
memcpy_to_mapping_impl<c10::qint32>(src, dst_mapping);
} else if (src.dtype() == c10::kBool) {
memcpy_to_mapping_uint8(src, dst_mapping);
} else {
TORCH_CHECK(
false,
"Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,",
" at::kHalf or at::Float but got ",
" c10::kBool, at::kHalf, or at::Float but got ",
src.dtype());
}
}
Expand All @@ -43,11 +45,13 @@ void memcpy_from_mapping(api::MemoryMap& src_mapping, Tensor& dst) {
memcpy_from_mapping_impl<c10::qint8>(src_mapping, dst);
} else if (dst.dtype() == c10::kQInt32) {
memcpy_from_mapping_impl<c10::qint32>(src_mapping, dst);
} else if (dst.dtype() == c10::kBool) {
memcpy_from_mapping_bool(src_mapping, dst);
} else {
TORCH_CHECK(
false,
"Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,",
" at::kHalf or at::Float but got ",
" c10::kBool, at::kHalf or at::Float but got ",
dst.dtype());
}
}
Expand Down
20 changes: 20 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ void memcpy_from_mapping_impl(api::MemoryMap& src_mapping, Tensor& dst) {
std::min(src_mapping.nbytes(), dst.nbytes()));
}

inline void memcpy_from_mapping_bool(api::MemoryMap& src_mapping, Tensor& dst) {
uint8_t* src_ptr = src_mapping.template data<uint8_t>();
bool* dst_ptr = dst.mutable_data_ptr<bool>();
for (int i = 0; (unsigned)i < std::min(src_mapping.nbytes(), dst.nbytes());
++i) {
dst_ptr[i] = static_cast<bool>(src_ptr[i]);
}
}

inline void memcpy_to_mapping_uint8(
const Tensor& src,
api::MemoryMap& dst_mapping) {
bool* src_ptr = src.mutable_data_ptr<bool>();
uint8_t* dst_ptr = dst_mapping.template data<uint8_t>();
for (int i = 0; (unsigned)i < std::min(dst_mapping.nbytes(), src.nbytes());
++i) {
dst_ptr[i] = static_cast<uint8_t>(src_ptr[i]);
}
}

void memcpy_to_mapping(const Tensor& src, api::MemoryMap& dst_mapping);

void memcpy_from_mapping(api::MemoryMap& src_mapping, Tensor& dst);
Expand Down

0 comments on commit 4772bbe

Please sign in to comment.