diff --git a/pika/channel.py b/pika/channel.py index fb67a0d34..282f53c0a 100644 --- a/pika/channel.py +++ b/pika/channel.py @@ -1347,7 +1347,7 @@ def _drain_blocked_methods_on_remote_close(self): sent, and thus its completion callback would never be called. """ - LOGGER.debug('Draining %i blocked frames due to remote Channel.Close', + LOGGER.debug('Draining %i blocked frames due to broker-requested Channel.Close', len(self._blocked)) while self._blocked: method = self._blocked.popleft()[0] @@ -1408,6 +1408,12 @@ def _rpc(self, method, callback=None, acceptable_replies=None): self._blocked.append([method, callback, acceptable_replies]) return + # Note: _send_method can throw exceptions if there are framing errors + # or invalid data passed in. Call it here to prevent self._blocking + # from being set if an exception is thrown. This also prevents + # acceptable_replies registering callbacks when exceptions are thrown + self._send_method(method) + # If acceptable replies are set, add callbacks if acceptable_replies: # Block until a response frame is received for synchronous frames @@ -1430,8 +1436,6 @@ def _rpc(self, method, callback=None, acceptable_replies=None): self.callbacks.add(self.channel_number, reply, callback, arguments=arguments) - self._send_method(method) - def _raise_if_not_open(self): """If channel is not in the OPEN state, raises ChannelClosed with `reply_code` and `reply_text` corresponding to current state. If channel diff --git a/pika/connection.py b/pika/connection.py index be2b1bc2d..271b19807 100644 --- a/pika/connection.py +++ b/pika/connection.py @@ -2306,11 +2306,7 @@ def _send_frame(self, frame_value): 'Attempted to send a frame on closed connection.') marshaled_frame = frame_value.marshal() - self.bytes_sent += len(marshaled_frame) - self.frames_sent += 1 - self._adapter_emit_data(marshaled_frame) - if self.params.backpressure_detection: - self._detect_backpressure() + self._output_marshaled_frames([marshaled_frame]) def _send_method(self, channel_number, method, content=None): """Constructs a RPC method frame and then sends it to the broker. @@ -2336,8 +2332,14 @@ def _send_message(self, channel_number, method_frame, content): """ length = len(content[1]) - self._send_frame(frame.Method(channel_number, method_frame)) - self._send_frame(frame.Header(channel_number, length, content[0])) + marshaled_body_frames = [] + + # Note: we construct the Method, Header and Content objects, marshal them + # *then* output in case the marshaling operation throws an exception + frame_method = frame.Method(channel_number, method_frame) + frame_header = frame.Header(channel_number, length, content[0]) + marshaled_body_frames.append(frame_method.marshal()) + marshaled_body_frames.append(frame_header.marshal()) if content[1]: chunks = int(math.ceil(float(length) / self._body_max_length)) @@ -2346,7 +2348,10 @@ def _send_message(self, channel_number, method_frame, content): end = start + self._body_max_length if end > length: end = length - self._send_frame(frame.Body(channel_number, content[1][start:end])) + frame_body = frame.Body(channel_number, content[1][start:end]) + marshaled_body_frames.append(frame_body.marshal()) + + self._output_marshaled_frames(marshaled_body_frames) def _set_connection_state(self, connection_state): """Set the connection state. @@ -2382,3 +2387,16 @@ def _trim_frame_buffer(self, byte_count): """ self._frame_buffer = self._frame_buffer[byte_count:] self.bytes_received += byte_count + + def _output_marshaled_frames(self, marshaled_frames): + """Output list of marshaled frames to buffer and update stats + + :param list marshaled_frames: A list of frames marshaled to bytes + + """ + for marshaled_frame in marshaled_frames: + self.bytes_sent += len(marshaled_frame) + self.frames_sent += 1 + self._adapter_emit_data(marshaled_frame) + if self.params.backpressure_detection: + self._detect_backpressure() diff --git a/tests/acceptance/async_adapter_tests.py b/tests/acceptance/async_adapter_tests.py index 2a51ae199..967e1096b 100644 --- a/tests/acceptance/async_adapter_tests.py +++ b/tests/acceptance/async_adapter_tests.py @@ -625,8 +625,9 @@ def on_bad_result(self, frame): raise AssertionError("Should not have received an Exchange.DeclareOk") -class TestPassiveExchangeDeclareWithConcurrentClose(AsyncTestCase, AsyncAdapters): - DESCRIPTION = "should close channel: declare passive exchange with close" +class TestNoDeadlockWhenClosingChannelWithPendingBlockedRequestsAndConcurrentChannelCloseFromBroker( + AsyncTestCase, AsyncAdapters): + DESCRIPTION = "No deadlock when closing a channel with pending blocked requests and concurrent Channel.Close from broker." # To observe the behavior that this is testing, comment out this line # in pika/channel.py - _on_close: @@ -636,10 +637,12 @@ class TestPassiveExchangeDeclareWithConcurrentClose(AsyncTestCase, AsyncAdapters # With the above line commented out, this test will hang def begin(self, channel): - self.name = self.__class__.__name__ + ':' + uuid.uuid1().hex + base_exch_name = self.__class__.__name__ + ':' + uuid.uuid1().hex self.channel.add_on_close_callback(self.on_channel_closed) for i in range(0, 99): - exch_name = self.name + ':' + str(i) + # Passively declare a non-existent exchange to force Channel.Close + # from broker + exch_name = base_exch_name + ':' + str(i) cb = functools.partial(self.on_bad_result, exch_name) channel.exchange_declare(exch_name, exchange_type='direct', @@ -648,15 +651,49 @@ def begin(self, channel): channel.close() def on_channel_closed(self, channel, reply_code, reply_text): + # The close is expected because the requested exchange doesn't exist self.stop() def on_bad_result(self, exch_name, frame): - self.channel.exchange_delete(exch_name) - raise AssertionError("Should not have received an Exchange.DeclareOk") + self.fail("Should not have received an Exchange.DeclareOk") -class TestQueueDeclareAndDelete(AsyncTestCase, AsyncAdapters): - DESCRIPTION = "Create and delete a queue" +class TestClosingAChannelPermitsBlockedRequestToComplete(AsyncTestCase, + AsyncAdapters): + DESCRIPTION = "Closing a channel permits blocked requests to complete." + + def begin(self, channel): + self._queue_deleted = False + + channel.add_on_close_callback(self.on_channel_closed) + + q_name = self.__class__.__name__ + ':' + uuid.uuid1().hex + # NOTE we pass callback to make it a blocking request + channel.queue_declare(q_name, + exclusive=True, + callback=lambda _frame: None) + + self.assertIsNotNone(channel._blocking) + + # The Queue.Delete should block on completion of Queue.Declare + channel.queue_delete(q_name, callback=self.on_queue_deleted) + self.assertTrue(channel._blocked) + + # This Channel.Close should allow the blocked Queue.Delete to complete + # Before closing the channel + channel.close() + + def on_queue_deleted(self, _frame): + # Getting this callback shows that the blocked request was processed + self._queue_deleted = True + + def on_channel_closed(self, _channel, _reply_code, _reply_text): + self.assertTrue(self._queue_deleted) + self.stop() + + +class TestQueueUnnamedDeclareAndDelete(AsyncTestCase, AsyncAdapters): + DESCRIPTION = "Create and delete an unnamed queue" def begin(self, channel): channel.queue_declare(queue='', @@ -673,11 +710,11 @@ def on_queue_declared(self, frame): def on_queue_delete(self, frame): self.assertIsInstance(frame.method, spec.Queue.DeleteOk) + # NOTE: with event loops that suppress exceptions from callbacks self.stop() - -class TestQueueNameDeclareAndDelete(AsyncTestCase, AsyncAdapters): +class TestQueueNamedDeclareAndDelete(AsyncTestCase, AsyncAdapters): DESCRIPTION = "Create and delete a named queue" def begin(self, channel): @@ -701,7 +738,6 @@ def on_queue_delete(self, frame): self.stop() - class TestQueueRedeclareWithDifferentValues(AsyncTestCase, AsyncAdapters): DESCRIPTION = "Should close chan: re-declared queue w/ diff params" @@ -745,7 +781,6 @@ def on_complete(self, frame): self.stop() - class TestTX2_Commit(AsyncTestCase, AsyncAdapters): # pylint: disable=C0103 DESCRIPTION = "Start a transaction, and commit it" diff --git a/tests/acceptance/blocking_adapter_test.py b/tests/acceptance/blocking_adapter_test.py index d79ded2f7..d0ed48e9a 100644 --- a/tests/acceptance/blocking_adapter_test.py +++ b/tests/acceptance/blocking_adapter_test.py @@ -50,7 +50,6 @@ def setUpModule(): logging.basicConfig(level=logging.DEBUG) -#@unittest.skip('SKIPPING WHILE DEBUGGING SOME CHANGES. DO NOT MERGE LIKE THIS') class BlockingTestCaseBase(unittest.TestCase): TIMEOUT = DEFAULT_TIMEOUT @@ -355,6 +354,16 @@ def test(self): self.assertFalse(ch._impl._consumers) +class TestUsingInvalidQueueArgument(BlockingTestCaseBase): + def test(self): + """BlockingConnection raises expected exception when invalid queue parameter is used + """ + connection = self._connect() + ch = connection.channel() + with self.assertRaises(AssertionError): + ch.queue_declare(queue=[1, 2, 3]) + + class TestSuddenBrokerDisconnectBeforeChannel(BlockingTestCaseBase): def test(self): diff --git a/tests/unit/channel_tests.py b/tests/unit/channel_tests.py index 10e594ea4..dc353efcc 100644 --- a/tests/unit/channel_tests.py +++ b/tests/unit/channel_tests.py @@ -1587,3 +1587,18 @@ def test_validate_callback_raises_value_error_not_callable( self.assertRaises(TypeError, self.obj._validate_rpc_completion_callback, 'foo') + + def test_no_side_effects_from_send_method_error(self): + self.obj._set_state(self.obj.OPEN) + + self.assertIsNone(self.obj._blocking) + + with mock.patch.object(self.obj.callbacks, 'add') as cb_add_mock: + with mock.patch.object(self.obj, '_send_method', + side_effect=TypeError) as send_method_mock: + with self.assertRaises(TypeError): + self.obj.queue_delete('', callback=lambda _frame: None) + + self.assertEqual(send_method_mock.call_count, 1) + self.assertIsNone(self.obj._blocking) + self.assertEqual(cb_add_mock.call_count, 0) diff --git a/tests/unit/connection_tests.py b/tests/unit/connection_tests.py index 19df8735d..04fd5438b 100644 --- a/tests/unit/connection_tests.py +++ b/tests/unit/connection_tests.py @@ -983,3 +983,31 @@ def test_send_message_updates_frames_sent_and_bytes_sent( # Make sure _detect_backpressure doesn't throw self.connection._detect_backpressure() + + + def test_no_side_effects_from_message_marshal_error(self): + # Verify that frame buffer is empty on entry + self.assertEqual(b'', self.connection._frame_buffer) + + # Use Basic.Public with invalid body to trigger marshalling error + method = spec.Basic.Publish() + properties = spec.BasicProperties() + # Verify that marshalling of method and header won't trigger error + frame.Method(1, method).marshal() + frame.Header(1, body_size=10, props=properties).marshal() + # Create bogus body that should trigger an error during marshalling + body = [1,2,3,4] + # Verify that frame body can be created using the bogus body, but + # that marshalling will fail + frame.Body(1, body) + with self.assertRaises(TypeError): + frame.Body(1, body).marshal() + + # Now, attempt to send the method with the bogus body + with self.assertRaises(TypeError): + self.connection._send_method(channel_number=1, + method=method, + content=(properties, body)) + + # Now make sure that nothing is enqueued on frame buffer + self.assertEqual(b'', self.connection._frame_buffer)