diff --git a/docs/faq.rst b/docs/faq.rst index 132b4955d..f70ef5528 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -3,7 +3,7 @@ Frequently Asked Questions - Is Pika thread safe? - Pika does not have any notion of threading in the code. If you want to use Pika with threading, make sure you have a Pika connection per thread, created in that thread. It is not safe to share one Pika connection across threads. + Pika does not have any notion of threading in the code. If you want to use Pika with threading, make sure you have a Pika connection per thread, created in that thread. It is not safe to share one Pika connection across threads, with one exception: you may call the connection method `add_callback_threadsafe` from another thread to schedule a callback within an active pika connection. - How do I report a bug with Pika? diff --git a/pika/adapters/asyncio_connection.py b/pika/adapters/asyncio_connection.py index b2ac57050..57cb2d443 100644 --- a/pika/adapters/asyncio_connection.py +++ b/pika/adapters/asyncio_connection.py @@ -20,6 +20,17 @@ def __init__(self, loop): self.readers = set() self.writers = set() + def close(self): + """Release ioloop's resources. + + This method is intended to be called by the application or test code + only after the ioloop's outermost `start()` call returns. After calling + `close()`, no other interaction with the closed instance of ioloop + should be performed. + + """ + self.loop.close() + def add_timeout(self, deadline, callback_method): """Add the callback_method to the EventLoop timer to fire after deadline seconds. Returns a Handle to the timeout. @@ -41,6 +52,28 @@ def remove_timeout(handle): """ return handle.cancel() + def add_callback_threadsafe(self, callback): + """Requests a call to the given function as soon as possible in the + context of this IOLoop's thread. + + NOTE: This is the only thread-safe method offered by the IOLoop adapter. + All other manipulations of the IOLoop adapter and its parent connection + must be performed from the connection's thread. + + For example, a thread may request a call to the + `channel.basic_ack` method of a connection that is running in a + different thread via + + ``` + connection.add_callback_threadsafe( + functools.partial(channel.basic_ack, delivery_tag=...)) + ``` + + :param method callback: The callback method; must be callable. + + """ + self.loop.call_soon_threadsafe(callback) + def add_handler(self, fd, cb, event_state): """ Registers the given handler to receive the given events for ``fd``. diff --git a/pika/adapters/base_connection.py b/pika/adapters/base_connection.py index d6e84077e..cc3836d03 100644 --- a/pika/adapters/base_connection.py +++ b/pika/adapters/base_connection.py @@ -135,6 +135,32 @@ def remove_timeout(self, timeout_id): """ self.ioloop.remove_timeout(timeout_id) + def add_callback_threadsafe(self, callback): + """Requests a call to the given function as soon as possible in the + context of this connection's IOLoop thread. + + NOTE: This is the only thread-safe method offered by the connection. All + other manipulations of the connection must be performed from the + connection's thread. + + For example, a thread may request a call to the + `channel.basic_ack` method of a connection that is running in a + different thread via + + ``` + connection.add_callback_threadsafe( + functools.partial(channel.basic_ack, delivery_tag=...)) + ``` + + :param method callback: The callback method; must be callable. + + """ + if not callable(callback): + raise TypeError( + 'callback must be a callable, but got %r' % (callback,)) + + self.ioloop.add_callback_threadsafe(callback) + def _adapter_connect(self): """Connect to the RabbitMQ broker, returning True if connected. diff --git a/pika/adapters/blocking_connection.py b/pika/adapters/blocking_connection.py index 4f7751f29..29701c287 100644 --- a/pika/adapters/blocking_connection.py +++ b/pika/adapters/blocking_connection.py @@ -10,6 +10,9 @@ classes. """ +# Suppress too-many-lines +# pylint: disable=C0302 + # Disable "access to protected member warnings: this wrapper implementation is # a friend of those instances # pylint: disable=W0212 @@ -155,7 +158,7 @@ def elements(self): with `append_element` """ assert self._ready, '_CallbackResult was not set' - assert isinstance(self._values, list) and len(self._values) > 0, ( + assert isinstance(self._values, list) and self._values, ( '_CallbackResult value is incompatible with append_element: %r' % (self._values,)) @@ -378,7 +381,7 @@ def __repr__(self): def _cleanup(self): """Clean up members that might inhibit garbage collection""" - self._impl.ioloop.deactivate_poller() + self._impl.ioloop.close() self._ready_events.clear() self._opened_result.reset() self._open_error_result.reset() @@ -525,6 +528,18 @@ def _on_timer_ready(self, evt): """ self._ready_events.append(evt) + def _on_threadsafe_callback(self, user_callback): + """Handle callback that was registered via `add_callback_threadsafe`. + + :param user_callback: callback passed to `add_callback_threadsafe` by + the application. + + """ + # Turn it into a 0-delay timeout to take advantage of our existing logic + # that deals with reentrancy + self.add_timeout(0, user_callback) + + def _on_connection_blocked(self, user_callback, method_frame): """Handle Connection.Blocked notification from RabbitMQ broker @@ -632,6 +647,29 @@ def add_timeout(self, deadline, callback_method): return timer_id + def add_callback_threadsafe(self, callback): + """Requests a call to the given function as soon as possible in the + context of this connection's thread. + + NOTE: This is the only thread-safe method in `BlockingConnection`. All + other manipulations of `BlockingConnection` must be performed from the + connection's thread. + + For example, a thread may request a call to the + `BlockingChannel.basic_ack` method of a `BlockingConnection` that is + running in a different thread via + + ``` + connection.add_callback_threadsafe( + functools.partial(channel.basic_ack, delivery_tag=...)) + ``` + + :param method callback: The callback method; must be callable + + """ + self._impl.add_callback_threadsafe( + functools.partial(self._on_threadsafe_callback, callback)) + def remove_timeout(self, timeout_id): """Remove a timer if it's still in the timeout stack @@ -874,7 +912,7 @@ class _ConsumerCancellationEvt(_ChannelPendingEvt): `Basic.Cancel` """ - __slots__ = 'method_frame' + __slots__ = ('method_frame',) def __init__(self, method_frame): """ @@ -1798,7 +1836,8 @@ def consume(self, queue, no_ack=False, """Blocking consumption of a queue instead of via a callback. This method is a generator that yields each message as a tuple of method, properties, and body. The active generator iterator terminates when the - consumer is cancelled by client or broker. + consumer is cancelled by client via `BlockingChannel.cancel()` or by + broker. Example: @@ -2398,7 +2437,8 @@ def queue_declare(self, queue='', passive=False, durable=False, :param queue: The queue name :type queue: str or unicode; if empty string, the broker will create a unique queue name; - :param bool passive: Only check to see if the queue exists + :param bool passive: Only check to see if the queue exists and raise + `ChannelClosed` if it doesn't; :param bool durable: Survive reboots of the broker :param bool exclusive: Only allow access by the current connection :param bool auto_delete: Delete after consumer cancels or disconnects diff --git a/pika/adapters/select_connection.py b/pika/adapters/select_connection.py index bbe9e09f6..f6970d53b 100644 --- a/pika/adapters/select_connection.py +++ b/pika/adapters/select_connection.py @@ -3,6 +3,7 @@ """ import abc +import collections import errno import functools import heapq @@ -11,8 +12,6 @@ import time import threading -from collections import defaultdict - import pika.compat from pika.adapters.base_connection import BaseConnection @@ -163,6 +162,16 @@ def __init__(self): # collection of canceled timeouts self._num_cancellations = 0 + def close(self): + """Release resources. Don't use the `_Timer` instance after closing + it + """ + # Eliminate potential reference cycles to aid garbage-collection + if self._timeout_heap is not None: + for timeout in self._timeout_heap: + timeout.callback = None + self._timeout_heap = None + def call_later(self, delay, callback): """Schedule a one-shot timeout given delay seconds. @@ -282,9 +291,28 @@ class IOLoop(object): def __init__(self): self._timer = _Timer() - self._poller = self._get_poller(self._timer.get_remaining_interval, + # Callbacks requested via `add_callback` + self._callbacks = collections.deque() + + # Identity of this IOLoop's thread + self._thread_id = None + + self._poller = self._get_poller(self._get_remaining_interval, self.process_timeouts) + def close(self): + """Release IOLoop's resources. + + `IOLoop.close` is intended to be called by the application or test code + only after `IOLoop.start()` returns. After calling `close()`, no other + interaction with the closed instance of `IOLoop` should be performed. + + """ + if self._callbacks is not None: + self._poller.close() + self._timer.close() + self._callbacks = None + @staticmethod def _get_poller(get_wait_seconds, process_timeouts): """Determine the best poller to use for this environment and instantiate @@ -346,13 +374,57 @@ def remove_timeout(self, timeout_id): """ self._timer.remove_timeout(timeout_id) + def add_callback_threadsafe(self, callback): + """Requests a call to the given function as soon as possible in the + context of this IOLoop's thread. + + NOTE: This is the only thread-safe method in IOLoop. All other + manipulations of IOLoop must be performed from the IOLoop's thread. + + For example, a thread may request a call to the `stop` method of an + ioloop that is running in a different thread via + `ioloop.add_callback_threadsafe(ioloop.stop)` + + :param method callback: The callback method + + """ + if not callable(callback): + raise TypeError( + 'callback must be a callable, but got %r' % (callback,)) + + # NOTE: `deque.append` is atomic + self._callbacks.append(callback) + if threading.current_thread().ident != self._thread_id: + # Wake up the IOLoop running in another thread + self._poller.wake_threadsafe() + + LOGGER.debug('add_callback_threadsafe: added callback=%r', callback) + def process_timeouts(self): - """[Extension] Process pending timeouts, invoking callbacks for those + """[Extension] Process pending callbacks and timeouts, invoking those whose time has come. Internal use only. """ + # Avoid I/O starvation by postponing new callbacks to the next iteration + for _ in pika.compat.xrange(len(self._callbacks)): + self._callbacks.popleft()() + self._timer.process_timeouts() + def _get_remaining_interval(self): + """Get the remaining interval to the next callback or timeout + expiration. + + :returns: non-negative number of seconds until next callback or timer + expiration; None if there are no callbacks and timers + :rtype: float + + """ + if self._callbacks: + return 0 + + return self._timer.get_remaining_interval() + def add_handler(self, fileno, handler, events): """[API] Add a new fileno to the set to be monitored @@ -385,20 +457,31 @@ def start(self): exit. See `IOLoop.stop`. """ + self._thread_id = threading.current_thread().ident self._poller.start() def stop(self): """[API] Request exit from the ioloop. The loop is NOT guaranteed to - stop before this method returns. This is the only method that may be - called from another thread. + stop before this method returns. + + To invoke `stop()` safely from a thread other than this IOLoop's thread, + call it via `add_callback_threadsafe`; e.g., + + `ioloop.add_callback_threadsafe(ioloop.stop)` """ + if (self._thread_id is not None and + threading.current_thread().ident != self._thread_id): + LOGGER.warning('Use add_callback_threadsafe to request ' + 'ioloop.stop() from another thread') + self._poller.stop() def activate_poller(self): """[Extension] Activate the poller """ + self._thread_id = threading.current_thread().ident self._poller.activate_poller() def deactivate_poller(self): @@ -442,6 +525,10 @@ def __init__(self, get_wait_seconds, process_timeouts): self._get_wait_seconds = get_wait_seconds self._process_timeouts = process_timeouts + # We guard access to the waking file descriptors to avoid races from + # closing them while another thread is calling our `wake()` method. + self._waking_mutex = threading.Lock() + # fd-to-handler function mappings self._fd_handlers = dict() @@ -455,14 +542,60 @@ def __init__(self, get_wait_seconds, process_timeouts): self._stopping = False - # Mutex for controlling critical sections where ioloop-interrupt sockets - # are created, used, and destroyed. Needed in case `stop()` is called - # from a thread. - self._mutex = threading.Lock() + # Create ioloop-interrupt socket pair and register read handler. + self._r_interrupt, self._w_interrupt = self._get_interrupt_pair() + self.add_handler(self._r_interrupt.fileno(), self._read_interrupt, READ) + + def close(self): + """Release poller's resources. + + `close()` is intended to be called after the poller's `start()` method + returns. After calling `close()`, no other interaction with the closed + poller instance should be performed. + + """ + # Unregister and close ioloop-interrupt socket pair; mutual exclusion is + # necessary to avoid race condition with `wake_threadsafe` executing in + # another thread's context + assert self._start_nesting_levels == 0, \ + 'Cannot call close() before start() unwinds.' + + with self._waking_mutex: + if self._w_interrupt is not None: + self.remove_handler(self._r_interrupt.fileno()) # pylint: disable=E1101 + self._r_interrupt.close() + self._r_interrupt = None + self._w_interrupt.close() + self._w_interrupt = None + + self.deactivate_poller() + + self._fd_handlers = None + self._fd_events = None + self._processing_fd_event_map = None + + def wake_threadsafe(self): + """Wake up the poller as soon as possible. As the name indicates, this + method is thread-safe. + + """ + with self._waking_mutex: + if self._w_interrupt is None: + return + + try: + # Send byte to interrupt the poll loop, use send() instead of + # os.write for Windows compatibility + self._w_interrupt.send(b'X') + except pika.compat.SOCKET_ERROR as err: + if err.errno != errno.EWOULDBLOCK: + raise + except Exception as err: + # There's nothing sensible to do here, we'll exit the interrupt + # loop after POLL_TIMEOUT secs in worst case anyway. + LOGGER.warning("Failed to send interrupt to poller: %s", err) + raise - # ioloop-interrupt socket pair; initialized in start() - self._r_interrupt = None - self._w_interrupt = None def _get_max_wait(self): """Get the interval to the next timeout event, or a default interval @@ -557,7 +690,7 @@ def activate_poller(self): """ # Activate the underlying poller and register current events self._init_poller() - fd_to_events = defaultdict(int) + fd_to_events = collections.defaultdict(int) for event, file_descriptors in self._fd_events.items(): for fileno in file_descriptors: fd_to_events[fileno] |= event @@ -579,22 +712,10 @@ def start(self): if self._start_nesting_levels == 1: LOGGER.debug('Entering IOLoop') - self._stopping = False # Activate the underlying poller and register current events self.activate_poller() - # Create ioloop-interrupt socket pair and register read handler. - # NOTE: we defer their creation because some users (e.g., - # BlockingConnection adapter) don't use the event loop and these - # sockets would get reported as leaks - with self._mutex: - assert self._r_interrupt is None - self._r_interrupt, self._w_interrupt = self._get_interrupt_pair( - ) - self.add_handler(self._r_interrupt.fileno(), - self._read_interrupt, READ) - else: LOGGER.debug('Reentering IOLoop at nesting level=%s', self._start_nesting_levels) @@ -609,46 +730,26 @@ def start(self): self._start_nesting_levels -= 1 if self._start_nesting_levels == 0: - LOGGER.debug('Cleaning up IOLoop') - # Unregister and close ioloop-interrupt socket pair - with self._mutex: - self.remove_handler(self._r_interrupt.fileno()) - self._r_interrupt.close() - self._r_interrupt = None - self._w_interrupt.close() - self._w_interrupt = None - - # Deactivate the underlying poller - self.deactivate_poller() + try: + LOGGER.debug('Deactivating poller') + + # Deactivate the underlying poller + self.deactivate_poller() + finally: + self._stopping = False else: LOGGER.debug('Leaving IOLoop with %s nesting levels remaining', self._start_nesting_levels) def stop(self): """Request exit from the ioloop. The loop is NOT guaranteed to stop - before this method returns. This is the only method that may be called - from another thread. + before this method returns. """ LOGGER.debug('Stopping IOLoop') self._stopping = True - with self._mutex: - if self._w_interrupt is None: - return - - try: - # Send byte to interrupt the poll loop, use send() instead of - # os.write for Windows compatibility - self._w_interrupt.send(b'X') - except pika.compat.SOCKET_ERROR as err: - if err.errno != errno.EWOULDBLOCK: - raise - except Exception as err: - # There's nothing sensible to do here, we'll exit the interrupt - # loop after POLL_TIMEOUT secs in worst case anyway. - LOGGER.warning("Failed to send ioloop interrupt: %s", err) - raise + self.wake_threadsafe() @abc.abstractmethod def poll(self): @@ -796,7 +897,7 @@ def poll(self): # Build an event bit mask for each fileno we've received an event for - fd_event_map = defaultdict(int) + fd_event_map = collections.defaultdict(int) for fd_set, evt in zip((read, write, error), (READ, WRITE, ERROR)): for fileno in fd_set: fd_event_map[fileno] |= evt @@ -855,9 +956,8 @@ class KQueuePoller(_PollerBase): def __init__(self, get_wait_seconds, process_timeouts): """Create an instance of the KQueuePoller """ - super(KQueuePoller, self).__init__(get_wait_seconds, process_timeouts) - self._kqueue = None + super(KQueuePoller, self).__init__(get_wait_seconds, process_timeouts) @staticmethod def _map_event(kevent): @@ -893,7 +993,7 @@ def poll(self): else: raise - fd_event_map = defaultdict(int) + fd_event_map = collections.defaultdict(int) for event in kevents: fd_event_map[event.ident] |= self._map_event(event) @@ -907,8 +1007,9 @@ def _init_poller(self): def _uninit_poller(self): """Notify the implementation to release the poller resource""" - self._kqueue.close() - self._kqueue = None + if self._kqueue is not None: + self._kqueue.close() + self._kqueue = None def _register_fd(self, fileno, events): """The base class invokes this method to notify the implementation to @@ -1012,7 +1113,7 @@ def poll(self): else: raise - fd_event_map = defaultdict(int) + fd_event_map = collections.defaultdict(int) for fileno, event in events: fd_event_map[fileno] |= event @@ -1026,10 +1127,11 @@ def _init_poller(self): def _uninit_poller(self): """Notify the implementation to release the poller resource""" - if hasattr(self._poll, "close"): - self._poll.close() + if self._poll is not None: + if hasattr(self._poll, "close"): + self._poll.close() - self._poll = None + self._poll = None def _register_fd(self, fileno, events): """The base class invokes this method to notify the implementation to diff --git a/pika/adapters/tornado_connection.py b/pika/adapters/tornado_connection.py index ce407d1f8..db34dfd99 100644 --- a/pika/adapters/tornado_connection.py +++ b/pika/adapters/tornado_connection.py @@ -94,3 +94,29 @@ def remove_timeout(self, timeout_id): """ return self.ioloop.remove_timeout(timeout_id) + + def add_callback_threadsafe(self, callback): + """Requests a call to the given function as soon as possible in the + context of this connection's IOLoop thread. + + NOTE: This is the only thread-safe method offered by the connection. All + other manipulations of the connection must be performed from the + connection's thread. + + For example, a thread may request a call to the + `channel.basic_ack` method of a connection that is running in a + different thread via + + ``` + connection.add_callback_threadsafe( + functools.partial(channel.basic_ack, delivery_tag=...)) + ``` + + :param method callback: The callback method; must be callable. + + """ + if not callable(callback): + raise TypeError( + 'callback must be a callable, but got %r' % (callback,)) + + self.ioloop.add_callback(callback) diff --git a/pika/adapters/twisted_connection.py b/pika/adapters/twisted_connection.py index 053ddf6e3..1dac51f44 100644 --- a/pika/adapters/twisted_connection.py +++ b/pika/adapters/twisted_connection.py @@ -225,6 +225,28 @@ def remove_timeout(self, call): """ call.cancel() + def add_callback_threadsafe(self, callback): + """Requests a call to the given function as soon as possible in the + context of this IOLoop's thread. + + NOTE: This is the only thread-safe method offered by the IOLoop adapter. + All other manipulations of the IOLoop adapter and its parent connection + must be performed from the connection's thread. + + For example, a thread may request a call to the + `channel.basic_ack` method of a connection that is running in a + different thread via + + ``` + connection.add_callback_threadsafe( + functools.partial(channel.basic_ack, delivery_tag=...)) + ``` + + :param method callback: The callback method; must be callable. + + """ + self.reactor.callFromThread(callback) + def stop(self): # Guard against stopping the reactor multiple times if not self.started: diff --git a/pika/connection.py b/pika/connection.py index e371a03ac..0c4e2a7f0 100644 --- a/pika/connection.py +++ b/pika/connection.py @@ -1238,8 +1238,11 @@ def close(self, reply_code=200, reply_text='Normal shutdown'): LOGGER.warning('Suppressing close request on %s', self) return + # NOTE The connection is either in opening or open state + # Initiate graceful closing of channels that are OPEN or OPENING - self._close_channels(reply_code, reply_text) + if self._channels: + self._close_channels(reply_code, reply_text) # Set our connection state self._set_connection_state(self.CONNECTION_CLOSING) diff --git a/tests/acceptance/async_adapter_tests.py b/tests/acceptance/async_adapter_tests.py index edb27b941..53004f69f 100644 --- a/tests/acceptance/async_adapter_tests.py +++ b/tests/acceptance/async_adapter_tests.py @@ -10,6 +10,8 @@ # Suppress pylint warning about unused argument # pylint: disable=W0613 +import functools +import threading import time import uuid @@ -460,6 +462,114 @@ def on_closed(self, connection, reply_code, reply_text): reply_text) +class TestAddCallbackThreadsafeRequestBeforeIOLoopStarts(AsyncTestCase, AsyncAdapters): + DESCRIPTION = "Test add_callback_threadsafe request before ioloop starts." + + def _run_ioloop(self, *args, **kwargs): # pylint: disable=W0221 + """We intercept this method from AsyncTestCase in order to call + add_callback_threadsafe before AsyncTestCase starts the ioloop. + + """ + self.my_start_time = time.time() + # Request a callback from our current (ioloop's) thread + self.connection.add_callback_threadsafe(self.on_requested_callback) + + return super( + TestAddCallbackThreadsafeRequestBeforeIOLoopStarts, self)._run_ioloop( + *args, **kwargs) + + def start(self, *args, **kwargs): # pylint: disable=W0221 + self.loop_thread_ident = threading.current_thread().ident + self.my_start_time = None + self.got_callback = False + super(TestAddCallbackThreadsafeRequestBeforeIOLoopStarts, self).start(*args, **kwargs) + self.assertTrue(self.got_callback) + + def begin(self, channel): + self.stop() + + def on_requested_callback(self): + self.assertEqual(threading.current_thread().ident, + self.loop_thread_ident) + self.assertLess(time.time() - self.my_start_time, 0.25) + self.got_callback = True + + +class TestAddCallbackThreadsafeFromIOLoopThread(AsyncTestCase, AsyncAdapters): + DESCRIPTION = "Test add_callback_threadsafe request from same thread." + + def start(self, *args, **kwargs): + self.loop_thread_ident = threading.current_thread().ident + self.my_start_time = None + self.got_callback = False + super(TestAddCallbackThreadsafeFromIOLoopThread, self).start(*args, **kwargs) + self.assertTrue(self.got_callback) + + def begin(self, channel): + self.my_start_time = time.time() + # Request a callback from our current (ioloop's) thread + channel.connection.add_callback_threadsafe(self.on_requested_callback) + + def on_requested_callback(self): + self.assertEqual(threading.current_thread().ident, + self.loop_thread_ident) + self.assertLess(time.time() - self.my_start_time, 0.25) + self.got_callback = True + self.stop() + + +class TestAddCallbackThreadsafeFromAnotherThread(AsyncTestCase, AsyncAdapters): + DESCRIPTION = "Test add_callback_threadsafe request from another thread." + + def start(self, *args, **kwargs): + self.loop_thread_ident = threading.current_thread().ident + self.my_start_time = None + self.got_callback = False + super(TestAddCallbackThreadsafeFromAnotherThread, self).start(*args, **kwargs) + self.assertTrue(self.got_callback) + + def begin(self, channel): + self.my_start_time = time.time() + # Request a callback from ioloop while executing in another thread + timer = threading.Timer( + 0, + lambda: channel.connection.add_callback_threadsafe( + self.on_requested_callback)) + self.addCleanup(timer.cancel) + timer.start() + + def on_requested_callback(self): + self.assertEqual(threading.current_thread().ident, + self.loop_thread_ident) + self.assertLess(time.time() - self.my_start_time, 0.25) + self.got_callback = True + self.stop() + + +class TestIOLoopStopBeforeIOLoopStarts(AsyncTestCase, AsyncAdapters): + DESCRIPTION = "Test ioloop.stop() before ioloop starts causes ioloop to exit quickly." + + def _run_ioloop(self, *args, **kwargs): # pylint: disable=W0221 + """We intercept this method from AsyncTestCase in order to call + ioloop.stop() before AsyncTestCase starts the ioloop. + """ + # Request ioloop to stop before it starts + self.my_start_time = time.time() + self.stop_ioloop_only() + + return super( + TestIOLoopStopBeforeIOLoopStarts, self)._run_ioloop(*args, **kwargs) + + def start(self, *args, **kwargs): # pylint: disable=W0221 + self.loop_thread_ident = threading.current_thread().ident + self.my_start_time = None + super(TestIOLoopStopBeforeIOLoopStarts, self).start(*args, **kwargs) + self.assertLess(time.time() - self.my_start_time, 0.25) + + def begin(self, channel): + pass + + class TestViabilityOfMultipleTimeoutsWithSameDeadlineAndCallback(AsyncTestCase, AsyncAdapters): # pylint: disable=C0103 DESCRIPTION = "Test viability of multiple timeouts with same deadline and callback" diff --git a/tests/acceptance/async_test_base.py b/tests/acceptance/async_test_base.py index 33f5d3dad..6a80b0c9b 100644 --- a/tests/acceptance/async_test_base.py +++ b/tests/acceptance/async_test_base.py @@ -68,34 +68,55 @@ def begin(self, channel): # pylint: disable=R0201,W0613 """Extend to start the actual tests on the channel""" self.fail("AsyncTestCase.begin_test not extended") - def start(self, adapter=None): + def start(self, adapter, ioloop_factory): self.logger.info('start at %s', datetime.datetime.utcnow()) self.adapter = adapter or self.ADAPTER - self.connection = self.adapter(self.parameters, self.on_open, - self.on_open_error, self.on_closed) - self.timeout = self.connection.add_timeout(self.TIMEOUT, - self.on_timeout) - self.connection.ioloop.start() - self.assertFalse(self._timed_out) + self.connection = self.adapter(self.parameters, + self.on_open, + self.on_open_error, + self.on_closed, + custom_ioloop=ioloop_factory()) + try: + self.timeout = self.connection.add_timeout(self.TIMEOUT, + self.on_timeout) + self._run_ioloop() + self.assertFalse(self._timed_out) + finally: + self.connection.ioloop.close() + self.connection = None + + def stop_ioloop_only(self): + """Request stopping of the connection's ioloop to end the test without + closing the connection + """ + self._safe_remove_test_timeout() + self.connection.ioloop.stop() def stop(self): """close the connection and stop the ioloop""" self.logger.info("Stopping test") - if self.timeout is not None: - self.connection.remove_timeout(self.timeout) - self.timeout = None - self.connection.close() + self._safe_remove_test_timeout() + self.connection.close() # NOTE: on_closed() will stop the ioloop + + def _run_ioloop(self): + """Some tests need to subclass this in order to bootstrap their test + logic after we instantiate the connection and assign it to + `self.connection`, but before we run the ioloop + """ + self.connection.ioloop.start() - def _stop(self): + def _safe_remove_test_timeout(self): if hasattr(self, 'timeout') and self.timeout is not None: self.logger.info("Removing timeout") self.connection.remove_timeout(self.timeout) self.timeout = None - if hasattr(self, 'connection') and self.connection: + + def _stop(self): + self._safe_remove_test_timeout() + if hasattr(self, 'connection') and self.connection is not None: self.logger.info("Stopping ioloop") self.connection.ioloop.stop() - self.connection = None def on_closed(self, connection, reply_code, reply_text): """called when the connection has finished closing""" @@ -124,12 +145,12 @@ def on_timeout(self): class BoundQueueTestCase(AsyncTestCase): - def start(self, adapter=None): + def start(self, adapter, ioloop_factory): # PY3 compat encoding self.exchange = 'e-' + self.__class__.__name__ + ':' + uuid.uuid1().hex self.queue = 'q-' + self.__class__.__name__ + ':' + uuid.uuid1().hex self.routing_key = self.__class__.__name__ - super(BoundQueueTestCase, self).start(adapter) + super(BoundQueueTestCase, self).start(adapter, ioloop_factory) def begin(self, channel): self.channel.exchange_declare(self.on_exchange_declared, self.exchange, @@ -164,20 +185,31 @@ def on_ready(self, frame): class AsyncAdapters(object): - def start(self, adapter_class): + def start(self, adapter_class, ioloop_factory): + """ + + :param adapter_class: pika connection adapter class to test. + :param ioloop_factory: to be called without args to instantiate a + non-shared ioloop to be passed as the `custom_ioloop` arg to the + `adapter_class` constructor. This is needed because some of the + adapters default to using a singleton ioloop, which results in + tests errors after prior tests close the ioloop to release resources, + in order to eliminate ResourceWarning warnings concerning unclosed + sockets from our adapters. + :return: + """ raise NotImplementedError def select_default_test(self): """SelectConnection:DefaultPoller""" - with mock.patch.multiple(select_connection, SELECT_TYPE=None): - self.start(adapters.SelectConnection) + self.start(adapters.SelectConnection, select_connection.IOLoop) def select_select_test(self): """SelectConnection:select""" with mock.patch.multiple(select_connection, SELECT_TYPE='select'): - self.start(adapters.SelectConnection) + self.start(adapters.SelectConnection, select_connection.IOLoop) @unittest.skipIf( not hasattr(select, 'poll') or @@ -186,27 +218,36 @@ def select_poll_test(self): """SelectConnection:poll""" with mock.patch.multiple(select_connection, SELECT_TYPE='poll'): - self.start(adapters.SelectConnection) + self.start(adapters.SelectConnection, select_connection.IOLoop) @unittest.skipIf(not hasattr(select, 'epoll'), "epoll not supported") def select_epoll_test(self): """SelectConnection:epoll""" with mock.patch.multiple(select_connection, SELECT_TYPE='epoll'): - self.start(adapters.SelectConnection) + self.start(adapters.SelectConnection, select_connection.IOLoop) @unittest.skipIf(not hasattr(select, 'kqueue'), "kqueue not supported") def select_kqueue_test(self): """SelectConnection:kqueue""" with mock.patch.multiple(select_connection, SELECT_TYPE='kqueue'): - self.start(adapters.SelectConnection) + self.start(adapters.SelectConnection, select_connection.IOLoop) def tornado_test(self): """TornadoConnection""" - self.start(adapters.TornadoConnection) + ioloop_factory = None + if adapters.TornadoConnection is not None: + import tornado.ioloop + ioloop_factory = tornado.ioloop.IOLoop + self.start(adapters.TornadoConnection, ioloop_factory) @unittest.skipIf(sys.version_info < (3, 4), "Asyncio available for Python 3.4+") def asyncio_test(self): """AsyncioConnection""" - self.start(adapters.AsyncioConnection) + ioloop_factory = None + if adapters.AsyncioConnection is not None: + import asyncio + ioloop_factory = asyncio.new_event_loop + + self.start(adapters.AsyncioConnection, ioloop_factory) diff --git a/tests/acceptance/blocking_adapter_test.py b/tests/acceptance/blocking_adapter_test.py index 7a7e0ce39..c27be2efd 100644 --- a/tests/acceptance/blocking_adapter_test.py +++ b/tests/acceptance/blocking_adapter_test.py @@ -1,7 +1,9 @@ """blocking adapter test""" from datetime import datetime +import functools import logging import socket +import threading import time import unittest import uuid @@ -449,6 +451,48 @@ def test(self): 'Blocked connection timeout expired')) +class TestAddCallbackThreadsafeFromSameThread(BlockingTestCaseBase): + + def test(self): + """BlockingConnection.add_callback_threadsafe from same thread""" + connection = self._connect() + + # Test timer completion + start_time = time.time() + rx_callback = [] + connection.add_callback_threadsafe( + lambda: rx_callback.append(time.time())) + while not rx_callback: + connection.process_data_events(time_limit=None) + + self.assertEqual(len(rx_callback), 1) + elapsed = time.time() - start_time + self.assertLess(elapsed, 0.25) + + +class TestAddCallbackThreadsafeFromAnotherThread(BlockingTestCaseBase): + + def test(self): + """BlockingConnection.add_callback_threadsafe from another thread""" + connection = self._connect() + + # Test timer completion + start_time = time.time() + rx_callback = [] + timer = threading.Timer( + 0, + functools.partial(connection.add_callback_threadsafe, + lambda: rx_callback.append(time.time()))) + self.addCleanup(timer.cancel) + timer.start() + while not rx_callback: + connection.process_data_events(time_limit=None) + + self.assertEqual(len(rx_callback), 1) + elapsed = time.time() - start_time + self.assertLess(elapsed, 0.25) + + class TestAddTimeoutRemoveTimeout(BlockingTestCaseBase): def test(self): @@ -1495,7 +1539,7 @@ def test(self): # pylint: disable=R0914 mandatory=True) self.assertEqual(res, True) - # Flush channel to force Basic.Return + # Flush connection to force Basic.Return connection.channel().close() # Deposit a routable message in the queue @@ -1654,7 +1698,7 @@ def test(self): # pylint: disable=R0914,R0915 queue=q_name, expected_count=0) - # Attempt to cosume again with a short timeout + # Attempt to consume again with a short timeout connection.process_data_events(time_limit=0.005) self.assertEqual(len(rx_messages), 2) @@ -1669,6 +1713,197 @@ def test(self): # pylint: disable=R0914,R0915 self.assertEqual(frame.method.consumer_tag, consumer_tag) +class TestBasicConsumeWithAckFromAnotherThread(BlockingTestCaseBase): + + def test(self): # pylint: disable=R0914,R0915 + """BlockingChannel.basic_consume with ack from another thread and \ + requesting basic_ack via add_callback_threadsafe + """ + # This test simulates processing of a message on another thread and + # then requesting an ACK to be dispatched on the connection's thread + # via BlockingConnection.add_callback_threadsafe + + connection = self._connect() + + ch = connection.channel() + + q_name = 'TestBasicConsumeWithAckFromAnotherThread_q' + uuid.uuid1().hex + exg_name = ('TestBasicConsumeWithAckFromAnotherThread_exg' + + uuid.uuid1().hex) + routing_key = 'TestBasicConsumeWithAckFromAnotherThread' + + # Place channel in publisher-acknowledgments mode so that publishing + # with mandatory=True will be synchronous (for convenience) + res = ch.confirm_delivery() + self.assertIsNone(res) + + # Declare a new exchange + ch.exchange_declare(exg_name, exchange_type='direct') + self.addCleanup(connection.channel().exchange_delete, exg_name) + + # Declare a new queue + ch.queue_declare(q_name, auto_delete=True) + self.addCleanup(self._connect().channel().queue_delete, q_name) + + # Bind the queue to the exchange using routing key + ch.queue_bind(q_name, exchange=exg_name, routing_key=routing_key) + + # Publish 2 messages with mandatory=True for synchronous processing + ch.publish(exg_name, routing_key, body='msg1', mandatory=True) + ch.publish(exg_name, routing_key, body='last-msg', mandatory=True) + + # Configure QoS for one message so that the 2nd message will be + # delivered only after the 1st one is ACKed + ch.basic_qos(prefetch_size=0, prefetch_count=1, all_channels=False) + + # Create a consumer + rx_messages = [] + def ackAndEnqueueMessageViaAnotherThread(rx_ch, + rx_method, + rx_properties, # pylint: disable=W0613 + rx_body): + LOGGER.debug( + '%s: Got message body=%r; delivery-tag=%r', + datetime.now(), rx_body, rx_method.delivery_tag) + + # Request ACK dispatch via add_callback_threadsafe from other + # thread; if last message, cancel consumer so that start_consuming + # can return + + def processOnConnectionThread(): + LOGGER.debug('%s: ACKing message body=%r; delivery-tag=%r', + datetime.now(), + rx_body, + rx_method.delivery_tag) + ch.basic_ack(delivery_tag=rx_method.delivery_tag, + multiple=False) + rx_messages.append(rx_body) + + # NOTE on python3, `b'last-msg' != 'last-msg'` + if rx_body == b'last-msg': + LOGGER.debug('%s: Canceling consumer consumer-tag=%r', + datetime.now(), + rx_method.consumer_tag) + rx_ch.basic_cancel(rx_method.consumer_tag) + + # Spawn a thread to initiate ACKing + timer = threading.Timer(0, + lambda: connection.add_callback_threadsafe( + processOnConnectionThread)) + self.addCleanup(timer.cancel) + timer.start() + + consumer_tag = ch.basic_consume( + ackAndEnqueueMessageViaAnotherThread, + q_name, + no_ack=False, + exclusive=False, + arguments=None) + + # Wait for both messages + LOGGER.debug('%s: calling start_consuming(); consumer tag=%r', + datetime.now(), + consumer_tag) + ch.start_consuming() + LOGGER.debug('%s: Returned from start_consuming(); consumer tag=%r', + datetime.now(), + consumer_tag) + + self.assertEqual(len(rx_messages), 2) + self.assertEqual(rx_messages[0], b'msg1') + self.assertEqual(rx_messages[1], b'last-msg') + + +class TestConsumeGeneratorWithAckFromAnotherThread(BlockingTestCaseBase): + + def test(self): # pylint: disable=R0914,R0915 + """BlockingChannel.consume and requesting basic_ack from another \ + thread via add_callback_threadsafe + """ + connection = self._connect() + + ch = connection.channel() + + q_name = ('TestConsumeGeneratorWithAckFromAnotherThread_q' + + uuid.uuid1().hex) + exg_name = ('TestConsumeGeneratorWithAckFromAnotherThread_exg' + + uuid.uuid1().hex) + routing_key = 'TestConsumeGeneratorWithAckFromAnotherThread' + + # Place channel in publisher-acknowledgments mode so that publishing + # with mandatory=True will be synchronous (for convenience) + res = ch.confirm_delivery() + self.assertIsNone(res) + + # Declare a new exchange + ch.exchange_declare(exg_name, exchange_type='direct') + self.addCleanup(connection.channel().exchange_delete, exg_name) + + # Declare a new queue + ch.queue_declare(q_name, auto_delete=True) + self.addCleanup(self._connect().channel().queue_delete, q_name) + + # Bind the queue to the exchange using routing key + ch.queue_bind(q_name, exchange=exg_name, routing_key=routing_key) + + # Publish 2 messages with mandatory=True for synchronous processing + ch.publish(exg_name, routing_key, body='msg1', mandatory=True) + ch.publish(exg_name, routing_key, body='last-msg', mandatory=True) + + # Configure QoS for one message so that the 2nd message will be + # delivered only after the 1st one is ACKed + ch.basic_qos(prefetch_size=0, prefetch_count=1, all_channels=False) + + # Create a consumer + rx_messages = [] + def ackAndEnqueueMessageViaAnotherThread(rx_ch, + rx_method, + rx_properties, # pylint: disable=W0613 + rx_body): + LOGGER.debug( + '%s: Got message body=%r; delivery-tag=%r', + datetime.now(), rx_body, rx_method.delivery_tag) + + # Request ACK dispatch via add_callback_threadsafe from other + # thread; if last message, cancel consumer so that consumer + # generator completes + + def processOnConnectionThread(): + LOGGER.debug('%s: ACKing message body=%r; delivery-tag=%r', + datetime.now(), + rx_body, + rx_method.delivery_tag) + ch.basic_ack(delivery_tag=rx_method.delivery_tag, + multiple=False) + rx_messages.append(rx_body) + + # NOTE on python3, `b'last-msg' != 'last-msg'` + if rx_body == b'last-msg': + LOGGER.debug('%s: Canceling consumer consumer-tag=%r', + datetime.now(), + rx_method.consumer_tag) + # NOTE Need to use cancel() for the consumer generator + # instead of basic_cancel() + rx_ch.cancel() + + # Spawn a thread to initiate ACKing + timer = threading.Timer(0, + lambda: connection.add_callback_threadsafe( + processOnConnectionThread)) + self.addCleanup(timer.cancel) + timer.start() + + for method, properties, body in ch.consume(q_name, no_ack=False): + ackAndEnqueueMessageViaAnotherThread(rx_ch=ch, + rx_method=method, + rx_properties=properties, + rx_body=body) + + self.assertEqual(len(rx_messages), 2) + self.assertEqual(rx_messages[0], b'msg1') + self.assertEqual(rx_messages[1], b'last-msg') + + class TestTwoBasicConsumersOnSameChannel(BlockingTestCaseBase): def test(self): # pylint: disable=R0914 @@ -1938,7 +2173,7 @@ def test(self): # pylint: disable=R0914,R0915 queue=q_name, expected_count=0) - # Attempt to cosume again with a short timeout + # Attempt to consume again with a short timeout connection.process_data_events(time_limit=0.005) self.assertEqual(len(rx_messages), 2) diff --git a/tests/unit/blocking_connection_tests.py b/tests/unit/blocking_connection_tests.py index 118aa2d79..ab046a07d 100644 --- a/tests/unit/blocking_connection_tests.py +++ b/tests/unit/blocking_connection_tests.py @@ -126,8 +126,7 @@ def test_flush_output_user_initiated_close(self, connection._flush_output(lambda: False, lambda: True) self.assertEqual(connection._impl.ioloop.activate_poller.call_count, 1) - self.assertEqual(connection._impl.ioloop.deactivate_poller.call_count, - 1) + self.assertEqual(connection._impl.ioloop.close.call_count, 1) @patch.object( blocking_connection, @@ -152,8 +151,7 @@ def test_flush_output_server_initiated_error_close( self.assertSequenceEqual(cm.exception.args, (404, 'not found')) self.assertEqual(connection._impl.ioloop.activate_poller.call_count, 1) - self.assertEqual(connection._impl.ioloop.deactivate_poller.call_count, - 1) + self.assertEqual(connection._impl.ioloop.close.call_count, 1) @patch.object( blocking_connection, @@ -178,8 +176,7 @@ def test_flush_output_server_initiated_no_error_close( self.assertSequenceEqual(cm.exception.args, (200, 'ok')) self.assertEqual(connection._impl.ioloop.activate_poller.call_count, 1) - self.assertEqual(connection._impl.ioloop.deactivate_poller.call_count, - 1) + self.assertEqual(connection._impl.ioloop.close.call_count, 1) @patch.object( blocking_connection, diff --git a/tests/unit/connection_timeout_tests.py b/tests/unit/connection_timeout_tests.py index b44161bfb..3118faba0 100644 --- a/tests/unit/connection_timeout_tests.py +++ b/tests/unit/connection_timeout_tests.py @@ -49,8 +49,13 @@ def test_asyncio_connection_timeout(self): connect=mock.Mock(side_effect=mock_timeout)) ) as create_sock_mock: params = pika.ConnectionParameters(socket_timeout=2.0) - conn = asyncio_connection.AsyncioConnection(params) + ioloop = asyncio_connection.asyncio.new_event_loop() + self.addCleanup(ioloop.close) + conn = asyncio_connection.AsyncioConnection( + params, + custom_ioloop=ioloop) conn._on_connect_timer() + create_sock_mock.return_value.settimeout.assert_called_with(2.0) self.assertIn('timeout', str(err_ctx.exception)) @@ -99,6 +104,7 @@ def test_select_connection_timeout(self): side_effect=mock_timeout))) as create_sock_mock: params = pika.ConnectionParameters(socket_timeout=2.0) conn = select_connection.SelectConnection(params) + self.addCleanup(conn.ioloop.close) conn._on_connect_timer() create_sock_mock.return_value.settimeout.assert_called_with(2.0) self.assertIn('timeout', str(err_ctx.exception)) @@ -113,7 +119,11 @@ def test_tornado_connection_timeout(self): connect=mock.Mock( side_effect=mock_timeout))) as create_sock_mock: params = pika.ConnectionParameters(socket_timeout=2.0) - conn = tornado_connection.TornadoConnection(params) + ioloop = tornado_connection.ioloop.IOLoop() + self.addCleanup(ioloop.close) + conn = tornado_connection.TornadoConnection( + params, + custom_ioloop=ioloop) conn._on_connect_timer() create_sock_mock.return_value.settimeout.assert_called_with(2.0) self.assertIn('timeout', str(err_ctx.exception)) diff --git a/tests/unit/select_connection_ioloop_tests.py b/tests/unit/select_connection_ioloop_tests.py index 6769668f0..d5c9cf168 100644 --- a/tests/unit/select_connection_ioloop_tests.py +++ b/tests/unit/select_connection_ioloop_tests.py @@ -20,6 +20,12 @@ from pika import compat from pika.adapters import select_connection +# protected-access +# pylint: disable=W0212 +# missing-docstring +# pylint: disable=C0111 + + EPOLL_SUPPORTED = hasattr(select, 'epoll') POLL_SUPPORTED = hasattr(select, 'poll') and hasattr(select.poll(), 'modify') KQUEUE_SUPPORTED = hasattr(select, 'kqueue') @@ -37,6 +43,7 @@ def setUp(self): self.ioloop = select_connection.IOLoop() self.addCleanup(setattr, self, 'ioloop', None) + self.addCleanup(self.ioloop.close) activate_poller_patch = mock.patch.object( self.ioloop._poller, @@ -74,13 +81,76 @@ def on_timeout(self): self.fail('Test timed out') +class IOLoopCloseClosesSubordinateObjectsTestSelect(IOLoopBaseTest): + """ Test ioloop being closed """ + SELECT_POLLER = 'select' + + def start_test(self): + with mock.patch.multiple(self.ioloop, + _timer=mock.DEFAULT, + _poller=mock.DEFAULT, + _callbacks=mock.DEFAULT) as mocks: + self.ioloop.close() + mocks['_timer'].close.assert_called_once() + mocks['_poller'].close.assert_called_once() + self.assertIsNone(self.ioloop._callbacks) + + +class IOLoopCloseAfterStartReturnsTestSelect(IOLoopBaseTest): + """ Test IOLoop.close() after normal return from start(). """ + SELECT_POLLER = 'select' + + def start_test(self): + self.ioloop.stop() # so start will terminate quickly + self.start() + self.ioloop.close() + self.assertIsNone(self.ioloop._callbacks) + + +class IOLoopCloseBeforeStartReturnsTestSelect(IOLoopBaseTest): + """ Test calling IOLoop.close() before return from start() raises exception. """ + SELECT_POLLER = 'select' + + def start_test(self): + callback_completed = [] + + def call_close_from_callback(): + with self.assertRaises(AssertionError) as cm: + self.ioloop.close() + + self.assertEqual(cm.exception.args[0], + 'Cannot call close() before start() unwinds.') + self.ioloop.stop() + callback_completed.append(1) + + self.ioloop.add_callback_threadsafe(call_close_from_callback) + self.start() + self.assertEqual(callback_completed, [1]) + + +class IOLoopThreadStopTestSelect(IOLoopBaseTest): + """ Test ioloop being stopped by another Thread. """ + SELECT_POLLER = 'select' + + def start_test(self): + """Starts a thread that stops ioloop after a while and start polling""" + timer = threading.Timer( + 0.1, + lambda: self.ioloop.add_callback_threadsafe(self.ioloop.stop)) + self.addCleanup(timer.cancel) + timer.start() + self.start() # NOTE: Normal return from `start()` constitutes success + + class IOLoopThreadStopTestSelect(IOLoopBaseTest): """ Test ioloop being stopped by another Thread. """ SELECT_POLLER = 'select' def start_test(self): """Starts a thread that stops ioloop after a while and start polling""" - timer = threading.Timer(0.1, self.ioloop.stop) + timer = threading.Timer( + 0.1, + lambda: self.ioloop.add_callback_threadsafe(self.ioloop.stop)) self.addCleanup(timer.cancel) timer.start() self.start() # NOTE: Normal return from `start()` constitutes success @@ -441,6 +511,7 @@ def test_eintr( timer = select_connection._Timer() self.poller = self.ioloop._get_poller(timer.get_remaining_interval, timer.process_timeouts) + self.addCleanup(self.poller.close) sockpair = self.poller._get_interrupt_pair() self.addCleanup(sockpair[0].close) @@ -496,6 +567,7 @@ def start_test(self): poller = select_connection.SelectPoller( get_wait_seconds=timer.get_remaining_interval, process_timeouts=timer.process_timeouts) + self.addCleanup(poller.close) timer_call_container = [] timer.call_later(0.00001, lambda: timer_call_container.append(1)) @@ -517,3 +589,46 @@ def start_test(self): break self.assertEqual(timer_call_container, [1]) + + +class PollerTestCaseSelect(unittest.TestCase): + SELECT_POLLER = 'select' + + def setUp(self): + select_type_patch = mock.patch.multiple( + select_connection, SELECT_TYPE=self.SELECT_POLLER) + select_type_patch.start() + self.addCleanup(select_type_patch.stop) + + timer = select_connection._Timer() + self.addCleanup(timer.close) + self.poller = select_connection.IOLoop._get_poller( + timer.get_remaining_interval, + timer.process_timeouts) + self.addCleanup(self.poller.close) + + def test_poller_close(self): + self.poller.close() + self.assertIsNone(self.poller._r_interrupt) + self.assertIsNone(self.poller._w_interrupt) + self.assertIsNone(self.poller._fd_handlers) + self.assertIsNone(self.poller._fd_events) + self.assertIsNone(self.poller._processing_fd_event_map) + + +@unittest.skipIf(not POLL_SUPPORTED, 'poll not supported') +class PollerTestCasePoll(PollerTestCaseSelect): + """Same as PollerTestCaseSelect but uses poll syscall""" + SELECT_POLLER = 'poll' + + +@unittest.skipIf(not EPOLL_SUPPORTED, 'epoll not supported') +class PollerTestCaseEPoll(PollerTestCaseSelect): + """Same as PollerTestCaseSelect but uses epoll syscall""" + SELECT_POLLER = 'epoll' + + +@unittest.skipIf(not KQUEUE_SUPPORTED, 'kqueue not supported') +class PollerTestCaseKqueue(PollerTestCaseSelect): + """Same as PollerTestCaseSelect but uses kqueue syscall""" + SELECT_POLLER = 'kqueue' diff --git a/tests/unit/select_connection_timer_tests.py b/tests/unit/select_connection_timer_tests.py index 1017e7ec3..72e3db596 100644 --- a/tests/unit/select_connection_timer_tests.py +++ b/tests/unit/select_connection_timer_tests.py @@ -91,6 +91,20 @@ def test_le_operator(self): class TimerClassTests(unittest.TestCase): """Test select_connection._Timer class""" + def test_close_empty(self): + timer = select_connection._Timer() + timer.close() + self.assertIsNone(timer._timeout_heap) + + def test_close_non_empty(self): + timer = select_connection._Timer() + t1 = timer.call_later(10, lambda: 10) + t2 = timer.call_later(20, lambda: 20) + timer.close() + self.assertIsNone(timer._timeout_heap) + self.assertIsNone(t1.callback) + self.assertIsNone(t2.callback) + def test_no_timeouts_remaining_interval_is_none(self): timer = select_connection._Timer() self.assertIsNone(timer.get_remaining_interval())