Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vulkan clamp op #47196

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/native/vulkan/VulkanOps.cpp
Expand Up @@ -1163,14 +1163,14 @@ void clamp(
int32_t W;
int32_t H;
int32_t C_4;
int32_t C;
//int32_t C;
float min;
float max;
};
ConstBlock cb{safe_downcast<int32_t>(W),
safe_downcast<int32_t>(H),
safe_downcast<int32_t>(C_4),
safe_downcast<int32_t>(C),
//safe_downcast<int32_t>(C),
min,
max};
VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb));
Expand Down
20 changes: 10 additions & 10 deletions aten/src/ATen/native/vulkan/glsl/clamp.glsl
Expand Up @@ -2,22 +2,22 @@
#define PRECISION $precision
layout(std430) buffer;
layout(std430) uniform;
layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform constBlock {
ivec4 size;
layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform restrict Block {
ivec3 WHC;
float minValue;
float maxValue;
}
uConstBlock;
} uBlock;

layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;

void main() {
ivec3 pos = ivec3(gl_GlobalInvocationID);
if (all(lessThan(pos, uConstBlock.size.xyz))) {
vec4 v = texelFetch(uInput, pos, 0);
const ivec3 pos = ivec3(gl_GlobalInvocationID);
if (all(lessThan(pos, uBlock.WHC))) {
imageStore(
uOutput, pos, clamp(v, uConstBlock.minValue, uConstBlock.maxValue));
uOutput,
pos,
clamp(texelFetch(uInput, pos, 0), uBlock.minValue, uBlock.maxValue));
}
}
22 changes: 22 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/clamp_.glsl
@@ -0,0 +1,22 @@
#version 450 core
#define PRECISION $precision
layout(std430) buffer;
layout(std430) uniform;
layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput;
layout(set = 0, binding = 1) uniform restrict Block {
ivec3 WHC;
float minValue;
float maxValue;
} uBlock;

layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
if (all(lessThan(pos, uBlock.WHC))) {
imageStore(
uOutput,
pos,
clamp(imageLoad(uOutput, pos), uBlock.minValue, uBlock.maxValue));
}
}
142 changes: 142 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Clamp.cpp
@@ -0,0 +1,142 @@
#include <ATen/native/vulkan/ops/Common.h>
#include <torch/library.h>

namespace at {
namespace native {
namespace vulkan {
namespace ops {
namespace {

Tensor clamp(
const Tensor& self_arg,
const c10::optional<Scalar> min_value,
const c10::optional<Scalar> max_value) {
if (!min_value && !max_value) {
TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None");
}

api::Context* const context = api::context();

const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();
const vTensor& v_self = convert(self);

vTensor v_output{
context,
self.sizes(),
self.options(),
};

api::Command::Buffer command_buffer = context->command().pool.allocate();
command_buffer.begin();
{
if (v_output.has_image() && v_self.has_image()) {
const struct {
uint32_t width, height, channels;
float min_value;
float max_value;
} block {
v_output.extents().width,
v_output.extents().height,
v_output.extents().depth,
min_value ? min_value->to<float>() : -std::numeric_limits<float>::infinity(),
max_value ? max_value->to<float>() : std::numeric_limits<float>::infinity(),
};

context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(clamp),
v_output.extents(),
// Write-only access bypasses synchronization but inserts appropriate
// barriers if necessary.
v_output.image(command_buffer, vTensor::Access::Write),
// Read-only access is implied on const tensors and triggers an async
// synchronization if necessary.
v_self.image(command_buffer),
// Object lifetime is managed by the resource pool.
// It is OK not to keep track of the handle.
context->resource().pool.uniform(block).object);
}
else {
TORCH_CHECK(false, "Not implemented!");
}
}
command_buffer.end();
command_buffer.submit(context->gpu().queue);

return convert(v_output);
}

Tensor& clamp_(
Tensor& self_arg,
const c10::optional<Scalar> min_value,
const c10::optional<Scalar> max_value) {
api::Context* const context = api::context();
if (!min_value && !max_value) {
TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None");
}
TORCH_CHECK(
self_arg.is_vulkan(),
"Vulkan: In-place clamp is only supported on Vulkan tensors.");

vTensor& v_self = convert(self_arg);

api::Command::Buffer command_buffer = context->command().pool.allocate();
command_buffer.begin();
{
if (v_self.has_image()) {
const struct {
uint32_t width, height, channels;
float min_value;
float max_value;
} block {
v_self.extents().width,
v_self.extents().height,
v_self.extents().depth,
min_value ? min_value->to<float>() : -std::numeric_limits<float>::infinity(),
max_value ? max_value->to<float>() : std::numeric_limits<float>::infinity(),
};

context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(clamp_),
v_self.extents(),
// Read-Write access triggers an async synchronization if necessory
// and inserts appropriate barriers if hazards are detected.
v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write),
// Object lifetime is managed by the resource pool.
// It is OK not to keep track of the handle.
context->resource().pool.uniform(block).object);
}
else {
TORCH_CHECK(false, "Not implemented!");
}
}
command_buffer.end();
command_buffer.submit(context->gpu().queue);

return self_arg;
}

#ifdef USE_VULKAN_API

TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl("clamp", TORCH_FN(clamp));
m.impl("clamp_", TORCH_FN(clamp_));
}

#endif /* USE_VULKAN_API */

} // namespace
} // namespace ops
} // namespace vulkan
} // namespace native
} // namespace at
2 changes: 1 addition & 1 deletion aten/src/ATen/native/vulkan/ops/Mul.cpp
Expand Up @@ -71,7 +71,7 @@ Tensor& mul_scalar_(

TORCH_CHECK(
self_arg.is_vulkan(),
"Vulkan: In-place add is only supported on Vulkan tensors.");
"Vulkan: In-place mul_scalar is only supported on Vulkan tensors.");

vTensor& v_self = convert(self_arg);

Expand Down
26 changes: 26 additions & 0 deletions aten/src/ATen/test/vulkan_api_test.cpp
Expand Up @@ -103,6 +103,32 @@ TEST(VulkanAPITest, mul_scalar_) {
ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu()));
}

TEST(VulkanAPITest, clamp) {
const auto a_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
const auto a_vulkan = a_cpu.vulkan();

const float min_value = 0.2f;
const float max_value = 0.8f;

const auto c_cpu = at::clamp(a_cpu, min_value, max_value);
const auto c_vulkan = at::clamp(a_vulkan, min_value, max_value);

ASSERT_TRUE(almostEqual(c_cpu, c_vulkan.cpu()));
}

TEST(VulkanAPITest, clamp_) {
const auto a_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
const auto a_vulkan = a_cpu.vulkan();

const float min_value = 0.2f;
const float max_value = 0.8f;

a_cpu.clamp_(min_value, max_value);
a_vulkan.clamp_(min_value, max_value);

ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu()));
}

TEST(VulkanAPITest, copy) {
const auto cpu = at::rand({13, 17, 37, 19}, at::device(at::kCPU).dtype(at::kFloat));
ASSERT_TRUE(exactlyEqual(cpu, cpu.vulkan().cpu()));
Expand Down