Skip to content

Commit

Permalink
Merge pull request #132 from Lenka42/master
Browse files Browse the repository at this point in the history
Fix unsubscribe cleans subscriptions storage
  • Loading branch information
Lenka42 committed Jul 27, 2021
2 parents 46996b1 + ba61805 commit 50bc083
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 61 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dist/

# virtualenvs
env/
venv/
pyenv/

# pytest
Expand Down
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
language: python
python:
- "3.9"
- "3.8"
- "3.7"
- "3.6"
Expand Down
4 changes: 2 additions & 2 deletions gmqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

__credits__ = [
"Mikhail Turchunovich",
"Elena Nikolaichik"
"Elena Shylko"
]
__version__ = "0.6.9"
__version__ = "0.6.10"


__all__ = [
Expand Down
126 changes: 68 additions & 58 deletions gmqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,74 @@ def __init__(self, topic, qos=0, no_local=False, retain_as_published=False, reta
self.subscription_identifier = subscription_identifier


class Client(MqttPackageHandler):
class SubscriptionsHandler:
def __init__(self):
self.subscriptions = []

def update_subscriptions_with_subscription_or_topic(
self, subscription_or_topic, qos, no_local, retain_as_published, retain_handling_options, kwargs):

sentinel = object()
subscription_identifier = kwargs.get('subscription_identifier', sentinel)

if isinstance(subscription_or_topic, Subscription):

if subscription_identifier is not sentinel:
subscription_or_topic.subscription_identifier = subscription_identifier

subscriptions = [subscription_or_topic]
elif isinstance(subscription_or_topic, (tuple, list)):

if subscription_identifier is not sentinel:
for sub in subscription_or_topic:
sub.subscription_identifier = subscription_identifier

subscriptions = subscription_or_topic
elif isinstance(subscription_or_topic, str):

if subscription_identifier is sentinel:
subscription_identifier = None

subscriptions = [Subscription(subscription_or_topic, qos=qos, no_local=no_local,
retain_as_published=retain_as_published,
retain_handling_options=retain_handling_options,
subscription_identifier=subscription_identifier)]
else:
raise ValueError('Bad subscription: must be string or Subscription or list of Subscriptions')
self.subscriptions.extend(subscriptions)
return subscriptions

def _remove_subscriptions(self, topic: Union[str, Sequence[str]]):
if isinstance(topic, str):
self.subscriptions = [s for s in self.subscriptions if s.topic != topic]
else:
self.subscriptions = [s for s in self.subscriptions if s.topic not in topic]

def subscribe(self, subscription_or_topic: Union[str, Subscription, Sequence[Subscription]],
qos=0, no_local=False, retain_as_published=False, retain_handling_options=0, **kwargs):

# Warn: if you will pass a few subscriptions objects, and each will be have different
# subscription identifier - the only first will be used as identifier
# if only you will not pass the identifier in kwargs

subscriptions = self.update_subscriptions_with_subscription_or_topic(
subscription_or_topic, qos, no_local, retain_as_published, retain_handling_options, kwargs)
return self._connection.subscribe(subscriptions, **kwargs)

def resubscribe(self, subscription: Subscription, **kwargs):
# send subscribe packet for subscription,that's already in client's subscription list
if 'subscription_identifier' in kwargs:
subscription.subscription_identifier = kwargs['subscription_identifier']
elif subscription.subscription_identifier is not None:
kwargs['subscription_identifier'] = subscription.subscription_identifier
return self._connection.subscribe([subscription], **kwargs)

def unsubscribe(self, topic: Union[str, Sequence[str]], **kwargs):
self._remove_subscriptions(topic)
return self._connection.unsubscribe(topic, **kwargs)


class Client(MqttPackageHandler, SubscriptionsHandler):
def __init__(self, client_id, clean_session=True, optimistic_acknowledgement=True,
will_message=None, **kwargs):
super(Client, self).__init__(optimistic_acknowledgement=optimistic_acknowledgement)
Expand Down Expand Up @@ -89,8 +156,6 @@ def __init__(self, client_id, clean_session=True, optimistic_acknowledgement=Tru

self._resend_task = asyncio.ensure_future(self._resend_qos_messages())

self.subscriptions = []

def get_subscription_by_identifier(self, subscription_identifier):
return next((sub for sub in self.subscriptions if sub.subscription_identifier == subscription_identifier), None)

Expand Down Expand Up @@ -222,61 +287,6 @@ async def _disconnect(self, reason_code=0, **properties):
self._connection.send_disconnect(reason_code=reason_code, **properties)
await self._connection.close()

def subscribe(self, subscription_or_topic: Union[str, Subscription, Sequence[Subscription]],
qos=0, no_local=False, retain_as_published=False, retain_handling_options=0, **kwargs):

# Warn: if you will pass a few subscriptions objects, and each will be have different
# subscription identifier - the only first will be used as identifier
# if only you will not pass the identifier in kwargs

subscriptions = self.update_subscriptions_with_subscription_or_topic(
subscription_or_topic, qos, no_local, retain_as_published, retain_handling_options, kwargs)
return self._connection.subscribe(subscriptions, **kwargs)

def update_subscriptions_with_subscription_or_topic(
self, subscription_or_topic, qos, no_local, retain_as_published, retain_handling_options, kwargs):

sentinel = object()
subscription_identifier = kwargs.get('subscription_identifier', sentinel)

if isinstance(subscription_or_topic, Subscription):

if subscription_identifier is not sentinel:
subscription_or_topic.subscription_identifier = subscription_identifier

subscriptions = [subscription_or_topic]
elif isinstance(subscription_or_topic, (tuple, list)):

if subscription_identifier is not sentinel:
for sub in subscription_or_topic:
sub.subscription_identifier = subscription_identifier

subscriptions = subscription_or_topic
elif isinstance(subscription_or_topic, str):

if subscription_identifier is sentinel:
subscription_identifier = None

subscriptions = [Subscription(subscription_or_topic, qos=qos, no_local=no_local,
retain_as_published=retain_as_published,
retain_handling_options=retain_handling_options,
subscription_identifier=subscription_identifier)]
else:
raise ValueError('Bad subscription: must be string or Subscription or list of Subscriptions')
self.subscriptions.extend(subscriptions)
return subscriptions

def resubscribe(self, subscription: Subscription, **kwargs):
# send subscribe packet for subscription,that's already in client's subscription list
if 'subscription_identifier' in kwargs:
subscription.subscription_identifier = kwargs['subscription_identifier']
elif subscription.subscription_identifier is not None:
kwargs['subscription_identifier'] = subscription.subscription_identifier
return self._connection.subscribe([subscription], **kwargs)

def unsubscribe(self, topic, **kwargs):
return self._connection.unsubscribe(topic, **kwargs)

def publish(self, message_or_topic, payload=None, qos=0, retain=False, **kwargs):
if isinstance(message_or_topic, Message):
message = message_or_topic
Expand Down
6 changes: 5 additions & 1 deletion tests/test_mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ async def test_unsubscribe(init_clients):
bclient.subscribe(TOPICS[2])
bclient.subscribe(TOPICS[3])
await asyncio.sleep(1)
assert len(bclient.subscriptions) == 3

aclient.publish(TOPICS[1], b"topic 0 - subscribed", 1, retain=False)
aclient.publish(TOPICS[2], b"topic 1", 1, retain=False)
Expand All @@ -242,6 +243,7 @@ async def test_unsubscribe(init_clients):
callback2.clear()
# Unsubscribe from one topic
bclient.unsubscribe(TOPICS[1])
assert len(bclient.subscriptions) == 2
await asyncio.sleep(3)

aclient.publish(TOPICS[1], b"topic 0 - unsubscribed", 1, retain=False)
Expand Down Expand Up @@ -478,7 +480,9 @@ async def test_reconnection_with_failure(init_clients):
disconnect_mock.side_effect = ConnectionAbortedError("error")
await aclient.reconnect()

await asyncio.sleep(3)

# Check aclient is still working after reconnection
aclient.publish(TOPICS[0], b"test")
await asyncio.sleep(5)
await asyncio.sleep(3)
assert len(callback2.messages) == 1

0 comments on commit 50bc083

Please sign in to comment.