forked from preda/gpuowl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBuffer.h
129 lines (100 loc) · 3.4 KB
/
Buffer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright (C) Mihai Preda.
#pragma once
#include "clwrap.h"
#include "AllocTrac.h"
#include "Context.h"
#include "Queue.h"
#include <memory>
#include <string>
#include <vector>
template<typename T>
class ConstBuffer {
std::unique_ptr<cl_mem> ptr;
public:
const size_t size{};
const std::string name;
private:
AllocTrac allocTrac;
protected:
ConstBuffer(cl_context context, std::string_view name, unsigned kind, size_t size, const T* ptr = nullptr)
: ptr{makeBuf_(context, kind, size * sizeof(T), ptr)}
, size(size)
, name(name)
, allocTrac(size * sizeof(T))
{}
public:
using type = T;
ConstBuffer() = delete;
ConstBuffer(const Context& context, std::string_view name, unsigned kind, size_t size, const T* ptr = nullptr)
: ConstBuffer(context.get(), name, kind, size, ptr)
{}
ConstBuffer(const Context& context, std::string_view name, const std::vector<T>& vect)
: ConstBuffer(context.get(), name, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR | CL_MEM_HOST_NO_ACCESS, vect.size(), vect.data())
{}
ConstBuffer(ConstBuffer&& rhs) = default;
ConstBuffer& operator=(ConstBuffer&& rhs) {
assert(size == rhs.size);
ptr = std::move(rhs.ptr);
return *this;
}
virtual ~ConstBuffer() = default;
cl_mem get() const { return ptr.get(); }
void reset() { ptr.reset(); }
};
template<typename T>
class Buffer : public ConstBuffer<T> {
protected:
QueuePtr queue;
Buffer(QueuePtr queue, std::string_view name, size_t size, unsigned kind)
: ConstBuffer<T>{getQueueContext(queue->get()), name, kind, size}
, queue{queue}
{}
public:
Buffer(QueuePtr queue, std::string_view name, size_t size)
: Buffer(queue, name, size, CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS) {}
Buffer(Buffer&& rhs) = default;
void zero(size_t len = 0) {
assert(len <= this->size);
T zero = 0;
fillBuf(queue->get(), this->get(), &zero, sizeof(T), (len ? len : this->size) * sizeof(T));
}
void set(T value) {
zero();
fillBuf(queue->get(), this->get(), &value, sizeof(T), sizeof(T));
}
// device-side copy
void operator<<(const ConstBuffer<T>& rhs) {
assert(this->size == rhs.size);
copyBuf(queue->get(), rhs.get(), this->get(), this->size * sizeof(T));
}
};
template<typename T>
class HostAccessBuffer : public Buffer<T> {
public:
// using Buffer<T>::operator=;
using Buffer<T>::operator<<;
HostAccessBuffer(QueuePtr queue, std::string_view name, size_t size)
: Buffer<T>(queue, name, size, CL_MEM_READ_WRITE) {}
// sync read
vector<T> read(size_t sizeOrFull = 0) const {
auto readSize = sizeOrFull ? sizeOrFull : this->size;
assert(readSize <= this->size);
vector<T> ret(readSize);
::read(this->queue->get(), true, this->get(), readSize * sizeof(T), ret.data());
return ret;
}
void readAsync(vector<T>& out, size_t sizeOrFull = 0, size_t start = 0) const {
auto readSize = sizeOrFull ? sizeOrFull : this->size;
assert(readSize <= this->size);
out.resize(readSize);
::read(this->queue->get(), false, this->get(), readSize * sizeof(T), out.data(), start * sizeof(T));
}
// sync write
void write(const vector<T>& vect) {
assert(this->size >= vect.size());
::write(this->queue->get(), true, this->get(), vect.size() * sizeof(T), vect.data());
}
operator vector<T>() const { return read(); }
// async read
// void operator>>(vector<T>& out) const { readAsync(out); }
};