Skip to content

Commit

Permalink
Merge pull request #3 from portugueslab/portablequeue
Browse files Browse the repository at this point in the history
Addressing UNIX problems with Queue.qsize()
  • Loading branch information
vigji committed Feb 17, 2021
2 parents 2592a04 + 184ad20 commit 0537ef7
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 51 deletions.
4 changes: 4 additions & 0 deletions arrayqueues/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
__author__ = "Vilim Stih @portugueslab"

__version__ = "1.2.0"

from arrayqueues.shared_arrays import ArrayQueue

# from arrayqueues.portable_queue import PortableQueue
88 changes: 88 additions & 0 deletions arrayqueues/portable_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""The code below has been entirely taken from the following GitHub gist:
https://gist.github.com/FanchenBao/d8577599c46eab1238a81857bb7277c9
"""

from multiprocessing import Value, get_context, queues


class SharedCounter(object):
"""A synchronized shared counter.
The locking done by multiprocessing.Value ensures that only a single
process or thread may read or write the in-memory ctypes object. However,
in order to do n += 1, Python performs a read followed by a write, so a
second process may read the old value before the new one is written by the
first process. The solution is to use a multiprocessing.Lock to guarantee
the atomicity of the modifications to Value.
This class comes almost entirely from Eli Bendersky's blog:
http://eli.thegreenplace.net/2012/01/04/shared-counter-with-pythons-multiprocessing/
"""

def __init__(self, n=0):
self.count = Value("i", n)

def increment(self, n=1):
""" Increment the counter by n (default = 1) """
with self.count.get_lock():
self.count.value += n

@property
def value(self):
""" Return the value of the counter """
return self.count.value


class PortableQueue(queues.Queue):
"""A portable implementation of multiprocessing.Queue.
Because of multithreading / multiprocessing semantics, Queue.qsize() may
raise the NotImplementedError exception on Unix platforms like Mac OS X
where sem_getvalue() is not implemented. This subclass addresses this
problem by using a synchronized shared counter (initialized to zero) and
increasing / decreasing its value every time the put() and get() methods
are called, respectively. This not only prevents NotImplementedError from
being raised, but also allows us to implement a reliable version of both
qsize() and empty().
"""

def __init__(self, *args, **kwargs):
self.size = SharedCounter(0)
super(PortableQueue, self).__init__(*args, ctx=get_context(), **kwargs)

def __getstate__(self):
state = super(PortableQueue, self).__getstate__()
return state + (self.size,)

def __setstate__(self, state):
(
self._ignore_epipe,
self._maxsize,
self._reader,
self._writer,
self._rlock,
self._wlock,
self._sem,
self._opid,
self.size,
) = state
super(PortableQueue, self)._after_fork()

def put(self, *args, **kwargs):
super(PortableQueue, self).put(*args, **kwargs)
self.size.increment(1)

def get(self, *args, **kwargs):
retrived_val = super(PortableQueue, self).get(*args, **kwargs)
self.size.increment(-1)
return retrived_val

def qsize(self):
""" Reliable implementation of multiprocessing.Queue.qsize() """
return self.size.value

def empty(self):
""" Reliable implementation of multiprocessing.Queue.empty() """
return not self.qsize()
24 changes: 15 additions & 9 deletions arrayqueues/shared_arrays.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from multiprocessing import Queue, Array
import numpy as np
from datetime import datetime
from multiprocessing import Array
from queue import Empty, Full

# except AttributeError:
# from multiprocessing import Queue
import numpy as np

# try:
from arrayqueues.portable_queue import PortableQueue # as Queue


class ArrayView:
def __init__(self, array, max_bytes, dtype, el_shape, i_item=0):
Expand All @@ -26,11 +32,8 @@ def push(self, element):
self.view[self.i_item, ...] = element
i_inserted = self.i_item
self.i_item = (self.i_item + 1) % self.n_items
return (
self.dtype,
self.el_shape,
i_inserted,
) # a tuple is returned to maximise performance
# a tuple is returned to maximise performance
return self.dtype, self.el_shape, i_inserted

def pop(self, i_item):
return self.view[i_item, ...]
Expand All @@ -56,8 +59,8 @@ def __init__(self, max_mbytes=10):
self.maxbytes = int(max_mbytes * 1000000)
self.array = Array("c", self.maxbytes)
self.view = None
self.queue = Queue()
self.read_queue = Queue()
self.queue = PortableQueue()
self.read_queue = PortableQueue()
self.last_item = 0

def check_full(self):
Expand Down Expand Up @@ -116,6 +119,9 @@ def clear(self):
def empty(self):
return self.queue.empty()

def qsize(self):
return self.queue.qsize()


class TimestampedArrayQueue(ArrayQueue):
"""A small extension to support timestamps saved alongside arrays"""
Expand Down
Empty file added arrayqueues/tests/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions arrayqueues/tests/test_portable_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import time
from multiprocessing import Process
from queue import Empty

from arrayqueues.portable_queue import PortableQueue


class SourceProcess(Process):
def __init__(self, n_elements):
super().__init__()
self.n_elements = n_elements
self.source_queue = PortableQueue()

def run(self):
for i in range(self.n_elements):
self.source_queue.put(1)


class SinkProcess(Process):
def __init__(self, source_queue):
super().__init__()
self.source_queue = source_queue

def run(self):
while True:
try:
_ = self.source_queue.get(timeout=0.5)
except Empty:
break


def test_portable_queue():
N_ELEMENTS = 5

p1 = SourceProcess(N_ELEMENTS)
p2 = SinkProcess(source_queue=p1.source_queue)
p1.start()
time.sleep(0.5)
assert p1.source_queue.qsize() == N_ELEMENTS
p2.start()
time.sleep(0.5)
assert p1.source_queue.qsize() == 0
p2.join()
p1.join()
86 changes: 44 additions & 42 deletions arrayqueues/tests/test_shared_arrays.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import time
from multiprocessing import Process
from queue import Empty, Full

import numpy as np

from arrayqueues.shared_arrays import (
ArrayQueue,
TimestampedArrayQueue,
IndexedArrayQueue,
TimestampedArrayQueue,
)
from multiprocessing import Process
import numpy as np
from queue import Empty, Full
import unittest
import time


class SourceProcess(Process):
Expand Down Expand Up @@ -92,39 +93,40 @@ def run(self):
break


class TestSample(unittest.TestCase):
def test_shared_queues(self):
p1 = SourceProcess(100)
p2 = SinkProcess(source_array=p1.source_array)
p1.start()
p2.start()
p1.join()
p2.join()

def test_shared_timestamped_queues(self):
p1 = SourceProcess(100, timestamped=True)
p2 = TimestampedSinkProcess(source_array=p1.source_array)
p1.start()
p2.start()
p1.join()
p2.join()

def test_full_queue(self):
# Here we intentionally overfill the queue to test if the right
# exception is raised
p1 = SourceProcess(40, n_mbytes=0.2, wait=0.1, test_full=True)
p2 = SinkProcess(source_array=p1.source_array, limit=4)
p1.start()
p2.start()
p2.join()
p1.join()

def test_clearing_queue(self):
# Here we intentionally overfill the queue to test if the right
# exception is raised
p1 = SourceProcess(5, n_mbytes=10)
p1.start()
p1.join()
p1.source_array.clear()
time.sleep(1.0)
assert p1.source_array.empty()
def test_sample():
p1 = SourceProcess(100)
p2 = SinkProcess(source_array=p1.source_array)
p1.start()
p2.start()
p1.join()
p2.join()


def test_shared_timestamped_queues():
p1 = SourceProcess(100, timestamped=True)
p2 = TimestampedSinkProcess(source_array=p1.source_array)
p1.start()
p2.start()
p1.join()
p2.join()


def test_full_queue():
# Here we intentionally overfill the queue to test if the right
# exception is raised
# TODO is this actually completed?
p1 = SourceProcess(40, n_mbytes=0.2, wait=0.1, test_full=True)
p2 = SinkProcess(source_array=p1.source_array, limit=4)
p1.start()
p2.start()
p2.join()
p1.join()


def test_clearing_queue():
p1 = SourceProcess(5, n_mbytes=10)
p1.start()
p1.join()
p1.source_array.clear()
time.sleep(1.0)
assert p1.source_array.empty()
24 changes: 24 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[tool.black]
target-version = ['py36', 'py37', 'py38']
skip-string-normalization = false
exclude = '''
(
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
| examples
)/
)
'''

[tool.isort]
multi_line_output = 3
include_trailing_comma = true
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from distutils.core import setup

from setuptools import find_packages

with open("requirements_dev.txt") as f:
Expand Down

0 comments on commit 0537ef7

Please sign in to comment.