From 7922e076c2b5ded29156b9523dc5cf8178c3f1ce Mon Sep 17 00:00:00 2001 From: Patrick Hayes Date: Tue, 19 Apr 2016 13:00:19 -0700 Subject: [PATCH] __enter__ / __exit__ on BlockingChannel --- pika/adapters/blocking_connection.py | 9 +++++++++ tests/unit/blocking_channel_tests.py | 24 +++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pika/adapters/blocking_connection.py b/pika/adapters/blocking_connection.py index c05c3ce84..2582517a8 100644 --- a/pika/adapters/blocking_connection.py +++ b/pika/adapters/blocking_connection.py @@ -1172,6 +1172,15 @@ def __int__(self): def __repr__(self): return '<%s impl=%r>' % (self.__class__.__name__, self._impl) + def __enter__(self): + return self + + def __exit__(self, exc_type, value, traceback): + try: + self.close() + except exceptions.ChannelClosed: + pass + def _cleanup(self): """Clean up members that might inhibit garbage collection""" self._message_confirmation_result.reset() diff --git a/tests/unit/blocking_channel_tests.py b/tests/unit/blocking_channel_tests.py index 9a81f9b92..de7267211 100644 --- a/tests/unit/blocking_channel_tests.py +++ b/tests/unit/blocking_channel_tests.py @@ -19,6 +19,7 @@ from pika.adapters import blocking_connection from pika import callback from pika import channel +from pika import exceptions from pika import frame from pika import spec @@ -65,4 +66,25 @@ def test_basic_consume(self): self.obj.basic_consume(mock.Mock(), "queue") self.assertEqual(self.obj._consumer_infos['ctag0'].state, - blocking_connection._ConsumerInfo.ACTIVE) \ No newline at end of file + blocking_connection._ConsumerInfo.ACTIVE) + + def test_context_manager(self): + with self.obj as channel: + self.assertFalse(channel._impl.close.called) + channel._impl.close.assert_called_once_with(reply_code=0, reply_text='Normal Shutdown') + + def test_context_manager_does_not_suppress_exception(self): + class TestException(Exception): + pass + + with self.assertRaises(TestException): + with self.obj as channel: + self.assertFalse(channel._impl.close.called) + raise TestException() + channel._impl.close.assert_called_once_with(reply_code=0, reply_text='Normal Shutdown') + + def test_context_manager_exit_with_closed_channel(self): + with self.obj as channel: + self.assertFalse(channel._impl.close.called) + channel.close() + channel._impl.close.assert_called_with(reply_code=0, reply_text='Normal Shutdown')