-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathmem_object.h
466 lines (423 loc) · 15.3 KB
/
mem_object.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
/*
* Copyright Codeplay Software Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PORTDNN_INCLUDE_MEM_OBJECT_H_
#define PORTDNN_INCLUDE_MEM_OBJECT_H_
/**
* \file
* Provides the \ref sycldnn::USMMemObject and \ref sycldnn::BufferMemObject
* classes, along with the \ref sycldnn::make_mem_object helper function.
*/
#include <CL/sycl.hpp>
#include <type_traits>
#include "portdnn/accessor_types.h"
#include "portdnn/helpers/macros.h"
#include "portdnn/helpers/sycl_language_helpers.h"
namespace sycldnn {
/**
* Forward decleration of USMMemObject.
*/
template <typename T>
class USMMemObject;
/**
* Forward decleration of BufferMemObject/
*/
template <typename T>
class BufferMemObject;
/**
* Templated struct to check if type T is of type USMMemObject specialized for
* U.
*/
template <typename T, typename U>
struct is_usm_obj
: std::integral_constant<
bool, std::is_same<USMMemObject<U>,
typename std::remove_cv<T>::type>::value> {};
/**
* Helper for adding template enable_if checks for is_usm_obj.
*/
template <typename T, typename U>
inline constexpr bool is_usm_obj_v = is_usm_obj<T, U>::value;
/**
* Templated struct to check if type T is of type USMMemObject specialized for
* U.
*/
template <typename T, typename U>
struct is_buffer_obj
: std::integral_constant<
bool, std::is_same<BufferMemObject<U>,
typename std::remove_cv<T>::type>::value> {};
/**
* Helper for adding template enable_if checks for is_usm_obj.
*/
template <typename T, typename U>
inline constexpr bool is_buffer_obj_v = is_buffer_obj<T, U>::value;
/**
* Templated struct to check if type T is of type USMMemObject or
* BufferMemObject specialized for U.
*/
template <typename T, typename U>
struct is_mem_obj : std::integral_constant<bool, is_usm_obj_v<T, U> ||
is_buffer_obj_v<T, U>> {};
/**
* Helper for adding template enable_if checks for is_mem_obj.
*/
template <typename T, typename U>
inline constexpr bool is_mem_obj_v = is_mem_obj<T, U>::value;
/**
* Helper function to create BufferMemObjects.
*
* This is useful as it can automatically deduce the template types, enabling
* BufferMemObjects to be constructed as simply as:
* \code
* auto mem_object = make_buffer_mem_object(buffer, size, offset);
* \endcode
*
* \param buffer The SYCL buffer to use as the underlying memory object.
* \param extent The overall number of elements in the buffer to provide
* access to.
* \param offset The offset from the start of the buffer (in number of
* elements) to use as the initial index for the memory object.
*
* \return A BufferMemObject that provides access to the given SYCL buffer.
*/
template <typename T>
BufferMemObject<T> make_buffer_mem_object(cl::sycl::buffer<T, 1> buffer,
size_t extent, size_t offset = 0) {
SNN_ASSERT(buffer.size() >= extent + offset,
"Buffer must contain at least extent + offset elements");
return BufferMemObject<T>{buffer, extent, offset};
}
/**
* Helper function to create USMMemObjects.
*
* This is useful as it can automatically deduce the template types, enabling
* USMMemObjects to be constructed as simply as:
* \code
* auto mem_object = make_usm_mem_object(ptr, size, offset);
* \endcode
*
* \param ptr The SYCL pointer to use as the underlying memory object.
* \param extent The overall number of elements in the memory block.
* \param offset The offset from the start of the USM address (in number of
* elements).
*
* \return A USMMemObject that provides access to the given SYCL USM pointer.
*/
template <typename T>
USMMemObject<T> make_usm_mem_object(T* ptr, size_t extent, size_t offset = 0) {
return USMMemObject<T>{ptr, extent, offset};
}
/**
* Helper function to create MemObjects.
*
* This is useful as it can automatically deduce the template types, enabling
* BufferMemObjects to be constructed as simply as:
* \code
* auto mem_object = make_buffer_mem_object(buffer, size, offset);
* \endcode
*
* \param buffer The SYCL buffer to use as the underlying memory object.
* \param extent The overall number of elements in the buffer to provide
* access to.
* \param offset The offset from the start of the buffer (in number of
* elements) to use as the initial index for the memory object.
*
* \return A BufferMemObject that provides access to the given SYCL buffer.
*/
template <typename T>
BufferMemObject<T> make_mem_object(cl::sycl::buffer<T, 1> buffer, size_t extent,
size_t offset = 0) {
return make_buffer_mem_object(buffer, extent, offset);
}
/**
* Helper function to create MemObjects.
*
* This is useful as it can automatically deduce the template types, enabling
* BufferMemObjects to be constructed as simply as:
* \code
* auto mem_object = make_buffer_mem_object(buffer, size, offset);
* \endcode
*
* \param buffer The SYCL buffer to use as the underlying memory object.
* \param extent The overall number of elements in the buffer to provide
* access to.
* \param offset The offset from the start of the buffer (in number of
* elements) to use as the initial index for the memory object.
*
* \return A BufferMemObject that provides access to the given SYCL buffer.
*/
template <typename T, typename = std::enable_if<std::is_const_v<T>>>
BufferMemObject<T> make_mem_object(
cl::sycl::buffer<typename std::remove_const<T>::type, 1> buffer,
size_t extent, size_t offset = 0) {
return make_buffer_mem_object(buffer.template reinterpret<T>(), extent,
offset);
}
/**
* Helper function to create USMMemObjects.
*
* This is useful as it can automatically deduce the template types, enabling
* USMMemObjects to be constructed as simply as:
* \code
* auto mem_object = make_usm_mem_object(ptr, size, offset);
* \endcode
*
* \param ptr The SYCL pointer to use as the underlying memory object.
* \param extent The overall number of elements in the memory block.
* \param offset The offset from the start of the USM address (in number of
* elements).
*
* \return A USMMemObject that provides access to the given SYCL USM pointer.
*/
template <typename T>
USMMemObject<T> make_mem_object(T* ptr, size_t extent, size_t offset = 0) {
return make_usm_mem_object(ptr, extent, offset);
}
/**
* The implementation of USMMemObject for SYCL pointers.
*/
template <typename T>
class USMMemObject {
private:
/** The datatype stored in the memory object. */
using DataType = T;
/** Alias for the SYCL command group handler. */
using Handler = cl::sycl::handler;
public:
/**
* Construct a USMMemObject wrapper around the given SYCL pointer.
*
* \param ptr SYCL pointer to use as underlying memory.
* \param extent The overall number of elements the pointer spans.
* \param offset The offset from the start of the pointer (in number of
* elements) to use as the initial index for the memory
*/
USMMemObject(DataType* ptr, size_t extent, size_t offset)
: ptr_(ptr), extent_(extent), offset_(offset){};
/**
* Returns the underlying USM pointer
*
* \return The underlying USM pointer.
*/
DataType* get_pointer() const { return ptr_; }
/**
* Get the extent of this USMMemObject. This is the number of elements in the
* SYCL pointer that have been allocated.
* \return The extent of this USMMemObject.
*/
size_t get_extent() const { return extent_; }
/**
* Get the offset of this USMMemObject.
* \return The number of elements offset from the start of the pointer.
*/
size_t get_offset() const { return offset_; }
/**
* Get a read only generic memory object to the underlying memory object.
*
* \param cgh The SYCL command group handler.
* \return A ReadMem wrapper containing a SYCL pointer.
*/
ReadMem<T, /*IsUSM*/ true> read_mem(Handler& cgh) {
return {ptr_, cgh, extent_, offset_};
}
/**
* Get a read-write generic memory object to the underlying memory object.
*
* \param cgh The SYCL command group handler.
* \return A ReadMem wrapper containing a SYCL pointer.
*/
template <typename U = DataType,
typename = std::enable_if<std::is_same<U, DataType>::value>>
typename std::enable_if_t<!std::is_const<U>::value,
ReadWriteMem<U, /*IsUSM*/ true>>
read_write_mem(Handler& cgh) {
return {ptr_, cgh, extent_, offset_};
}
/**
* Get a write only generic memory object to the underlying memory object.
*
* \param cgh The SYCL command group handler.
* \return A WriteMem wrapper containing a SYCL pointer.
*/
template <typename U = DataType,
typename = std::enable_if<std::is_same<U, DataType>::value>>
typename std::enable_if_t<!std::is_const<U>::value,
WriteMem<U, /*IsUSM*/ true>>
write_mem(Handler& cgh) {
return {ptr_, cgh, extent_, offset_};
}
/**
* Return a new USMMemObject with a pointer casted to a new type.
* \return Casted USMMemObject.
*/
template <typename NewDataType,
typename std::enable_if<sizeof(NewDataType) == sizeof(DataType),
int>::type = 0>
USMMemObject<NewDataType> cast() {
return USMMemObject<NewDataType>(reinterpret_cast<NewDataType*>(ptr_),
extent_, offset_);
}
/**
* Return the same USMMemObject as a read-only one.
* \return Read-only USMMemObject.
*/
USMMemObject<DataType const> as_const() {
return this->cast<DataType const>();
}
private:
/** The underlying SYCL pointer. */
DataType* ptr_;
/** The number of elements the pointer spans. */
size_t extent_;
/** The offset from the start of the pointer (in number of elements). */
size_t offset_;
};
template <typename T>
class BufferMemObject {
private:
/** The datatype stored in the memory object. */
using DataType = T;
using Buffer = cl::sycl::buffer<T, 1>;
using Handler = cl::sycl::handler;
public:
/**
* Construct a BufferMemObject wrapper around the given SYCL buffer.
*
* \param buffer SYCL buffer to use as underlying memory.
* \param extent The overall number of elements in the buffer to provide
* access to.
* \param offset The offset from the start of the buffer (in number of
* elements) to use as the initial index for the memory
* object.
*/
BufferMemObject(Buffer buffer, size_t extent, size_t offset)
: buffer_(buffer), extent_(extent), offset_(offset) {
SNN_ASSERT(buffer_.size() >= extent_ + offset_,
"Buffer must contain at least extent + offset elements");
};
/**
* Get a reference to the SYCL buffer referred to by this MemObject.
* \return A reference to the SYCL buffer.
*/
Buffer const& get_buffer() const { return buffer_; }
/**
* Get the extent of this MemObject. This is the number of elements in the
* SYCL buffer that are available to a user when a SYCL accessor is
* requested.
* \return The extent of this MemObject.
*/
size_t get_extent() const { return extent_; }
/**
* Get the offset of this MemObject into its Buffer.
* \return The number of elements offset from the start of the Buffer.
*/
size_t get_offset() const { return offset_; }
/**
* Get a read only accessor to the underlying memory object.
*
* \param cgh The SYCL command group handler to bind the buffer accessor to.
* \return A ReadAccessor wrapper containing a SYCL accessor.
*/
ReadAccessor<DataType> read_accessor(Handler& cgh) {
return {buffer_, cgh, extent_, offset_};
}
/**
* Get a read-write accessor to the underlying memory object.
*
* \param cgh The SYCL command group handler to bind the buffer accessor to.
* \return A ReadWriteAccessor wrapper containing a SYCL accessor.
*/
template <typename U = DataType,
typename = std::enable_if<std::is_same<U, DataType>::value>>
typename std::enable_if_t<!std::is_const<U>::value, ReadWriteAccessor<U>>
read_write_accessor(Handler& cgh) {
return {buffer_, cgh, extent_, offset_};
}
/**
* Get a write only accessor to the underlying memory object.
*
* \param cgh The SYCL command group handler to bind the buffer accessor to.
* \return A WriteAccessor wrapper containing a SYCL accessor.
*/
template <typename U = DataType,
typename = std::enable_if<std::is_same<U, DataType>::value>>
typename std::enable_if_t<!std::is_const<U>::value, WriteAccessor<U>>
write_accessor(Handler& cgh) {
return {buffer_, cgh, extent_, offset_};
}
/**
* Get a read only generic memory object to the underlying memory object.
*
* \param cgh The SYCL command group handler.
* \return A ReadMem wrapper containing a SYCL buffer.
*/
ReadMem<T, /*IsUSM*/ false> read_mem(Handler& cgh) {
return {buffer_, cgh, extent_, offset_};
}
/**
* Get a read-write generic memory object to the underlying memory object.
*
* \param cgh The SYCL command group handler.
* \return A ReadMem wrapper containing a SYCL buffer.
*/
template <typename U = DataType,
typename = std::enable_if<std::is_same<U, DataType>::value>>
typename std::enable_if_t<!std::is_const<U>::value,
ReadWriteMem<U, /*IsUSM*/ false>>
read_write_mem(Handler& cgh) {
return {buffer_, cgh, extent_, offset_};
}
/**
* Get a write only generic memory object to the underlying memory object.
*
* \param cgh The SYCL command group handler.
* \return A ReadMem wrapper containing a SYCL buffer.
*/
template <typename U = DataType,
typename = std::enable_if<std::is_same<U, DataType>::value>>
typename std::enable_if_t<!std::is_const<U>::value,
WriteMem<U, /*IsUSM*/ false>>
write_mem(Handler& cgh) {
return {buffer_, cgh, extent_, offset_};
}
/**
* Return a new MemObject with a buffer casted to a new type.
* \return Casted BufferMemObject.
*/
template <typename NewDataType,
typename std::enable_if<sizeof(NewDataType) == sizeof(DataType),
int>::type = 0>
BufferMemObject<NewDataType> cast() {
return BufferMemObject<NewDataType>(
buffer_.template reinterpret<NewDataType>(), extent_, offset_);
}
/**
* Return a new MemObject with a buffer casted to a const of the DataType.
* \return const BufferMemObject.
*/
BufferMemObject<DataType const> as_const() {
return this->cast<DataType const>();
}
private:
/** The underlying SYCL buffer. */
Buffer buffer_;
/** The number of elements to expose in the SYCL buffer. */
size_t extent_;
/** The offset from the start of the buffer (in elements). */
size_t offset_;
};
} // namespace sycldnn
#endif // PORTDNN_INCLUDE_MEM_OBJECT_H_