Skip to content

Commit

Permalink
Merge pull request #1058 from maggyero/patch-1
Browse files Browse the repository at this point in the history
 Add and correct the implementations of comparison methods
  • Loading branch information
lukebakken committed Aug 28, 2018
2 parents 706ac7e + 08bcb30 commit 4ec4a42
Show file tree
Hide file tree
Showing 6 changed files with 483 additions and 104 deletions.
33 changes: 28 additions & 5 deletions pika/adapters/select_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import abc
import collections
import errno
import functools
import heapq
import logging
import select
Expand Down Expand Up @@ -139,7 +138,6 @@ def connection_factory(params):
on_done=on_done)


@functools.total_ordering
class _Timeout(object):
"""Represents a timeout"""

Expand All @@ -166,15 +164,40 @@ def __init__(self, deadline, callback):

def __eq__(self, other):
"""NOTE: not supporting sort stability"""
return self.deadline == other.deadline
if isinstance(other, _Timeout):
return self.deadline == other.deadline
return NotImplemented

def __ne__(self, other):
"""NOTE: not supporting sort stability"""
result = self.__eq__(other)
if result is not NotImplemented:
return not result
return NotImplemented

def __lt__(self, other):
"""NOTE: not supporting sort stability"""
return self.deadline < other.deadline
if isinstance(other, _Timeout):
return self.deadline < other.deadline
return NotImplemented

def __gt__(self, other):
"""NOTE: not supporting sort stability"""
if isinstance(other, _Timeout):
return self.deadline > other.deadline
return NotImplemented

def __le__(self, other):
"""NOTE: not supporting sort stability"""
return self.deadline <= other.deadline
if isinstance(other, _Timeout):
return self.deadline <= other.deadline
return NotImplemented

def __ge__(self, other):
"""NOTE: not supporting sort stability"""
if isinstance(other, _Timeout):
return self.deadline >= other.deadline
return NotImplemented


class _Timer(object):
Expand Down
11 changes: 11 additions & 0 deletions pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,17 @@ def __repr__(self):
(self.__class__.__name__, self.host, self.port,
self.virtual_host, bool(self.ssl_options)))

def __eq__(self, other):
if isinstance(other, Parameters):
return (self._host == other._host and self._port == other._port)
return NotImplemented

def __ne__(self, other):
result = self.__eq__(other)
if result is not NotImplemented:
return not result
return NotImplemented

@property
def backpressure_detection(self):
"""
Expand Down
24 changes: 16 additions & 8 deletions pika/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,17 @@ def __init__(self, username, password, erase_on_connect=False):
self.erase_on_connect = erase_on_connect

def __eq__(self, other):
return (isinstance(other, PlainCredentials) and
other.username == self.username and
other.password == self.password and
other.erase_on_connect == self.erase_on_connect)
if isinstance(other, PlainCredentials):
return (self.username == other.username
and self.password == other.password
and self.erase_on_connect == other.erase_on_connect)
return NotImplemented

def __ne__(self, other):
return not self == other
result = self.__eq__(other)
if result is not NotImplemented:
return not result
return NotImplemented

def response_for(self, start):
"""Validate that this type of authentication is supported
Expand Down Expand Up @@ -94,11 +98,15 @@ def __init__(self):
self.erase_on_connect = False

def __eq__(self, other):
return (isinstance(other, ExternalCredentials) and
other.erase_on_connect == self.erase_on_connect)
if isinstance(other, ExternalCredentials):
return self.erase_on_connect == other.erase_on_connect
return NotImplemented

def __ne__(self, other):
return not self == other
result = self.__eq__(other)
if result is not NotImplemented:
return not result
return NotImplemented

def response_for(self, start):
"""Validate that this type of authentication is supported
Expand Down
115 changes: 106 additions & 9 deletions tests/unit/connection_parameters_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@
and issubclass(type(getattr(connection.Parameters, attr)), property))


class _ParametersTestsBase(unittest.TestCase):
class ChildParameters(connection.Parameters):

def __init__(self, *args, **kwargs):
super(ChildParameters, self).__init__(*args, **kwargs)
self.extra = 'e'

def __eq__(self, other):
if isinstance(other, ChildParameters):
return self.extra == other.extra and super(
ChildParameters, self).__eq__(other)
return NotImplemented


class ParametersTestsBase(unittest.TestCase):

def setUp(self):
warnings.resetwarnings()
self.addCleanup(warnings.resetwarnings)
Expand Down Expand Up @@ -101,8 +115,89 @@ def assert_default_parameter_values(self, params):
value))


class ParametersTests(_ParametersTestsBase):
"""Test `pika.connection.Parameters`"""
class ParametersTests(ParametersTestsBase):

def test_eq(self):
params_1 = connection.Parameters()
params_2 = connection.Parameters()
params_3 = ChildParameters()

self.assertEqual(params_1, params_2)
self.assertEqual(params_2, params_1)

params_1.host = 'localhost'
params_1.port = 5672
params_1.virtual_host = '/'
params_1.credentials = credentials.PlainCredentials('u', 'p')
params_2.host = 'localhost'
params_2.port = 5672
params_2.virtual_host = '//'
params_2.credentials = credentials.PlainCredentials('uu', 'pp')
self.assertEqual(params_1, params_2)
self.assertEqual(params_2, params_1)

params_1.host = 'localhost'
params_1.port = 5672
params_1.virtual_host = '/'
params_1.credentials = credentials.PlainCredentials('u', 'p')
params_3.host = 'localhost'
params_3.port = 5672
params_3.virtual_host = '//'
params_3.credentials = credentials.PlainCredentials('uu', 'pp')
self.assertEqual(params_1, params_3)
self.assertEqual(params_3, params_1)

class Foreign(object):

def __eq__(self, other):
return 'foobar'

self.assertEqual(params_1 == Foreign(), 'foobar')
self.assertEqual(Foreign() == params_1, 'foobar')

def test_ne(self):
params_1 = connection.Parameters()
params_2 = connection.Parameters()
params_3 = ChildParameters()

params_1.host = 'localhost'
params_1.port = 5672
params_2.host = 'myserver.com'
params_2.port = 5672
self.assertNotEqual(params_1, params_2)
self.assertNotEqual(params_2, params_1)

params_1.host = 'localhost'
params_1.port = 5672
params_2.host = 'localhost'
params_2.port = 5671
self.assertNotEqual(params_1, params_2)
self.assertNotEqual(params_2, params_1)

params_1.host = 'localhost'
params_1.port = 5672
params_3.host = 'myserver.com'
params_3.port = 5672
self.assertNotEqual(params_1, params_3)
self.assertNotEqual(params_3, params_1)

params_1.host = 'localhost'
params_1.port = 5672
params_3.host = 'localhost'
params_3.port = 5671
self.assertNotEqual(params_1, params_3)
self.assertNotEqual(params_3, params_1)

self.assertNotEqual(params_1, dict(host='localhost', port=5672))
self.assertNotEqual(dict(host='localhost', port=5672), params_1)

class Foreign(object):

def __ne__(self, other):
return 'foobar'

self.assertEqual(params_1 != Foreign(), 'foobar')
self.assertEqual(Foreign() != params_1, 'foobar')

def test_default_property_values(self):
self.assert_default_parameter_values(connection.Parameters())
Expand Down Expand Up @@ -265,11 +360,11 @@ def test_host(self):
params.host = '127.0.0.1'
self.assertEqual(params.host, '127.0.0.1')

params.host = 'my.server.com'
self.assertEqual(params.host, 'my.server.com')
params.host = 'myserver.com'
self.assertEqual(params.host, 'myserver.com')

params.host = u'my.server.com'
self.assertEqual(params.host, u'my.server.com')
params.host = u'myserver.com'
self.assertEqual(params.host, u'myserver.com')

with self.assertRaises(TypeError):
params.host = 127
Expand Down Expand Up @@ -392,7 +487,8 @@ def test_tcp_options(self):
params.tcp_options = str(opt)


class ConnectionParametersTests(_ParametersTestsBase):
class ConnectionParametersTests(ParametersTestsBase):

def test_default_property_values(self):
self.assert_default_parameter_values(connection.ConnectionParameters())

Expand Down Expand Up @@ -547,7 +643,8 @@ def test_parameters_accept_unicode_locale(self):
self.assertEqual(parameters.locale, 'en_US')


class URLParametersTests(_ParametersTestsBase):
class URLParametersTests(ParametersTestsBase):

def test_default_property_values(self):
params = connection.URLParameters('')
self.assert_default_parameter_values(params)
Expand Down

0 comments on commit 4ec4a42

Please sign in to comment.