Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch Vulkan] add Vulkan support for aten::masked_fill #104444

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/native/vulkan/api/Resource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace api {
*/
VkFormat vk_format(const at::ScalarType dtype) {
switch (dtype) {
case c10::kBool:
return VK_FORMAT_R8G8B8A8_SINT;
case kFloat:
#ifdef USE_VULKAN_FP16_INFERENCE
return VK_FORMAT_R16G16B16A16_SFLOAT;
Expand Down
110 changes: 110 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/masked_fill.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput;

/*
* Input Textures
*/
layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// output texture size (x=width,y=height,z=depth,w=unused)
ivec4 out_extents;
// mask texture size (x=width,y=height,z=depth,w=unused)
ivec4 mask_extents;
// output extent sizes (x=batch,y=channel,z=height,w=width)
uvec4 out_size_info;
// mask extent sizes (x=batch,y=channel,z=height,w=width)
uvec4 mask_size_info;
// x: size of output channel dim up-aligned to 4
// y: size of mask channel dim up-aligned to 4
uvec2 aligned_channel_info;
// value to replace
float value;
}
uBlock;

/*
* Local Work Group
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos_mask = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos_mask, uBlock.out_extents.xyz))) {
return;
}

ivec4 inval = texelFetch(uInput, pos_mask, 0);

bool mask_has_true = false;
for (uint i = 0; i < 4; ++i) {
if ((pos_mask.z * 4 + i) % uBlock.aligned_channel_info.y >=
uBlock.mask_size_info.y) {
break;
}
if (inval[i] == 1) {
mask_has_true = true;
}
}

// we traverse the elements of mask. If an element is True, we find the
// corresponding positions in the output according to broadcasting and fill
// the elements of output with value. Due to the padding at channel dimension,
// we have different ways to fill the value depending on whether the channel
// dimension is broadcasted or not
if (mask_has_true) {
bool mask_channel_is_broadcast =
uBlock.mask_size_info.y < uBlock.out_size_info.y;
uint tex_cnt_in_output_batch = uBlock.aligned_channel_info.x / 4;

for (uint batch = 0;
batch < uBlock.out_size_info.x / uBlock.mask_size_info.x;
++batch) {
for (uint height = 0;
height < uBlock.out_size_info.z / uBlock.mask_size_info.z;
++height) {
for (uint width = 0;
width < uBlock.out_size_info.w / uBlock.mask_size_info.w;
++width) {
if (mask_channel_is_broadcast) {
for (int tex_idx = 0; tex_idx < tex_cnt_in_output_batch;
++tex_idx) {
ivec3 write_pos = ivec3(
pos_mask.x + width,
pos_mask.y + height,
tex_cnt_in_output_batch * (batch + pos_mask.z) + tex_idx);
vec4 out_tex = imageLoad(uOutput, write_pos);
for (int i = 0; i < 4; ++i) {
if (tex_idx * 4 + i >= uBlock.out_size_info.y) {
break;
}
out_tex[i] = uBlock.value;
}
imageStore(uOutput, write_pos, out_tex);
}
} else {
ivec3 write_pos = ivec3(
pos_mask.x + width,
pos_mask.y + height,
pos_mask.z + tex_cnt_in_output_batch * batch);
vec4 out_tex = imageLoad(uOutput, write_pos);
out_tex = vec4(equal(inval, ivec4(1))) * uBlock.value + vec4(notEqual(inval, ivec4(1))) * out_tex;
imageStore(uOutput, write_pos, out_tex);
}
}
}
}
}
}
98 changes: 98 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/nchw_to_image_bool.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

/*
* Output Image
*/
layout(set = 0, binding = 0, rgba8i) uniform PRECISION restrict writeonly iimage3D uImage;

/*
* Input Buffer
*/
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
int data[];
}
uBuffer;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// xyz contain the extents of the output texture, w contains HxW to help
// calculate buffer offsets
ivec4 out_extents;
// x: number of texels spanned by one batch
// y: number of channels
ivec2 c_info;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Extends sign of int8
*/
int extend_sign(int x) {
if (x >> 7 == 1) {
return x | 0xFFFFFF00;
}
return x;
}

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
return;
}

const int n_index = int(pos.z / uBlock.c_info.x);
const int c_index = (pos.z % uBlock.c_info.x) * 4;
int d_offset = (n_index * uBlock.c_info.y) + c_index;

const int base_index =
pos.x + uBlock.out_extents.x * pos.y + uBlock.out_extents.w * d_offset;
const ivec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * uBlock.out_extents.w;

int shift = (1 << 8) - 1;
ivec4 masks;
masks.x = shift << 8 * (buf_indices.x % 4);
masks.y = shift << 8 * (buf_indices.y % 4);
masks.z = shift << 8 * (buf_indices.z % 4);
masks.w = shift << 8 * (buf_indices.w % 4);

int buf_in_1 = uBuffer.data[buf_indices.x / 4];
int a_v = (buf_in_1 & masks.x) >> 8 * (buf_indices.x % 4);
a_v = extend_sign(a_v);

int buf_in_2 = uBuffer.data[buf_indices.y / 4];
int b_v = (buf_in_2 & masks.y) >> 8 * (buf_indices.y % 4);
b_v = extend_sign(b_v);

int buf_in_3 = uBuffer.data[buf_indices.z / 4];
int g_v = (buf_in_3 & masks.z) >> 8 * (buf_indices.z % 4);
g_v = extend_sign(g_v);

int buf_in_4 = uBuffer.data[buf_indices.w / 4];
int r_v = (buf_in_4 & masks.w) >> 8 * (buf_indices.w % 4);
r_v = extend_sign(r_v);

ivec4 texel = ivec4(a_v, b_v, g_v, r_v);

if (c_index + 3 >= uBlock.c_info.y) {
ivec4 c_ind = ivec4(c_index) + ivec4(0, 1, 2, 3);
ivec4 valid_c = ivec4(lessThan(c_ind, ivec4(uBlock.c_info.y)));
texel = texel * valid_c;
}

imageStore(uImage, pos, texel);
}
56 changes: 39 additions & 17 deletions aten/src/ATen/native/vulkan/impl/Packing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,24 @@ api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
}
}

switch (v_dst.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(nchw_to_image);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(nchw_to_image2d);
default:
TORCH_CHECK(false, "No kernel available!");
if (v_dst.dtype() == at::kFloat) {
switch (v_dst.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(nchw_to_image);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(nchw_to_image2d);
default:
TORCH_CHECK(false, "No kernel available!");
}
} else if (v_dst.dtype() == at::kBool) {
switch (v_dst.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(nchw_to_image_bool);
default:
TORCH_CHECK(false, "No kernel available!");
}
} else {
TORCH_CHECK(false, "Unsupported dtype!");
}
}

Expand All @@ -50,10 +61,10 @@ api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) {
switch (v_src.dtype()) {
case c10::ScalarType::QUInt8:
return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4)
: VK_KERNEL(image_to_nchw_quantized);
: VK_KERNEL(image_to_nchw_uint);
case c10::ScalarType::QInt8:
return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4)
: VK_KERNEL(image_to_nchw_quantized);
: VK_KERNEL(image_to_nchw_uint);
case c10::ScalarType::QInt32:
return VK_KERNEL(image_to_nchw_int32);
default:
Expand All @@ -70,13 +81,24 @@ api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) {
}
}

switch (v_src.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(image_to_nchw);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(image2d_to_nchw);
default:
TORCH_CHECK(false, "No kernel available!");
if (v_src.dtype() == at::kFloat) {
switch (v_src.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(image_to_nchw);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(image2d_to_nchw);
default:
TORCH_CHECK(false, "No kernel available!");
}
} else if (v_src.dtype() == at::kBool) {
switch (v_src.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(image_to_nchw_uint);
default:
TORCH_CHECK(false, "No kernel available!");
}
} else {
TORCH_CHECK(false, "Unsupported dtype!");
}
}

Expand Down Expand Up @@ -161,7 +183,7 @@ void record_image_to_nchw_op(
};

if (v_src.dtype() == c10::ScalarType::QUInt8 ||
v_src.dtype() == c10::ScalarType::QInt8) {
v_src.dtype() == c10::ScalarType::QInt8 || v_src.dtype() == at::kBool) {
if (plane_size % 4 == 0) {
global_size.data[0u] = plane_size / 4;
global_size.data[1u] = 1;
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/native/vulkan/ops/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ void memcpy_to_mapping(const Tensor& src, api::MemoryMap& dst_mapping) {
memcpy_to_mapping_impl<c10::qint8>(src, dst_mapping);
} else if (src.dtype() == c10::kQInt32) {
memcpy_to_mapping_impl<c10::qint32>(src, dst_mapping);
} else if (src.dtype() == c10::kBool) {
memcpy_to_mapping_uint8(src, dst_mapping);
} else {
TORCH_CHECK(
false,
"Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,",
" at::kHalf or at::Float but got ",
" c10::kBool, at::kHalf, or at::Float but got ",
src.dtype());
}
}
Expand All @@ -43,11 +45,13 @@ void memcpy_from_mapping(api::MemoryMap& src_mapping, Tensor& dst) {
memcpy_from_mapping_impl<c10::qint8>(src_mapping, dst);
} else if (dst.dtype() == c10::kQInt32) {
memcpy_from_mapping_impl<c10::qint32>(src_mapping, dst);
} else if (dst.dtype() == c10::kBool) {
memcpy_from_mapping_bool(src_mapping, dst);
} else {
TORCH_CHECK(
false,
"Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,",
" at::kHalf or at::Float but got ",
" c10::kBool, at::kHalf or at::Float but got ",
dst.dtype());
}
}
Expand Down
20 changes: 20 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ void memcpy_from_mapping_impl(api::MemoryMap& src_mapping, Tensor& dst) {
std::min(src_mapping.nbytes(), dst.nbytes()));
}

inline void memcpy_from_mapping_bool(api::MemoryMap& src_mapping, Tensor& dst) {
uint8_t* src_ptr = src_mapping.template data<uint8_t>();
bool* dst_ptr = dst.mutable_data_ptr<bool>();
for (int i = 0; (unsigned)i < std::min(src_mapping.nbytes(), dst.nbytes());
++i) {
dst_ptr[i] = static_cast<bool>(src_ptr[i]);
}
}

inline void memcpy_to_mapping_uint8(
const Tensor& src,
api::MemoryMap& dst_mapping) {
bool* src_ptr = src.mutable_data_ptr<bool>();
uint8_t* dst_ptr = dst_mapping.template data<uint8_t>();
for (int i = 0; (unsigned)i < std::min(dst_mapping.nbytes(), src.nbytes());
++i) {
dst_ptr[i] = static_cast<uint8_t>(src_ptr[i]);
}
}

void memcpy_to_mapping(const Tensor& src, api::MemoryMap& dst_mapping);

void memcpy_from_mapping(api::MemoryMap& src_mapping, Tensor& dst);
Expand Down