-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathlaunch.cc
158 lines (137 loc) · 5.86 KB
/
launch.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
/*
* Copyright Codeplay Software Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "portdnn/mem_object.h"
#include "portdnn/status.h"
#include "portdnn/conv2d/conv_type.h"
#include "portdnn/depthwise_conv2d/params.h"
#include "src/depthwise_conv2d/kernel_params.h"
#include "src/depthwise_conv2d/output_size.h"
#include "src/depthwise_conv2d/queue_depthwise_conv2d.h"
#include <stddef.h>
#include <cstdint>
#include <limits>
#include <CL/sycl.hpp>
#include "portdnn/export.h"
namespace sycldnn {
namespace depthwise_conv2d {
namespace internal {
namespace {
template <typename ConvType>
bool can_vectorize(DepthwiseConv2DParams const& p, int vector_width) {
// TODO(dmcbain): depthwise convolutions do not support vectorisation
// for channel multipliers that are not 1
if (p.channel_multiplier != 1) {
return false;
}
return (p.channels * p.channel_multiplier) % vector_width == 0;
}
template <typename ConvType, typename T, typename Index, int VectorWidth,
template <typename> class MemObj>
struct Launcher {
static SNNStatus launch(MemObj<T const>& input, MemObj<T const>& filter,
MemObj<T>& output,
DepthwiseConv2DParams const& params,
Index output_size, cl::sycl::queue& queue,
const std::vector<cl::sycl::event>& events) {
return queue_kernel<ConvType, VectorWidth>(input, filter, output, params,
output_size, queue, events);
}
};
template <typename T, typename Index, int VectorWidth,
template <typename> class MemObj>
struct Launcher<conv2d::conv_type::FilterBackprop, T, Index, VectorWidth,
MemObj> {
static SNNStatus launch(MemObj<T const>& input, MemObj<T const>& filter,
MemObj<T>& output,
DepthwiseConv2DParams const& params,
Index output_size, cl::sycl::queue& queue,
const std::vector<cl::sycl::event>& events) {
return queue_kernel_fil_bk<VectorWidth>(input, filter, output, params,
output_size, queue, events);
}
};
template <typename ConvType, typename T, typename IndexType,
template <typename> class MemObj,
typename = std::enable_if<is_mem_obj_v<MemObj<T>, T>>>
SNNStatus launch_vectorised(MemObj<T const>& input, MemObj<T const>& filter,
MemObj<T>& output,
DepthwiseConv2DParams const& params,
IndexType output_size, cl::sycl::queue& queue,
const std::vector<cl::sycl::event>& events) {
if (can_vectorize<ConvType>(params, 4)) {
return Launcher<ConvType, T, IndexType, 4, MemObj>::launch(
input, filter, output, params, output_size, queue, events);
} else if (can_vectorize<ConvType>(params, 2)) {
return Launcher<ConvType, T, IndexType, 2, MemObj>::launch(
input, filter, output, params, output_size, queue, events);
} else {
return Launcher<ConvType, T, IndexType, 1, MemObj>::launch(
input, filter, output, params, output_size, queue, events);
}
}
} // namespace
template <typename ConvType, typename T, template <typename> class MemObj,
typename = std::enable_if<is_mem_obj_v<MemObj<T>, T>>>
SNNStatus launch(MemObj<T const>& input, MemObj<T const>& filter,
MemObj<T>& output, DepthwiseConv2DParams const& params,
cl::sycl::queue& queue,
const std::vector<cl::sycl::event>& events) {
size_t output_size = get_output_size<ConvType>(params);
auto kernel_params = get_kernel_params<ConvType>(params);
if (output_size > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
#ifdef SNN_USE_INT64
return launch_vectorised<ConvType, T, int64_t>(
input, filter, output, kernel_params, static_cast<int64_t>(output_size),
queue, events);
#else
return StatusCode::IndexExceeded;
#endif // SNN_USE_INT64
} else {
return launch_vectorised<ConvType, T, int32_t>(
input, filter, output, kernel_params, static_cast<int32_t>(output_size),
queue, events);
}
}
#define INSTANTIATE_LAUNCHER(DTYPE, DIRECTION, MEM_OBJ) \
template SNN_EXPORT SNNStatus launch<DIRECTION, DTYPE>( \
MEM_OBJ<DTYPE const> & input, MEM_OBJ<DTYPE const> & filter, \
MEM_OBJ<DTYPE> & output, DepthwiseConv2DParams const& params, \
cl::sycl::queue& queue, const std::vector<cl::sycl::event>& events)
#define INSTANTIATE_FOR_TYPE(DTYPE, MEM_OBJ) \
INSTANTIATE_LAUNCHER(DTYPE, conv2d::conv_type::Forward, MEM_OBJ); \
INSTANTIATE_LAUNCHER(DTYPE, conv2d::conv_type::InputBackprop, MEM_OBJ); \
INSTANTIATE_LAUNCHER(DTYPE, conv2d::conv_type::FilterBackprop, MEM_OBJ)
#ifdef SNN_ENABLE_USM
INSTANTIATE_FOR_TYPE(float, USMMemObject);
#endif
INSTANTIATE_FOR_TYPE(float, BufferMemObject);
#ifdef SNN_USE_DOUBLE
#ifdef SNN_ENABLE_USM
INSTANTIATE_FOR_TYPE(double, USMMemObject);
#endif
INSTANTIATE_FOR_TYPE(double, BufferMemObject);
#endif
#ifdef SNN_USE_HALF
#ifdef SNN_ENABLE_USM
INSTANTIATE_FOR_TYPE(cl::sycl::half, USMMemObject);
#endif
INSTANTIATE_FOR_TYPE(cl::sycl::half, BufferMemObject);
#endif
#undef INSTANTIATE_FOR_TYPE
#undef INSTANTIATE_LAUNCHER
} // namespace internal
} // namespace depthwise_conv2d
} // namespace sycldnn