-
Notifications
You must be signed in to change notification settings - Fork 1
/
treequeues.py
259 lines (209 loc) · 9.18 KB
/
treequeues.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import multiprocessing as mp
import multiprocessing.managers
from abc import ABC, abstractmethod
from typing import Any
from typing import Dict
from typing import Optional
from typing import TypeVar
import numpy as np
import tree # noqa
T = TypeVar('T')
NestedArray = tree.StructureKV[str, np.ndarray]
class ArrayView:
def __init__(
self,
multiprocessing_array: mp.Array,
numpy_array: np.ndarray,
num_items: int,
):
self.num_items = num_items
self.dtype = numpy_array.dtype
self.shape = (num_items, *numpy_array.shape)
self.nbytes: int = numpy_array.nbytes * num_items
self._item_shape = numpy_array.shape
self._multiprocessing_array = multiprocessing_array
self._array_view = np.frombuffer(
buffer=multiprocessing_array,
dtype=numpy_array.dtype,
count=np.product(self.shape),
).reshape(self.shape)
def __getstate__(self) -> Dict[str, Any]:
state = dict(self.__dict__)
del self.__dict__['_view']
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self._array_view = np.frombuffer(
buffer=self._multiprocessing_array,
dtype=self.dtype,
count=np.product(self.shape),
).reshape(self.shape)
def put(self, item: np.ndarray, index: int) -> None:
assert item.shape == self._item_shape and item.dtype == self.dtype
self._array_view[index, ...] = item
def get(self, index: int) -> np.ndarray:
return np.copy(self._array_view[index, ...])
class AbstractQueue(ABC):
def __init__(self, maxsize: int):
self.maxsize = maxsize
self._queue = mp.Queue(maxsize=maxsize)
@abstractmethod
def put(self, item: T) -> None:
# https://stackoverflow.com/a/42778801
raise NotImplementedError
@abstractmethod
def get(self, block: bool = True, timeout: Optional[float] = None) -> T:
raise NotImplementedError
def empty(self) -> bool:
return self._queue.empty()
def full(self) -> bool:
return self._queue.full()
def qsize(self) -> int:
return self._queue.qsize()
class ArrayQueue(AbstractQueue):
def __init__(self, array: np.ndarray, maxsize: int):
super(ArrayQueue, self).__init__(maxsize=maxsize)
self._lock = mp.Lock()
self._next_index = mp.Value('i', 0)
self.nbytes: int = array.nbytes * maxsize
self._array = mp.Array("c", self.nbytes)
self._array_view = ArrayView(
multiprocessing_array=self._array.get_obj(),
numpy_array=array,
num_items=maxsize,
)
def put(self, array: np.ndarray) -> None:
# Avoid several simultaneous 'put' call
with self._next_index.get_lock():
self._queue.put(self._next_index.value)
# Avoid ArrayQueue changes during a 'put' or 'get' call
with self._lock:
self._put(array=array, index=self._next_index.value)
self._next_index.value = (self._next_index.value + 1) % self.maxsize
def _put(self, array: np.ndarray, index: int) -> None:
self._array_view.put(array, index)
def get(self, block: bool = True, timeout: Optional[float] = None) -> np.ndarray:
index = self._queue.get(block=block, timeout=timeout)
# Avoid ArrayQueue changes during a 'put' or 'get' call
with self._lock:
return self._get(index=index)
def _get(self, index: int) -> np.ndarray:
return self._array_view.get(index=index)
class SimpleTreeQueue(AbstractQueue):
"""TreeQueue implemented with simple locking techniques."""
def __init__(self, nested_array: NestedArray, maxsize: int):
super().__init__(maxsize=maxsize)
self._lock = mp.Lock()
self._next_index = mp.Value('i', 0)
self._nested_queue = tree.map_structure(
lambda array: ArrayQueue(array=array, maxsize=maxsize), nested_array
)
self._nested_array = nested_array
self.nbytes = sum([q.nbytes for q in tree.flatten(self._nested_queue)])
def put(self, nested_array: NestedArray, block: bool = True, timeout: Optional[float] = None) -> None:
# Avoid several simultaneous 'put' call
with self._next_index.get_lock():
self._queue.put(self._next_index.value, block=block, timeout=timeout)
# Avoid ArrayQueue changes during a 'put' or 'get' call
with self._lock:
tree.map_structure(
lambda queue, array: queue._put( # noqa
array=array,
index=self._next_index.value
),
self._nested_queue, nested_array
)
self._next_index.value = (self._next_index.value + 1) % self.maxsize
def get(self, block: bool = True, timeout: Optional[float] = None) -> NestedArray:
index = self._queue.get(block=block, timeout=timeout)
# Avoid ArrayQueue changes during a 'put' or 'get' call
with self._lock:
return tree.map_structure(
lambda queue: queue._get(index=index), # noqa
self._nested_queue
)
class TreeQueue(AbstractQueue):
"""TreeQueue implemented with techniques allowing it to be more efficient when using many
simultaneous processes and threads.
"""
def __init__(self, nested_array: NestedArray, maxsize: int):
super().__init__(maxsize=maxsize)
self._get_lock = mp.Lock()
self._put_lock = mp.Lock()
self._condition = mp.Condition()
self._next_index = mp.Value('i', 0)
self._manager = mp.Manager()
self._active_get_index_dict = self._manager.dict()
self._active_put_index_dict = self._manager.dict()
self._nested_queue = tree.map_structure(
lambda array: ArrayQueue(array=array, maxsize=maxsize), nested_array
)
self._nested_array = nested_array
self.nbytes = sum([q.nbytes for q in tree.flatten(self._nested_queue)])
def put(self, nested_array: NestedArray, block=True, timeout=None) -> None:
with self._put_lock:
index = self._next_index.value
self._next_index.value = (index + 1) % self.maxsize
self.wait_and_add(index, self._active_put_index_dict, self._condition)
while index in self._active_get_index_dict.keys():
with self._condition:
self._condition.wait()
tree.map_structure(
lambda queue, array: queue._put( # noqa
array=array,
index=index
),
self._nested_queue, nested_array
)
# put only in queue the index after being sure that the nested_array is written in the nested queue
self._queue.put(index, block=block, timeout=timeout)
del self._active_put_index_dict[index]
with self._condition:
self._condition.notify_all()
def get(self, block: bool = True, timeout: Optional[float] = None) -> NestedArray:
with self._get_lock:
index = self._queue.get(block=block, timeout=timeout)
# The index in the queue are always index of element that are finished to be transferred,
# we therefore don't need to enquiry if the index is in the put dict.
self.wait_and_add(index, self._active_get_index_dict, self._condition)
nested_array = tree.map_structure(
lambda queue: queue._get(index=index), # noqa
self._nested_queue
)
del self._active_get_index_dict[index]
with self._condition:
self._condition.notify_all()
return nested_array
@staticmethod
def wait_and_add(
index: int,
dictionary: Dict[int, bool],
condition: multiprocessing.Condition,
) -> None:
""" The following code make will acquire lock, test if it is in the get dictionary, then,
it'll either add it if it is not, else will wait for the next notify.
This is an equivalent to multiprocessing.Condition().wait_for() except that the lock is held
to do an action, in that case adding an entry to the dictionary.
References for the code:
Acquire lock inside a try-finally block:
https://stackoverflow.com/a/14137638
While loop with condition until predicate is met (and snippet condition with-wait):
https://stackoverflow.com/a/23116848
Equivalent while True and while loop:
https://stackoverflow.com/a/27512815
multiprocessing.Condition().wait_for():
https://docs.python.org/3/library/threading.html#threading.Condition.wait_for
Args:
index: index that will be tested for its presence in the dict before being added
dictionary: multiprocessing dict
condition: multiprocessing condition
"""
while True:
condition.acquire()
try:
if index not in dictionary.keys():
dictionary[index] = True
break
condition.wait()
finally:
condition.release()