Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/thefab/tornadis
Browse files Browse the repository at this point in the history
  • Loading branch information
thefab committed Jun 15, 2015
2 parents dd8767b + 9001d29 commit 6b07a26
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 28 deletions.
8 changes: 8 additions & 0 deletions tests/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,11 @@ def test_redis_or_raise_skiptest(host="localhost", port=6379):
except socket.error:
raise unittest.SkipTest("redis must be launched on %s:%i" % (host,
port))


def test_redis_uds_or_raise_skiptest(uds="/tmp/redis.sock"):
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
s.connect(uds)
except socket.error:
raise unittest.SkipTest("redis must listen on %s" % uds)
51 changes: 38 additions & 13 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tornadis.connection import Connection
from tornadis.utils import format_args_in_redis_protocol
from support import test_redis_or_raise_skiptest
from support import test_redis_uds_or_raise_skiptest
import hiredis
import functools
import random
Expand Down Expand Up @@ -97,23 +98,13 @@ def fake_socket_constructor(cls, *args, **kwargs):
return cls(*args, **kwargs)


class ConnectionTestCase(tornado.testing.AsyncTestCase):
class AbstractConnectionTestCase(tornado.testing.AsyncTestCase):

def setUp(self):
test_redis_or_raise_skiptest()
super(ConnectionTestCase, self).setUp()
test_redis_uds_or_raise_skiptest()
super(AbstractConnectionTestCase, self).setUp()
self.reader = hiredis.Reader()
self.reply_queue = toro.Queue()
self.replies = []

def get_new_ioloop(self):
return tornado.ioloop.IOLoop.instance()

@tornado.testing.gen_test
def test_init(self):
c = Connection(self._read_cb, self._close_cb)
yield c.connect()
c.disconnect()

def _close_cb(self):
pass
Expand All @@ -127,6 +118,40 @@ def _read_cb(self, data):
else:
break


class UDSConnectionTestCase(AbstractConnectionTestCase):

def setUp(self):
test_redis_uds_or_raise_skiptest()
super(UDSConnectionTestCase, self).setUp()

def get_new_ioloop(self):
return tornado.ioloop.IOLoop.instance()

@tornado.testing.gen_test
def test_init(self):
c = Connection(self._read_cb, self._close_cb,
unix_domain_socket="/tmp/redis.sock")
yield c.connect()
c.disconnect()


class ConnectionTestCase(AbstractConnectionTestCase):

def setUp(self):
test_redis_or_raise_skiptest()
super(ConnectionTestCase, self).setUp()
self.replies = []

def get_new_ioloop(self):
return tornado.ioloop.IOLoop.instance()

@tornado.testing.gen_test
def test_init(self):
c = Connection(self._read_cb, self._close_cb)
yield c.connect()
c.disconnect()

@tornado.testing.gen_test
def test_init_with_tcp_nodelay(self):
c = Connection(self._read_cb, self._close_cb, tcp_nodelay=True)
Expand Down
17 changes: 15 additions & 2 deletions tornadis/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def get_parameters():
default="127.0.0.1")
parser.add_argument('-p', '--port', help="Server port (default 6379)",
default=6379)
parser.add_argument('-u', '--unix_domain_socket',
help="path to a unix socket to connect to (if set "
", overrides host/port parameters)")
parser.add_argument('-a', '--password', help="Password for Redis Auth")
parser.add_argument('-c', '--clients',
help="Number of parallel connections (default 5)",
Expand Down Expand Up @@ -60,7 +63,12 @@ def __init__(self, params):

@tornado.gen.coroutine
def multiple_set(self, client_number):
client = tornadis.Client()
uds = self.params.unix_domain_socket
client = tornadis.Client(host=self.params.hostname,
port=self.params.port,
unix_domain_socket=uds,
autoconnect=False,
tcp_nodelay=True)
print_("Connect client", client_number)
yield client.connect()
print_("Client", client_number, "connected")
Expand Down Expand Up @@ -101,7 +109,12 @@ def _call_pipeline(self, client, pipeline, client_number):

@tornado.gen.coroutine
def pipelined_multiple_set(self, client_number):
client = tornadis.Client()
uds = self.params.unix_domain_socket
client = tornadis.Client(host=self.params.hostname,
port=self.params.port,
unix_domain_socket=uds,
autoconnect=False,
tcp_nodelay=True)
print_("Connect client", client_number)
yield client.connect()
print_("Client", client_number, "connected")
Expand Down
11 changes: 10 additions & 1 deletion tornadis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Client(object):
Attributes:
host (string): the host name to connect to.
port (int): the port to connect to.
unix_domain_socket (string): path to a unix socket to connect to
(if set, overrides host/port parameters).
read_page_size (int): page size for reading.
write_page_size (int): page size for writing.
connect_timeout (int): timeout (in seconds) for connecting.
Expand All @@ -41,6 +43,7 @@ class Client(object):
"""

def __init__(self, host=tornadis.DEFAULT_HOST, port=tornadis.DEFAULT_PORT,
unix_domain_socket=None,
read_page_size=tornadis.DEFAULT_READ_PAGE_SIZE,
write_page_size=tornadis.DEFAULT_WRITE_PAGE_SIZE,
connect_timeout=tornadis.DEFAULT_CONNECT_TIMEOUT,
Expand All @@ -51,6 +54,8 @@ def __init__(self, host=tornadis.DEFAULT_HOST, port=tornadis.DEFAULT_PORT,
Args:
host (string): the host name to connect to.
port (int): the port to connect to.
unix_domain_socket (string): path to a unix socket to connect to
(if set, overrides host/port parameters).
read_page_size (int): page size for reading.
write_page_size (int): page size for writing.
connect_timeout (int): timeout (in seconds) for connecting.
Expand All @@ -63,6 +68,7 @@ def __init__(self, host=tornadis.DEFAULT_HOST, port=tornadis.DEFAULT_PORT,
"""
self.host = host
self.port = port
self.unix_domain_socket = unix_domain_socket
self.read_page_size = read_page_size
self.write_page_size = write_page_size
self.connect_timeout = connect_timeout
Expand Down Expand Up @@ -104,8 +110,11 @@ def connect(self):
self.__callback_queue = collections.deque()
self._reply_list = []
self.__reader = hiredis.Reader()
uds = self.unix_domain_socket
self.__connection = Connection(cb1, cb2, host=self.host,
port=self.port, ioloop=self.__ioloop,
port=self.port,
unix_domain_socket=uds,
ioloop=self.__ioloop,
read_page_size=self.read_page_size,
write_page_size=self.write_page_size,
connect_timeout=self.connect_timeout,
Expand Down
46 changes: 34 additions & 12 deletions tornadis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# See the LICENSE file for more information.

import socket
import os
import tornado.iostream
import tornado.gen
from tornado.util import errno_from_exception
Expand Down Expand Up @@ -41,6 +42,8 @@ class Connection(object):
Attributes:
host (string): the host name to connect to.
port (int): the port to connect to.
unix_domain_socket (string): path to a unix socket to connect to
(if set, overrides host/port parameters).
read_page_size (int): page size for reading.
write_page_size (int): page size for writing.
connect_timeout (int): timeout (in seconds) for connecting.
Expand All @@ -51,7 +54,7 @@ class Connection(object):

def __init__(self, read_callback, close_callback,
host=tornadis.DEFAULT_HOST,
port=tornadis.DEFAULT_PORT,
port=tornadis.DEFAULT_PORT, unix_domain_socket=None,
read_page_size=tornadis.DEFAULT_READ_PAGE_SIZE,
write_page_size=tornadis.DEFAULT_WRITE_PAGE_SIZE,
connect_timeout=tornadis.DEFAULT_CONNECT_TIMEOUT,
Expand All @@ -63,6 +66,8 @@ def __init__(self, read_callback, close_callback,
close_callback: callback called when the connection is closed.
host (string): the host name to connect to.
port (int): the port to connect to.
unix_domain_socket (string): path to a unix socket to connect to
(if set, overrides host/port parameters).
read_page_size (int): page size for reading.
write_page_size (int): page size for writing.
connect_timeout (int): timeout (in seconds) for connecting.
Expand All @@ -73,6 +78,7 @@ def __init__(self, read_callback, close_callback,
"""
self.host = host
self.port = port
self.unix_domain_socket = unix_domain_socket
self._state = ConnectionState()
self.__ioloop = ioloop or tornado.ioloop.IOLoop.instance()
cb = tornado.ioloop.PeriodicCallback(self._on_every_second, 1000,
Expand All @@ -88,6 +94,11 @@ def __init__(self, read_callback, close_callback,
self._write_buffer = WriteBuffer()
self._listened_events = 0

def _redis_server(self):
if self.unix_domain_socket:
return self.unix_domain_socket
return "%s:%i" % (self.host, self.port)

def is_connecting(self):
"""Returns True if the object is connecting."""
return self._state.is_connecting()
Expand All @@ -106,29 +117,40 @@ def connect(self):
"""
if self.is_connected() or self.is_connecting():
raise tornado.gen.Return(True)
self.__socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if self.unix_domain_socket is None:
self.__socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if self.tcp_nodelay:
self.__socket.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY, 1)
else:
if not os.path.exists(self.unix_domain_socket):
LOG.warning("can't connect to %s, file does not exist",
self.unix_domain_socket)
raise tornado.gen.Return(False)
self.__socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.__socket.setblocking(0)
if self.tcp_nodelay:
self.__socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.__periodic_callback.start()
try:
LOG.debug("connecting to %s:%i...", self.host, self.port)
LOG.debug("connecting to %s...", self._redis_server())
self._state.set_connecting()
self.__socket.connect((self.host, self.port))
if self.unix_domain_socket is None:
self.__socket.connect((self.host, self.port))
else:
self.__socket.connect(self.unix_domain_socket)
except socket.error as e:
if (errno_from_exception(e) not in _ERRNO_INPROGRESS and
errno_from_exception(e) not in _ERRNO_WOULDBLOCK):
self.disconnect()
LOG.warning("can't connect to %s:%i", self.host, self.port)
LOG.warning("can't connect to %s", self._redis_server())
raise tornado.gen.Return(False)
self.__socket_fileno = self.__socket.fileno()
self._register_or_update_event_handler()
yield self._state.get_changed_state_future()
if not self.is_connected():
LOG.warning("can't connect to %s:%i", self.host, self.port)
LOG.warning("can't connect to %s", self._redis_server())
raise tornado.gen.Return(False)
else:
LOG.debug("connected to %s:%i", self.host, self.port)
LOG.debug("connected to %s", self._redis_server())
self.__socket_fileno = self.__socket.fileno()
self._state.set_connected()
self._register_or_update_event_handler()
Expand Down Expand Up @@ -170,7 +192,7 @@ def disconnect(self):
"""
if not self.is_connected() and not self.is_connecting():
return
LOG.debug("disconnecting from %s:%i...", self.host, self.port)
LOG.debug("disconnecting from %s...", self._redis_server())
self.__periodic_callback.stop()
try:
self.__ioloop.remove_handler(self.__socket_fileno)
Expand All @@ -184,7 +206,7 @@ def disconnect(self):
pass
self._state.set_disconnected()
self._close_callback()
LOG.debug("disconnected from %s:%i", self.host, self.port)
LOG.debug("disconnected from %s", self._redis_server())

def _handle_events(self, fd, event):
if self.is_connecting():
Expand All @@ -194,7 +216,7 @@ def _handle_events(self, fd, event):
self.disconnect()
return
self._state.set_connected()
LOG.debug("connected to %s:%i", self.host, self.port)
LOG.debug("connected to %s", self._redis_server())
if not self.is_connected():
return
if event & self.__ioloop.READ:
Expand Down

0 comments on commit 6b07a26

Please sign in to comment.