Skip to content

Commit

Permalink
Merge b1b7be7 into 475c9d6
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrokiefer committed Apr 2, 2019
2 parents 475c9d6 + b1b7be7 commit 9856503
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 9 deletions.
74 changes: 66 additions & 8 deletions aiostomp/aiostomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ async def run(self):
self.print_stats()


class AutoAckContextManager:
def __init__(self, protocol, ack_mode='auto', enabled=True):
self.protocol = protocol
self.enabled = enabled
self.ack_mode = ack_mode
self.result = None
self.frame = None

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
if not self.enabled:
return

if not self.frame:
return

if self.ack_mode in ['client', 'client-individual']:
if self.result:
self.protocol.ack(self.frame)
else:
self.protocol.nack(self.frame)


class AioStomp:

def __init__(self, host, port,
Expand Down Expand Up @@ -186,7 +211,7 @@ def connection_lost(self, exc):
logger.info('Connection lost, will retry.')
asyncio.ensure_future(self._reconnect(), loop=self._loop)

def subscribe(self, destination, ack='auto', extra_headers=None, handler=None):
def subscribe(self, destination, ack='auto', extra_headers=None, handler=None, auto_ack=True):
extra_headers = {} if extra_headers is None else extra_headers
self._last_subscribe_id += 1

Expand All @@ -195,7 +220,8 @@ def subscribe(self, destination, ack='auto', extra_headers=None, handler=None):
id=self._last_subscribe_id,
ack=ack,
extra_headers=extra_headers,
handler=handler)
handler=handler,
auto_ack=auto_ack)

self._subscriptions[str(self._last_subscribe_id)] = subscription

Expand Down Expand Up @@ -230,6 +256,32 @@ def send(self, destination, body='', headers=None, send_content_length=True):

return self._protocol.send(headers, body)

def _subscription_auto_ack(self, frame):
key = frame.headers.get('subscription')

subscription = self._subscriptions.get(key)
if not subscription:
logger.warn('Subscription %s not found.' % key)
return True

if subscription.auto_ack:
logger.warn('Auto ack/nack is enabled. Ignoring call.')
return True

return False

def ack(self, frame):
if self._subscription_auto_ack(frame):
return

return self._protocol.ack(frame)

def nack(self, frame):
if self._subscription_auto_ack(frame):
return

return self._protocol.nack(frame)

def get(self, key):
return self._subscriptions.get(key)

Expand Down Expand Up @@ -365,13 +417,13 @@ async def _handle_message(self, frame):
if self._stats:
self._stats.increment('rec_msg')

result = await subscription.handler(frame, frame.body)
with AutoAckContextManager(self,
ack_mode=subscription.ack,
enabled=subscription.auto_ack) as ack_context:
result = await subscription.handler(frame, frame.body)

if subscription.ack in ['client', 'client-individual']:
if result:
self.ack(frame)
else:
self.nack(frame)
ack_context.frame = frame
ack_context.result = result

async def _handle_error(self, frame):
message = frame.headers.get('message')
Expand Down Expand Up @@ -460,3 +512,9 @@ def unsubscribe(self, subscription):

def send(self, headers, body):
return self._protocol.send_frame('SEND', headers, body)

def ack(self, frame):
return self._protocol.ack(frame)

def nack(self, frame):
return self._protocol.nack(frame)
3 changes: 2 additions & 1 deletion aiostomp/subscription.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
class Subscription(object):

def __init__(self, destination, id, ack, extra_headers, handler):
def __init__(self, destination, id, ack, extra_headers, handler, auto_ack=True):
self.destination = destination
self.id = id
self.ack = ack
self.extra_headers = extra_headers
self.handler = handler
self.auto_ack = auto_ack
63 changes: 63 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,69 @@ def test_can_send_message_without_body(self):
'my-header': 'my-value',
}, '')

def test_can_ack_a_frame(self):
self.stomp._protocol.subscribe = Mock()
self.stomp._protocol.ack = Mock()

self.stomp.subscribe('/queue/test', auto_ack=False)
self.assertEqual(len(self.stomp._subscriptions), 1)

frame = Frame('MESSAGE', {'subscription': '1'}, 'data')

self.stomp.ack(frame)

self.stomp._protocol.ack.assert_called_with(frame)

def test_can_nack_a_frame(self):
self.stomp._protocol.subscribe = Mock()
self.stomp._protocol.nack = Mock()

self.stomp.subscribe('/queue/test', auto_ack=False)
self.assertEqual(len(self.stomp._subscriptions), 1)

frame = Frame('MESSAGE', {'subscription': '1'}, 'data')

self.stomp.nack(frame)

self.stomp._protocol.nack.assert_called_with(frame)

@patch('aiostomp.aiostomp.logger')
def test_cannot_ack_an_unsubscribed_frame(self, logger_mock):
self.stomp._protocol.ack = Mock()
self.assertEqual(len(self.stomp._subscriptions), 0)

frame = Frame('MESSAGE', {'subscription': '1'}, 'data')

self.stomp.ack(frame)
logger_mock.warn.assert_called_with('Subscription 1 not found.')
self.stomp._protocol.ack.assert_not_called()

@patch('aiostomp.aiostomp.logger')
def test_cannot_nack_an_unsubscribed_frame(self, logger_mock):
self.stomp._protocol.nack = Mock()
self.assertEqual(len(self.stomp._subscriptions), 0)

frame = Frame('MESSAGE', {'subscription': '1'}, 'data')

self.stomp.nack(frame)
logger_mock.warn.assert_called_with('Subscription 1 not found.')
self.stomp._protocol.nack.assert_not_called()

@patch('aiostomp.aiostomp.logger')
def test_cannot_ack_an_auto_ack_frame(self, logger_mock):
self.stomp._protocol.subscribe = Mock()
self.stomp._protocol.ack = Mock()

self.stomp.subscribe('/queue/test', auto_ack=True)
self.assertEqual(len(self.stomp._subscriptions), 1)

frame = Frame('MESSAGE', {'subscription': '1'}, 'data')

self.stomp.ack(frame)

logger_mock.warn.assert_called_with('Auto ack/nack is enabled. Ignoring call.')
self.stomp._protocol.ack.assert_not_called()


class TestStompProtocol(AsyncTestCase):

Expand Down

0 comments on commit 9856503

Please sign in to comment.