diff --git a/txredis/protocol.py b/txredis/protocol.py index d31fb8a..e7aeb04 100644 --- a/txredis/protocol.py +++ b/txredis/protocol.py @@ -104,7 +104,7 @@ class RedisBase(protocol.Protocol, policies.TimeoutMixin, object): def __init__(self, db=None, password=None, charset='utf8', errors='strict'): self.charset = charset - self.db = db + self.db = db if db is not None else 0 self.password = password self.errors = errors self._buffer = '' @@ -193,9 +193,22 @@ def failRequests(self, reason): def connectionMade(self): """ Called when incoming connections is made to the server. """ - self._disconnected = False + d = defer.succeed(True) + + # if we have a password set, make sure we auth if self.password: - return self.auth(self.password) + d.addCallback(lambda _res : self.auth(self.password)) + + # select the db passsed in + if self.db: + d.addCallback(lambda _res : self.select(self.db)) + + def done_connecting(_res): + # set our state as soon as we're properly connected + self._disconnected = False + d.addCallback(done_connecting) + + return d def connectionLost(self, reason): """Called when the connection is lost.