Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request marshaling error should not corrupt a channel #991

Merged
merged 6 commits into from Apr 27, 2018
Merged
10 changes: 7 additions & 3 deletions pika/channel.py
Expand Up @@ -1347,7 +1347,7 @@ def _drain_blocked_methods_on_remote_close(self):
sent, and thus its completion callback would never be called.

"""
LOGGER.debug('Draining %i blocked frames due to remote Channel.Close',
LOGGER.debug('Draining %i blocked frames due to broker-requested Channel.Close',
len(self._blocked))
while self._blocked:
method = self._blocked.popleft()[0]
Expand Down Expand Up @@ -1408,6 +1408,12 @@ def _rpc(self, method, callback=None, acceptable_replies=None):
self._blocked.append([method, callback, acceptable_replies])
return

# Note: _send_method can throw exceptions if there are framing errors
# or invalid data passed in. Call it here to prevent self._blocking
# from being set if an exception is thrown. This also prevents
# acceptable_replies registering callbacks when exceptions are thrown
self._send_method(method)

# If acceptable replies are set, add callbacks
if acceptable_replies:
# Block until a response frame is received for synchronous frames
Expand All @@ -1430,8 +1436,6 @@ def _rpc(self, method, callback=None, acceptable_replies=None):
self.callbacks.add(self.channel_number, reply, callback,
arguments=arguments)

self._send_method(method)

def _raise_if_not_open(self):
"""If channel is not in the OPEN state, raises ChannelClosed with
`reply_code` and `reply_text` corresponding to current state. If channel
Expand Down
34 changes: 26 additions & 8 deletions pika/connection.py
Expand Up @@ -2306,11 +2306,7 @@ def _send_frame(self, frame_value):
'Attempted to send a frame on closed connection.')

marshaled_frame = frame_value.marshal()
self.bytes_sent += len(marshaled_frame)
self.frames_sent += 1
self._adapter_emit_data(marshaled_frame)
if self.params.backpressure_detection:
self._detect_backpressure()
self._output_marshaled_frames([marshaled_frame])

def _send_method(self, channel_number, method, content=None):
"""Constructs a RPC method frame and then sends it to the broker.
Expand All @@ -2336,8 +2332,14 @@ def _send_message(self, channel_number, method_frame, content):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidentally, the method_frame parameter to _send_message() is misnamed in this legacy code. What's passed in should be called method. The method_frame is created inside the function when it constructs frame.Method.

"""
length = len(content[1])
self._send_frame(frame.Method(channel_number, method_frame))
self._send_frame(frame.Header(channel_number, length, content[0]))
marshaled_body_frames = []

# Note: we construct the Method, Header and Content objects, marshal them
# *then* output in case the marshaling operation throws an exception
frame_method = frame.Method(channel_number, method_frame)
frame_header = frame.Header(channel_number, length, content[0])
marshaled_body_frames.append(frame_method.marshal())
marshaled_body_frames.append(frame_header.marshal())

if content[1]:
chunks = int(math.ceil(float(length) / self._body_max_length))
Expand All @@ -2346,7 +2348,10 @@ def _send_message(self, channel_number, method_frame, content):
end = start + self._body_max_length
if end > length:
end = length
self._send_frame(frame.Body(channel_number, content[1][start:end]))
frame_body = frame.Body(channel_number, content[1][start:end])
marshaled_body_frames.append(frame_body.marshal())

self._output_marshaled_frames(marshaled_body_frames)

def _set_connection_state(self, connection_state):
"""Set the connection state.
Expand Down Expand Up @@ -2382,3 +2387,16 @@ def _trim_frame_buffer(self, byte_count):
"""
self._frame_buffer = self._frame_buffer[byte_count:]
self.bytes_received += byte_count

def _output_marshaled_frames(self, marshaled_frames):
"""Output list of marshaled frames to buffer and update stats

:param list marshaled_frames: A list of frames marshaled to bytes

"""
for marshaled_frame in marshaled_frames:
self.bytes_sent += len(marshaled_frame)
self.frames_sent += 1
self._adapter_emit_data(marshaled_frame)
if self.params.backpressure_detection:
self._detect_backpressure()
59 changes: 47 additions & 12 deletions tests/acceptance/async_adapter_tests.py
Expand Up @@ -625,8 +625,9 @@ def on_bad_result(self, frame):
raise AssertionError("Should not have received an Exchange.DeclareOk")


class TestPassiveExchangeDeclareWithConcurrentClose(AsyncTestCase, AsyncAdapters):
DESCRIPTION = "should close channel: declare passive exchange with close"
class TestNoDeadlockWhenClosingChannelWithPendingBlockedRequestsAndConcurrentChannelCloseFromBroker(
AsyncTestCase, AsyncAdapters):
DESCRIPTION = "No deadlock when closing a channel with pending blocked requests and concurrent Channel.Close from broker."

# To observe the behavior that this is testing, comment out this line
# in pika/channel.py - _on_close:
Expand All @@ -636,10 +637,12 @@ class TestPassiveExchangeDeclareWithConcurrentClose(AsyncTestCase, AsyncAdapters
# With the above line commented out, this test will hang

def begin(self, channel):
self.name = self.__class__.__name__ + ':' + uuid.uuid1().hex
base_exch_name = self.__class__.__name__ + ':' + uuid.uuid1().hex
self.channel.add_on_close_callback(self.on_channel_closed)
for i in range(0, 99):
exch_name = self.name + ':' + str(i)
# Passively declare a non-existent exchange to force Channel.Close
# from broker
exch_name = base_exch_name + ':' + str(i)
cb = functools.partial(self.on_bad_result, exch_name)
channel.exchange_declare(exch_name,
exchange_type='direct',
Expand All @@ -648,15 +651,49 @@ def begin(self, channel):
channel.close()

def on_channel_closed(self, channel, reply_code, reply_text):
# The close is expected because the requested exchange doesn't exist
self.stop()

def on_bad_result(self, exch_name, frame):
self.channel.exchange_delete(exch_name)
raise AssertionError("Should not have received an Exchange.DeclareOk")
self.fail("Should not have received an Exchange.DeclareOk")


class TestQueueDeclareAndDelete(AsyncTestCase, AsyncAdapters):
DESCRIPTION = "Create and delete a queue"
class TestClosingAChannelPermitsBlockedRequestToComplete(AsyncTestCase,
AsyncAdapters):
DESCRIPTION = "Closing a channel permits blocked requests to complete."

def begin(self, channel):
self._queue_deleted = False

channel.add_on_close_callback(self.on_channel_closed)

q_name = self.__class__.__name__ + ':' + uuid.uuid1().hex
# NOTE we pass callback to make it a blocking request
channel.queue_declare(q_name,
exclusive=True,
callback=lambda _frame: None)

self.assertIsNotNone(channel._blocking)

# The Queue.Delete should block on completion of Queue.Declare
channel.queue_delete(q_name, callback=self.on_queue_deleted)
self.assertTrue(channel._blocked)

# This Channel.Close should allow the blocked Queue.Delete to complete
# Before closing the channel
channel.close()

def on_queue_deleted(self, _frame):
# Getting this callback shows that the blocked request was processed
self._queue_deleted = True

def on_channel_closed(self, _channel, _reply_code, _reply_text):
self.assertTrue(self._queue_deleted)
self.stop()


class TestQueueUnnamedDeclareAndDelete(AsyncTestCase, AsyncAdapters):
DESCRIPTION = "Create and delete an unnamed queue"

def begin(self, channel):
channel.queue_declare(queue='',
Expand All @@ -673,11 +710,11 @@ def on_queue_declared(self, frame):

def on_queue_delete(self, frame):
self.assertIsInstance(frame.method, spec.Queue.DeleteOk)
# NOTE: with event loops that suppress exceptions from callbacks
self.stop()



class TestQueueNameDeclareAndDelete(AsyncTestCase, AsyncAdapters):
class TestQueueNamedDeclareAndDelete(AsyncTestCase, AsyncAdapters):
DESCRIPTION = "Create and delete a named queue"

def begin(self, channel):
Expand All @@ -701,7 +738,6 @@ def on_queue_delete(self, frame):
self.stop()



class TestQueueRedeclareWithDifferentValues(AsyncTestCase, AsyncAdapters):
DESCRIPTION = "Should close chan: re-declared queue w/ diff params"

Expand Down Expand Up @@ -745,7 +781,6 @@ def on_complete(self, frame):
self.stop()



class TestTX2_Commit(AsyncTestCase, AsyncAdapters): # pylint: disable=C0103
DESCRIPTION = "Start a transaction, and commit it"

Expand Down
11 changes: 10 additions & 1 deletion tests/acceptance/blocking_adapter_test.py
Expand Up @@ -50,7 +50,6 @@ def setUpModule():
logging.basicConfig(level=logging.DEBUG)


#@unittest.skip('SKIPPING WHILE DEBUGGING SOME CHANGES. DO NOT MERGE LIKE THIS')
class BlockingTestCaseBase(unittest.TestCase):

TIMEOUT = DEFAULT_TIMEOUT
Expand Down Expand Up @@ -355,6 +354,16 @@ def test(self):
self.assertFalse(ch._impl._consumers)


class TestUsingInvalidQueueArgument(BlockingTestCaseBase):
def test(self):
"""BlockingConnection raises expected exception when invalid queue parameter is used
"""
connection = self._connect()
ch = connection.channel()
with self.assertRaises(AssertionError):
ch.queue_declare(queue=[1, 2, 3])


class TestSuddenBrokerDisconnectBeforeChannel(BlockingTestCaseBase):

def test(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/channel_tests.py
Expand Up @@ -1587,3 +1587,18 @@ def test_validate_callback_raises_value_error_not_callable(
self.assertRaises(TypeError,
self.obj._validate_rpc_completion_callback,
'foo')

def test_no_side_effects_from_send_method_error(self):
self.obj._set_state(self.obj.OPEN)

self.assertIsNone(self.obj._blocking)

with mock.patch.object(self.obj.callbacks, 'add') as cb_add_mock:
with mock.patch.object(self.obj, '_send_method',
side_effect=TypeError) as send_method_mock:
with self.assertRaises(TypeError):
self.obj.queue_delete('', callback=lambda _frame: None)

self.assertEqual(send_method_mock.call_count, 1)
self.assertIsNone(self.obj._blocking)
self.assertEqual(cb_add_mock.call_count, 0)
28 changes: 28 additions & 0 deletions tests/unit/connection_tests.py
Expand Up @@ -983,3 +983,31 @@ def test_send_message_updates_frames_sent_and_bytes_sent(

# Make sure _detect_backpressure doesn't throw
self.connection._detect_backpressure()


def test_no_side_effects_from_message_marshal_error(self):
# Verify that frame buffer is empty on entry
self.assertEqual(b'', self.connection._frame_buffer)

# Use Basic.Public with invalid body to trigger marshalling error
method = spec.Basic.Publish()
properties = spec.BasicProperties()
# Verify that marshalling of method and header won't trigger error
frame.Method(1, method).marshal()
frame.Header(1, body_size=10, props=properties).marshal()
# Create bogus body that should trigger an error during marshalling
body = [1,2,3,4]
# Verify that frame body can be created using the bogus body, but
# that marshalling will fail
frame.Body(1, body)
with self.assertRaises(TypeError):
frame.Body(1, body).marshal()

# Now, attempt to send the method with the bogus body
with self.assertRaises(TypeError):
self.connection._send_method(channel_number=1,
method=method,
content=(properties, body))

# Now make sure that nothing is enqueued on frame buffer
self.assertEqual(b'', self.connection._frame_buffer)