-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
Clamp.cpp
142 lines (125 loc) · 4.15 KB
/
Clamp.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <ATen/native/vulkan/ops/Common.h>
#include <torch/library.h>
namespace at {
namespace native {
namespace vulkan {
namespace ops {
namespace {
Tensor clamp(
const Tensor& self_arg,
const c10::optional<Scalar> min_value,
const c10::optional<Scalar> max_value) {
if (!min_value && !max_value) {
TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None");
}
api::Context* const context = api::context();
const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();
const vTensor& v_self = convert(self);
vTensor v_output{
context,
self.sizes(),
self.options(),
};
api::Command::Buffer command_buffer = context->command().pool.allocate();
command_buffer.begin();
{
if (v_output.has_image() && v_self.has_image()) {
const struct {
uint32_t width, height, channels;
float min_value;
float max_value;
} block {
v_output.extents().width,
v_output.extents().height,
v_output.extents().depth,
min_value ? min_value->to<float>() : -std::numeric_limits<float>::infinity(),
max_value ? max_value->to<float>() : std::numeric_limits<float>::infinity(),
};
context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(clamp),
v_output.extents(),
// Write-only access bypasses synchronization but inserts appropriate
// barriers if necessary.
v_output.image(command_buffer, vTensor::Access::Write),
// Read-only access is implied on const tensors and triggers an async
// synchronization if necessary.
v_self.image(command_buffer),
// Object lifetime is managed by the resource pool.
// It is OK not to keep track of the handle.
context->resource().pool.uniform(block).object);
}
else {
TORCH_CHECK(false, "Not implemented!");
}
}
command_buffer.end();
command_buffer.submit(context->gpu().queue);
return convert(v_output);
}
Tensor& clamp_(
Tensor& self_arg,
const c10::optional<Scalar> min_value,
const c10::optional<Scalar> max_value) {
api::Context* const context = api::context();
if (!min_value && !max_value) {
TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None");
}
TORCH_CHECK(
self_arg.is_vulkan(),
"Vulkan: In-place clamp is only supported on Vulkan tensors.");
vTensor& v_self = convert(self_arg);
api::Command::Buffer command_buffer = context->command().pool.allocate();
command_buffer.begin();
{
if (v_self.has_image()) {
const struct {
uint32_t width, height, channels;
float min_value;
float max_value;
} block {
v_self.extents().width,
v_self.extents().height,
v_self.extents().depth,
min_value ? min_value->to<float>() : -std::numeric_limits<float>::infinity(),
max_value ? max_value->to<float>() : std::numeric_limits<float>::infinity(),
};
context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(clamp_),
v_self.extents(),
// Read-Write access triggers an async synchronization if necessory
// and inserts appropriate barriers if hazards are detected.
v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write),
// Object lifetime is managed by the resource pool.
// It is OK not to keep track of the handle.
context->resource().pool.uniform(block).object);
}
else {
TORCH_CHECK(false, "Not implemented!");
}
}
command_buffer.end();
command_buffer.submit(context->gpu().queue);
return self_arg;
}
#ifdef USE_VULKAN_API
TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl("clamp", TORCH_FN(clamp));
m.impl("clamp_", TORCH_FN(clamp_));
}
#endif /* USE_VULKAN_API */
} // namespace
} // namespace ops
} // namespace vulkan
} // namespace native
} // namespace at