From 9551dfd08e29728b5c77530536c54ad45d75a60e Mon Sep 17 00:00:00 2001 From: Nick Pope Date: Wed, 2 Sep 2020 20:43:31 +0100 Subject: [PATCH] Added support for connections over IPv6. Fixes #257. --- ADOPTERS.md | 1 + README.rst | 2 + docs/getting_started.rst | 54 +++++++++++++++------- pymemcache/client/base.py | 69 ++++++++++++++++++++++------- pymemcache/client/hash.py | 9 +++- pymemcache/test/test_client.py | 67 +++++++++++++++++++++++++--- pymemcache/test/test_client_hash.py | 17 +++++-- 7 files changed, 175 insertions(+), 44 deletions(-) diff --git a/ADOPTERS.md b/ADOPTERS.md index 657767d5..bcddbc86 100644 --- a/ADOPTERS.md +++ b/ADOPTERS.md @@ -4,4 +4,5 @@ This is an alphabetical list of people and organizations who are using this project. If you'd like to be included here, please send a Pull Request that adds your information to this file. +- [Django](https://www.djangoproject.com/) - [Pinterest](https://www.pinterest.com/) diff --git a/README.rst b/README.rst index 81d3cbfa..286b1c00 100644 --- a/README.rst +++ b/README.rst @@ -13,6 +13,7 @@ A comprehensive, fast, pure-Python memcached client. pymemcache supports the following features: * Complete implementation of the memcached text protocol. +* Connections using UNIX sockets, or TCP over IPv4 or IPv6. * Configurable timeouts for socket connect and send/recv calls. * Access to the "noreply" flag, which can significantly increase the speed of writes. * Flexible, modular and simple approach to serialization and deserialization. @@ -126,6 +127,7 @@ Credits * `Stephen Rosen `_ * `Feras Alazzeh `_ * `Moisés Guimarães de Medeiros `_ +* `Nick Pope `_ We're Hiring! ============= diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 36ee6a31..c6733bd1 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -9,20 +9,43 @@ Basic Usage from pymemcache.client.base import Client - client = Client(('localhost', 11211)) + client = Client('localhost') client.set('some_key', 'some_value') result = client.get('some_key') -Using UNIX domain sockets -------------------------- -You can also connect to a local memcached server over a UNIX domain socket by -passing the socket's path to the client's ``server`` parameter: +The server to connect to can be specified in a number of ways. + +If using TCP connections over IPv4 or IPv6, the ``server`` parameter can be +passed a ``host`` string, a ``host:port`` string, or a ``(host, port)`` +2-tuple. The host part may be a domain name, an IPv4 address, or an IPv6 +address. The port may be omitted, in which case it will default to ``11211``. .. code-block:: python - from pymemcache.client.base import Client + ipv4_client = Client('127.0.0.1') + ipv4_client_with_port = Client('127.0.0.1:11211') + ipv4_client_using_tuple = Client(('127.0.0.1', 11211)) + + ipv6_client = Client('[::1]') + ipv6_client_with_port = Client('[::1]:11211') + ipv6_client_using_tuple = Client(('::1', 11211)) + + domain_client = Client('localhost') + domain_client_with_port = Client('localhost:11211') + domain_client_using_tuple = Client(('localhost', 11211)) - client = Client('/var/run/memcached/memcached.sock') +Note that IPv6 may be used in preference to IPv4 when passing a domain name as +the host if an IPv6 address can be resolved for that domain. + +You can also connect to a local memcached server over a UNIX domain socket by +passing the socket's path to the client's ``server`` parameter. An optional +``unix:`` prefix may be used for compatibility in code that uses other client +libraries that require it. + +.. code-block:: python + + client = Client('/run/memcached/memcached.sock') + client_with_prefix = Client('unix:/run/memcached/memcached.sock') Using a client pool ------------------- @@ -35,7 +58,7 @@ clients for improved performance. from pymemcache.client.base import PooledClient - client = PooledClient(('127.0.0.1', 11211), max_pool_size=4) + client = PooledClient('127.0.0.1', max_pool_size=4) Using a memcached cluster ------------------------- @@ -48,8 +71,8 @@ on if a server goes down. from pymemcache.client.hash import HashClient client = HashClient([ - ('127.0.0.1', 11211), - ('127.0.0.1', 11212) + '127.0.0.1:11211', + '127.0.0.1:11212', ]) client.set('some_key', 'some value') result = client.get('some_key') @@ -74,7 +97,7 @@ To enable TLS in pymemcache, pass a valid TLS context to the client's cafile="my-ca-root.crt", ) - client = Client(('localhost', 11211), tls_context=context) + client = Client('localhost', tls_context=context) client.set('some_key', 'some_value') result = client.get('some_key') @@ -100,7 +123,7 @@ Serialization return json.loads(value) raise Exception("Unknown serialization format") - client = Client(('localhost', 11211), serde=JsonSerde()) + client = Client('localhost', serde=JsonSerde()) client.set('key', {'a':'b', 'c':'d'}) result = client.get('key') @@ -115,7 +138,7 @@ pymemcache provides a default class Foo(object): pass - client = Client(('localhost', 11211), serde=serde.pickle_serde) + client = Client('localhost', serde=serde.pickle_serde) client.set('key', Foo()) result = client.get('key') @@ -125,10 +148,7 @@ the version by explicitly instantiating :class:`pymemcache.serde.PickleSerde`: .. code-block:: python - client = Client( - ('localhost', 11211), - serde=serde.PickleSerde(pickle_version=2) - ) + client = Client('localhost', serde=serde.PickleSerde(pickle_version=2)) Deserialization with Python 3 diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 7b3ed6d6..3d790bc5 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -109,6 +109,27 @@ def check_key_helper(key, allow_unicode_keys, key_prefix=b''): return key +def normalize_server_spec(server): + if isinstance(server, tuple) or server is None: + return server + if isinstance(server, list): + return tuple(server) # Assume [host, port] provided. + if not isinstance(server, six.string_types): + raise ValueError('Unknown server provided: %r' % server) + if server.startswith('unix:'): + return server[5:] + if server.startswith('/'): + return server + if ':' not in server or server.endswith(']'): + host, port = server, 11211 + else: + host, port = server.rsplit(':', 1) + port = int(port) + if host.startswith('['): + host = host.strip('[]') + return (host, port) + + class Client(object): """ A client for a single memcached server. @@ -253,7 +274,7 @@ def __init__(self, The constructor does not make a connection to memcached. The first call to a method on the object will do that. """ - self.server = server + self.server = normalize_server_spec(server) self.serde = serde or LegacyWrappingSerde(serializer, deserializer) self.connect_timeout = connect_timeout self.timeout = timeout @@ -279,25 +300,41 @@ def check_key(self, key): def _connect(self): self.close() - if isinstance(self.server, (list, tuple)): - sock = self.socket_module.socket(self.socket_module.AF_INET, - self.socket_module.SOCK_STREAM) + s = self.socket_module + + if not isinstance(self.server, tuple): + sockaddr = self.server + sock = s.socket(s.AF_UNIX, s.SOCK_STREAM) - if self.tls_context: - sock = self.tls_context.wrap_socket( - sock, server_hostname=self.server[0] - ) else: - sock = self.socket_module.socket(self.socket_module.AF_UNIX, - self.socket_module.SOCK_STREAM) + sock = None + error = None + host, port = self.server + info = s.getaddrinfo(host, port, s.AF_UNSPEC, s.SOCK_STREAM, + s.IPPROTO_TCP) + for family, socktype, proto, _, sockaddr in info: + try: + sock = s.socket(family, socktype, proto) + if self.no_delay: + sock.setsockopt(s.IPPROTO_TCP, s.TCP_NODELAY, 1) + if self.tls_context: + context = self.tls_context + sock = context.wrap_socket(sock, server_hostname=host) + except Exception as e: + error = e + if sock is not None: + sock.close() + sock = None + else: + break + + if error is not None: + raise error + try: sock.settimeout(self.connect_timeout) - sock.connect(self.server) + sock.connect(sockaddr) sock.settimeout(self.timeout) - if self.no_delay and sock.family == self.socket_module.AF_INET: - sock.setsockopt(self.socket_module.IPPROTO_TCP, - self.socket_module.TCP_NODELAY, 1) - except Exception: sock.close() raise @@ -1030,7 +1067,7 @@ def __init__(self, allow_unicode_keys=False, encoding='ascii', tls_context=None): - self.server = server + self.server = normalize_server_spec(server) self.serde = serde or LegacyWrappingSerde(serializer, deserializer) self.connect_timeout = connect_timeout self.timeout = timeout diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index 9c61cb63..9e273ee8 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -4,7 +4,12 @@ import logging import six -from pymemcache.client.base import Client, PooledClient, check_key_helper +from pymemcache.client.base import ( + Client, + PooledClient, + check_key_helper, + normalize_server_spec, +) from pymemcache.client.rendezvous import RendezvousHash from pymemcache.exceptions import MemcacheError @@ -103,7 +108,7 @@ def __init__( }) for server in servers: - self.add_server(server) + self.add_server(normalize_server_spec(server)) self.encoding = encoding self.tls_context = tls_context diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 96e65515..7b8ac3bb 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -21,12 +21,13 @@ import json import os import mock +import re import socket import unittest import pytest -from pymemcache.client.base import PooledClient, Client +from pymemcache.client.base import PooledClient, Client, normalize_server_spec from pymemcache.exceptions import ( MemcacheClientError, MemcacheServerError, @@ -52,7 +53,10 @@ def __init__(self, recv_bufs, connect_failure=None, close_failure=None): @property def family(self): - return socket.AF_INET + # TODO: Use ipaddress module when dropping support for Python < 3.3 + ipv6_re = re.compile(r'^[0-9a-f:]+$') + is_ipv6 = any(ipv6_re.match(c[0]) for c in self.connections) + return socket.AF_INET6 if is_ipv6 else socket.AF_INET def sendall(self, value): self.send_bufs.append(value) @@ -103,7 +107,7 @@ def __init__(self, connect_failure=None, close_failure=None): self.close_failure = close_failure self.sockets = [] - def socket(self, family, type): + def socket(self, family, type, proto=0, fileno=None): socket = MockSocket( [], connect_failure=self.connect_failure, @@ -1075,12 +1079,40 @@ def test_version_exception(self): @pytest.mark.unit() class TestClientSocketConnect(unittest.TestCase): - def test_socket_connect(self): - server = ("example.com", 11211) + def test_socket_connect_ipv4(self): + server = ('127.0.0.1', 11211) client = Client(server, socket_module=MockSocketModule()) client._connect() + print(client.sock.connections) assert client.sock.connections == [server] + assert client.sock.family == socket.AF_INET + + timeout = 2 + connect_timeout = 3 + client = Client( + server, connect_timeout=connect_timeout, timeout=timeout, + socket_module=MockSocketModule()) + client._connect() + assert client.sock.timeouts == [connect_timeout, timeout] + + client = Client(server, socket_module=MockSocketModule()) + client._connect() + assert client.sock.socket_options == [] + + client = Client( + server, socket_module=MockSocketModule(), no_delay=True) + client._connect() + assert client.sock.socket_options == [(socket.IPPROTO_TCP, + socket.TCP_NODELAY, 1)] + + def test_socket_connect_ipv6(self): + server = ('::1', 11211) + + client = Client(server, socket_module=MockSocketModule()) + client._connect() + assert client.sock.connections == [server + (0, 0)] + assert client.sock.family == socket.AF_INET6 timeout = 2 connect_timeout = 3 @@ -1330,3 +1362,28 @@ def test_recv(self): b'ue1\r\nEND\r\n', ]) assert client[b'key1'] == b'value1' + + +@pytest.mark.unit() +class TestNormalizeServerSpec(unittest.TestCase): + def test_normalize_server_spec(self): + f = normalize_server_spec + assert f(None) is None + assert f(('127.0.0.1', 12345)) == ('127.0.0.1', 12345) + assert f(['127.0.0.1', 12345]) == ('127.0.0.1', 12345) + assert f('unix:/run/memcached/socket') == '/run/memcached/socket' + assert f('/run/memcached/socket') == '/run/memcached/socket' + assert f('localhost') == ('localhost', 11211) + assert f('localhost:12345') == ('localhost', 12345) + assert f('[::1]') == ('::1', 11211) + assert f('[::1]:12345') == ('::1', 12345) + assert f('127.0.0.1') == ('127.0.0.1', 11211) + assert f('127.0.0.1:12345') == ('127.0.0.1', 12345) + + with pytest.raises(ValueError) as excinfo: + f({'host': 12345}) + assert str(excinfo.value) == "Unknown server provided: {'host': 12345}" + + with pytest.raises(ValueError) as excinfo: + f(12345) + assert str(excinfo.value) == "Unknown server provided: 12345" diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py index 4ffad5c5..5dd4ec47 100644 --- a/pymemcache/test/test_client_hash.py +++ b/pymemcache/test/test_client_hash.py @@ -372,12 +372,21 @@ class MyClient(Client): assert isinstance(c, MyClient) def test_mixed_inet_and_unix_sockets(self): - servers = [ + expected = { '/tmp/pymemcache.{pid}'.format(pid=os.getpid()), ('127.0.0.1', 11211), - ] - client = HashClient(servers) - assert set(servers) == {c.server for c in client.clients.values()} + ('::1', 11211), + } + client = HashClient([ + '/tmp/pymemcache.{pid}'.format(pid=os.getpid()), + '127.0.0.1', + '127.0.0.1:11211', + '[::1]', + '[::1]:11211', + ('127.0.0.1', 11211), + ('::1', 11211), + ]) + assert expected == {c.server for c in client.clients.values()} def test_legacy_add_remove_server_signature(self): server = ('127.0.0.1', 11211)