forked from llvm-mirror/llvm
/
SYCLSerializeArguments.cpp
348 lines (301 loc) · 12.8 KB
/
SYCLSerializeArguments.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
//===- SYCLSerializeArguments.cpp ---------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// Replace a call to a kernel task marker function by a call to associate a name
// to a task and a SYCL kernel instantiating code by some functions serializing
// its arguments.
//
// Basically we look for the functions containing a call to a kernel and we
// transform
// \code
// tail call void @_ZN2cl4sycl6detail22set_kernel_task_markerERNS1_4taskE(%"struct.cl::sycl::detail::task"* nonnull dereferenceable(240) %t) #2
// [...]
// tail call fastcc void @"_ZN2cl4sycl6detail18instantiate_kernelIDnZZ9test_mainiPPcENK3$_1clERNS0_7handlerEEUlvE_EEvT0_"(i32* %agg.tmp.idx.val.idx.val) #2
// \endcode
// into
// \code
// call void @_ZN2cl4sycl3drt10set_kernelERNS0_6detail4taskEPKc(%"struct.cl::sycl::detail::task"* %t, i8* getelementptr inbounds ([94 x i8], [94 x i8]* @0, i32 0, i32 0))
// [...]
// %15 = bitcast i32* %agg.tmp.idx.val.idx.val.c to i8*
// call void @_ZN2cl4sycl3drt22serialize_accessor_argERNS0_6detail4taskEmPvm(%"struct.cl::sycl::detail::task"* %t, i64 0, i8* %15, i64 4)
// \endcode
// by including also the effect of the SYCLArgsFlattening pass.
//
// ===---------------------------------------------------------------------===//
#include <cstddef>
#include <functional>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/ValueSymbolTable.h"
#include "llvm/Pass.h"
#include "llvm/SYCL.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
/// Switch on debug with set DebugFlag=0 or set DebugFlag=1 in debugger or with
/// option -debug or -debug-only=SYCL
#define DEBUG_TYPE "SYCL"
/// Displayed with -stats
STATISTIC(SYCLKernelProcessed, "Number of SYCL kernel functions processed");
// Put the code in an anonymous namespace to avoid polluting the global
// namespace
namespace {
/// Replace a SYCL kernel code by a function serializing its arguments
struct SYCLSerializeArguments : public ModulePass {
static char ID; // Pass identification, replacement for typeid
/// The mangled name of the function marking the task to be used to launch the
/// kernel.
///
/// Note that it has to be defined in some include files so this pass can
/// find it.
///
/// The function is defined in
/// triSYCL/include/CL/sycl/detail/instantiate_kernel.hpp
///
/// \code
/// extern void set_kernel_task_marker(detail::task &task)
/// \endcode
static auto constexpr SetKernelTaskMarkerFunctionName =
"_ZN2cl4sycl6detail22set_kernel_task_markerERNS1_4taskE";
/// The mangled name of the serialization function to use.
///
/// Note that it has to be defined in some include files so this pass can use
/// it.
///
/// The function is defined in
/// triSYCL/include/CL/sycl/device_runtime.hpp
///
/// \code
/// TRISYCL_WEAK_ATTRIB_PREFIX void TRISYCL_WEAK_ATTRIB_SUFFIX
/// serialize_arg(detail::task &task,
/// std::size_t index,
/// void *arg,
/// std::size_t arg_size)
/// \endcode
static auto constexpr SerializationFunctionName =
"_ZN2cl4sycl3drt13serialize_argERNS0_6detail4taskEmPvm";
/// The mangled name of the serialization function to use.
///
/// Note that it has to be defined in some include files so this pass can use
/// it.
///
/// The function is defined in
/// triSYCL/include/CL/sycl/device_runtime.hpp
///
/// \code
/// TRISYCL_WEAK_ATTRIB_PREFIX void TRISYCL_WEAK_ATTRIB_SUFFIX
/// serialize_accessor_arg(detail::task &task,
/// std::size_t index,
/// void *arg,
/// std::size_t arg_size)
/// \endcode
static auto constexpr AccessorSerializationFunctionName =
"_ZN2cl4sycl3drt22serialize_accessor_argERNS0_6detail4taskEmPvm";
/// The mangled name of the kernel launching function to use.
///
/// Note that it has to be defined in some include files so this pass can use
/// it.
///
/// The function is defined in
/// triSYCL/include/CL/sycl/device_runtime.hpp
///
/// \endcode
/// TRISYCL_WEAK_ATTRIB_PREFIX void TRISYCL_WEAK_ATTRIB_SUFFIX
/// set_kernel(detail::task &task,
/// const char *kernel_name,
/// const char *kernel_short_name)
/// \endcode
static auto constexpr SetKernelFunctionName =
"_ZN2cl4sycl3drt10set_kernelERNS0_6detail4taskEPKcS6_";
SYCLSerializeArguments() : ModulePass(ID) {}
bool doInitialization(Module &M) override {
LLVM_DEBUG(errs() << "Enter: " << M.getModuleIdentifier() << "\n\n");
// Do not change the code
return false;
}
bool doFinalization(Module &M) override {
LLVM_DEBUG(errs() << "Exit: " << M.getModuleIdentifier() << "\n\n");
// Do not change the code
return false;
}
/// Replace the kernel call instructions by the serialization of its arguments
///
/// \param[inout] F is a function containing a call to
/// \c cl::sycl::detail::set_kernel_task_marker
///
/// \param[in] Task is the pointer to the \c cl::sycl::detail::task
///
/// \param[inout] KernelCall is the instruction calling
/// the kernel instantiation
///
/// There might be more than 1 calling to the same kernel instance because of
/// some CFG restructuration made by Clang/LLVM before, specially if the
/// accessors are not simplified DRT ones...
void serializeKernelArguments(Function &F,
Value &Task,
Instruction &KernelCall) {
// Need the data layout of the target to measure object size
auto M = F.getParent();
auto DL = M->getDataLayout();
// Get the predefined serialization functions to use
auto SF = M->getValueSymbolTable().lookup(SerializationFunctionName);
assert(SF && "Serialization function not found");
auto ASF =
M->getValueSymbolTable().lookup(AccessorSerializationFunctionName);
assert(ASF && "Accessor serialization function not found");
// Use an IRBuilder to ease IR creation in the basic block
auto BB = KernelCall.getParent();
IRBuilder<> Builder { BB };
// Insert the future new instructions before the current kernel call
Builder.SetInsertPoint(&KernelCall);
CallSite KernelCallSite { &KernelCall };
// The index used to number the arguments in the serialization
std::size_t IndexNumber = 0;
// Iterate on the kernel call arguments
for (auto &A : KernelCallSite.args()) {
LLVM_DEBUG(dbgs() << "Serializing '" << A->getName() << "'.\n");
LLVM_DEBUG(dbgs() << "Size '" << DL.getTypeAllocSize(A->getType()) << "'.\n");
// An IR version of the index number
auto Index = Builder.getInt64(IndexNumber);
// \todo Refactor/fuse the then/else part
if (auto PTy = dyn_cast<PointerType>(A->getType())) {
LLVM_DEBUG(dbgs() << " pointer to\n");
LLVM_DEBUG(PTy->getElementType()->dump());
// The pointer argument casted to a void *
auto Arg =
Builder.CreatePointerCast(&*A, Type::getInt8PtrTy(F.getContext()));
// The size of the pointee type
auto ArgSize = DL.getTypeAllocSize(PTy->getElementType());
// Insert the call to the serialization function with the 3 required
// arguments
Value * Args[] { &Task, Index, Arg, Builder.getInt64(ArgSize) };
// \todo add an initializer list to makeArrayRef
Builder.CreateCall(ASF, makeArrayRef(Args));
}
else {
// Create an intermediate memory place to pass the value by address
auto Alloca = Builder.CreateAlloca(A->getType());
Builder.CreateStore(&*A, Alloca);
auto Arg =
Builder.CreatePointerCast(Alloca,
Type::getInt8PtrTy(F.getContext()));
// The size of the argument
auto ArgSize = DL.getTypeAllocSize(A->getType());
// Insert the call to the serialization function with the 3 required
// arguments
Value * Args[] { &Task, Index, Arg, Builder.getInt64(ArgSize) };
// \todo add an initializer list to makeArrayRef
Builder.CreateCall(SF, makeArrayRef(Args));
}
++IndexNumber;
}
// Now remove the initial kernel call
KernelCall.eraseFromParent();
// Count the number of kernel appearance. Note that a kernel call might
// happen several times because of CFG massaging...
++SYCLKernelProcessed;
}
/// Replace the kernel call instructions by the serialization of its arguments
///
/// \param[inout] F is a function containing a call to
/// cl::sycl::detail::set_kernel_task_marker
///
/// \param[inout] MarkerCall is the instruction calling
/// cl::sycl::detail::set_kernel_task_marker
///
/// There might be more than 1 calling to the same kernel instance because of
/// some CFG restructuration made by Clang/LLVM before...
void setKernelTask(Function &F, Instruction &MarkerCall) {
StringRef KernelName;
// Now find the kernel calling sites independently to avoid rewriting the
// world we iterate on
SmallVector<std::reference_wrapper<Instruction>, 3> KernelCallSites;
// Look for calls by this function
for (BasicBlock &BB : F)
for (Instruction &I : BB)
if (auto CS = CallSite { &I })
// If we call a kernel, it is a kernel call site
if (auto CF = CS.getCalledFunction())
if (sycl::isKernel(*CF)) {
KernelCallSites.emplace_back(I);
// Use the kernel instantiating function name as the kernel name
KernelName = CF->getName();
}
auto CS = CallSite { &MarkerCall };
assert(CS && "Kernel task marker function not found");
/* Get the cl::sycl::detail::task address which is passed as the argument of
the marking function */
auto &Task = *CS.getArgument(0);
// Use an IRBuilder to ease IR creation in the basic block
auto BB = MarkerCall.getParent();
IRBuilder<> Builder { BB };
// Insert the future new instructions before the current task marking call
Builder.SetInsertPoint(&MarkerCall);
// Get the predefined kernel setting function to use
auto SKF = F.getParent()->getValueSymbolTable()
.lookup(SetKernelFunctionName);
assert(SKF && "Kernel setting function not found");
// Create a global string variable with the name of the kernel itself
// and return a char * on it
auto Name = Builder.CreateGlobalStringPtr(KernelName);
// Create a a global string variable with the short name of the kernel
// itself and return a char * on it
auto ShortName = Builder.CreateGlobalStringPtr(
sycl::registerSYCLKernelAndGetShortName(KernelName));
// Add the setting of the kernel
Value * Args[] { &Task, Name, ShortName };
// \todo add an initializer list to makeArrayRef
Builder.CreateCall(SKF, makeArrayRef(Args));
// Now that we have used the task parameter, we can discard the useless
// call to the marking function
MarkerCall.eraseFromParent();
// Then serialize the arguments of the detected kernels
for (auto &KernelCall : KernelCallSites)
serializeKernelArguments(F, Task, KernelCall);
}
/// Visit all the functions of the module
bool runOnModule(Module &M) override {
// First find the kernel calling site independently to avoid rewriting the
// world we iterate on
SmallVector<std::pair<std::reference_wrapper<Function>,
std::reference_wrapper<Instruction>>,
8> KernelMarkerCallSites;
for (auto &F : M.functions())
// Look for calls by this function
for (BasicBlock &BB : F)
for (Instruction &I : BB)
if (auto CS = CallSite { &I })
// If we call a kernel, it is a kernel call site
if (auto CF = CS.getCalledFunction())
if (CF->getName().equals(SetKernelTaskMarkerFunctionName))
KernelMarkerCallSites.emplace_back(F, I);
// Then serialize the calls to the detected kernels
for (auto &MarkerCall : KernelMarkerCallSites)
setKernelTask(MarkerCall.first, MarkerCall.second);
// The module changed if there were some kernels
return !KernelMarkerCallSites.empty();
}
};
}
char SYCLSerializeArguments::ID = 0;
static RegisterPass<SYCLSerializeArguments> X {
"SYCL-serialize-arguments",
"pass to serialize arguments of a SYCL kernel"
};