forked from preda/gpuowl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkernel.h
67 lines (54 loc) · 1.96 KB
/
kernel.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// Copyright Mihai Preda.
#pragma once
#include "Queue.h"
#include "Buffer.h"
#include "timeutil.h"
#include "common.h"
#include <string>
#include <stdexcept>
class Kernel {
KernelHolder kernel;
int groupSize;
QueuePtr queue;
size_t workSize;
string name;
public:
Kernel(cl_program program, QueuePtr queue, cl_device_id device, u32 nWorkGroups, const std::string &name) :
kernel(makeKernel(program, name.c_str())),
groupSize(kernel ? getWorkGroupSize(kernel.get(), device, name.c_str()) : 0),
queue(std::move(queue)),
workSize(nWorkGroups * groupSize),
name(name)
{}
Kernel(cl_program program, QueuePtr queue, cl_device_id device, const std::string &name, size_t workSize) :
kernel(makeKernel(program, name.c_str())),
groupSize(kernel ? getWorkGroupSize(kernel.get(), device, name.c_str()) : 0),
queue(std::move(queue)),
workSize(workSize),
name(name)
{
assert(groupSize == 0 || (workSize % groupSize == 0));
}
template<typename... Args> void setFixedArgs(int pos, const Args &...tail) { setArgs(pos, tail...); }
template<typename... Args> void operator()(const Args &...args) {
setArgs(0, args...);
run();
}
string getName() { return name; }
private:
template<typename T> void setArgs(int pos, const ConstBuffer<T>& buf) { setArgs(pos, buf.get()); }
template<typename T> void setArgs(int pos, const Buffer<T>& buf) { setArgs(pos, buf.get()); }
template<typename T> void setArgs(int pos, const HostAccessBuffer<T>& buf) { setArgs(pos, buf.get()); }
template<typename T> void setArgs(int pos, const T &arg) { ::setArg(kernel.get(), pos, arg); }
template<typename T, typename... Args> void setArgs(int pos, const T &arg, const Args &...tail) {
setArgs(pos, arg);
setArgs(pos + 1, tail...);
}
void run() {
if (kernel) {
queue->run(kernel.get(), groupSize, workSize, name);
} else {
throw std::runtime_error("OpenCL kernel "s + name + " not found");
}
}
};