Skip to content

Commit

Permalink
add authentication support
Browse files Browse the repository at this point in the history
* handle authentication transparently on each connection
  • Loading branch information
jehiah committed Oct 10, 2011
1 parent 6e5c771 commit 193df57
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 24 deletions.
96 changes: 87 additions & 9 deletions asyncmongo/connection.py
Expand Up @@ -24,36 +24,48 @@

import tornado.iostream
import socket
import helpers
import struct
import logging

from errors import ProgrammingError, IntegrityError, InterfaceError
from bson import SON
from errors import ProgrammingError, IntegrityError, InterfaceError, AuthenticationError
import message
import helpers

class Connection(object):
"""
:Parameters:
- `host`: hostname or ip of mongo host
- `port`: port to connect to
- `dbuser`: db user to connect with
- `dbpass`: db password
- `autoreconnect` (optional): auto reconnect on interface errors
"""
def __init__(self, host, port, autoreconnect=True, pool=None):
def __init__(self, host, port, dbuser=None, dbpass=None, autoreconnect=True, pool=None):
assert isinstance(host, (str, unicode))
assert isinstance(port, int)
assert isinstance(autoreconnect, bool)
assert isinstance(dbuser, (str, unicode, None.__class__))
assert isinstance(dbpass, (str, unicode, None.__class__))
assert pool
self.__host = host
self.__port = port
self.__dbuser = dbuser
self.__dbpass = dbpass
self.__stream = None
self.__callback = None
self.__alive = False
self.__connect()
self.__authenticate = False
self.__autoreconnect = autoreconnect
self.__pool = pool
self.__deferred_message = None
self.__deferred_callback = None
self.usage_count = 0
self.__connect()

def __connect(self):
self.usage_count = 0
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect((self.__host, self.__port))
Expand All @@ -62,6 +74,9 @@ def __connect(self):
self.__alive = True
except socket.error, error:
raise InterfaceError(error)

if self.__dbuser and self.__dbpass:
self.__authenticate = True

def _socket_close(self):
"""cleanup after the socket is closed by the other end"""
Expand All @@ -88,8 +103,6 @@ def close(self):
def send_message(self, message, callback):
""" send a message over the wire; callback=None indicates a safe=False call where we write and forget about it"""

self.usage_count +=1
# TODO: handle reconnect
if self.__callback is not None:
raise ProgrammingError('connection already in use')

Expand All @@ -99,13 +112,22 @@ def send_message(self, message, callback):
else:
raise InterfaceError('connection invalid. autoreconnect=False')

self.__callback=callback
if self.__authenticate:
self.__deferred_message = message
self.__deferred_callback = callback
self._get_nonce(self._start_authentication)
else:
self.__callback = callback
self._send_message(message)

def _send_message(self, message):
self.usage_count +=1
# __request_id used by get_more()
(self.__request_id, data) = message
# logging.info('request id %d writing %r' % (self.__request_id, data))
try:
self.__stream.write(data)
if callback:
if self.__callback:
self.__stream.read_bytes(16, callback=self._parse_header)
else:
self.__request_id = None
Expand Down Expand Up @@ -140,7 +162,11 @@ def _parse_response(self, response):
request_id = self.__request_id
self.__request_id = None
self.__callback = None
self.__pool.cache(self)
if not self.__deferred_message:
# skip adding to the cache because there is something else
# that needs to be called on this connection for this request
# (ie: we authenticted, but still have to send the real req)
self.__pool.cache(self)

try:
response = helpers._unpack_response(response, request_id) # TODO: pass tz_awar
Expand All @@ -156,3 +182,55 @@ def _parse_response(self, response):
# logging.info('response: %s' % response)
callback(response)

def _start_authentication(self, response, error=None):
# this is the nonce response
if error:
logging.error(error)
logging.error(response)
raise AuthenticationError(error)
nonce = response['data'][0]['nonce']
key = helpers._auth_key(nonce, self.__dbuser, self.__dbpass)

self.__callback = self._finish_authentication
self._send_message(
message.query(0,
"%s.$cmd" % self.__pool._dbname,
0,
1,
SON([('authenticate', 1), ('user' , self.__dbuser), ('nonce' , nonce), ('key' , key)]),
SON({})))

def _finish_authentication(self, response, error=None):
if error:
self.__deferred_message = None
self.__deferred_callback = None
raise AuthenticationError(error)
assert response['number_returned'] == 1
response = response['data'][0]
if response['ok'] != 1:
logging.error('Failed authentication %s' % response['errmsg'])
self.__deferred_message = None
self.__deferred_callback = None
raise AuthenticationError(response['errmsg'])

message = self.__deferred_message
callback = self.__deferred_callback
self.__deferred_message = None
self.__deferred_callback = None
self.__callback = callback
# continue the original request
self._send_message(message)

def _get_nonce(self, callback):
assert self.__callback is None
self.__callback = callback
self._send_message(
message.query(0,
"%s.$cmd" % self.__pool._dbname,
0,
1,
SON({'getnonce' : 1}),
SON({})
))


13 changes: 9 additions & 4 deletions asyncmongo/cursor.py
Expand Up @@ -377,10 +377,15 @@ def find(self, spec=None, fields=None, skip=0, limit=0,
connection.send_message(
message.query(self.__query_options(),
self.full_collection_name,
self.__skip, self.__limit,
self.__query_spec(), self.__fields), callback=functools.partial(self._handle_response, orig_callback=callback))
except:
self.__skip,
self.__limit,
self.__query_spec(),
self.__fields),
callback=functools.partial(self._handle_response, orig_callback=callback))
except Exception as e:
logging.error('Error sending query %s' % e)
connection.close()
raise

def _handle_response(self, result, error=None, orig_callback=None):
if error:
Expand All @@ -398,7 +403,7 @@ def __query_options(self):
options = 0
if self.__tailable:
options |= _QUERY_OPTIONS["tailable_cursor"]
if self.__slave_okay or self.__pool.slave_okay:
if self.__slave_okay or self.__pool._slave_okay:
options |= _QUERY_OPTIONS["slave_okay"]
if not self.__timeout:
options |= _QUERY_OPTIONS["no_timeout"]
Expand Down
3 changes: 3 additions & 0 deletions asyncmongo/errors.py
Expand Up @@ -54,3 +54,6 @@ class NotSupportedError(DatabaseError):

class TooManyConnections(Error):
pass

class AuthenticationError(Error):
pass
22 changes: 22 additions & 0 deletions asyncmongo/helpers.py
@@ -1,3 +1,4 @@
import hashlib

import bson
from bson.son import SON
Expand Down Expand Up @@ -78,3 +79,24 @@ def _index_document(index_list):
"DESCENDING, or GEO2D")
index[key] = value
return index

def _password_digest(username, password):
"""Get a password digest to use for authentication.
"""
if not isinstance(password, basestring):
raise TypeError("password must be an instance of basestring")
if not isinstance(username, basestring):
raise TypeError("username must be an instance of basestring")

md5hash = hashlib.md5()
md5hash.update("%s:mongo:%s" % (username.encode('utf-8'),
password.encode('utf-8')))
return unicode(md5hash.hexdigest())

def _auth_key(nonce, username, password):
"""Get an auth key to use for authentication.
"""
digest = _password_digest(username, password)
md5hash = hashlib.md5()
md5hash.update("%s%s%s" % (nonce, unicode(username), digest))
return unicode(md5hash.hexdigest())
13 changes: 3 additions & 10 deletions asyncmongo/pool.py
Expand Up @@ -19,6 +19,7 @@
from errors import TooManyConnections, ProgrammingError
from connection import Connection


class ConnectionPools(object):
""" singleton to keep track of named connection pools """
@classmethod
Expand Down Expand Up @@ -57,6 +58,7 @@ class ConnectionPool(object):
- `maxconnections` (optional): maximum open connections for this pool. 0 for unlimited
- `maxusage` (optional): number of requests allowed on a connection before it is closed. 0 for unlimited
- `dbname`: mongo database name
- `slave_okay` (optional): is it okay to connect directly to and perform queries on a slave instance
- `**kwargs`: passed to `connection.Connection`
"""
Expand Down Expand Up @@ -89,6 +91,7 @@ def __init__(self,
self._dbname = dbname
self._slave_okay = slave_okay
self._connections = 0


# Establish an initial number of idle database connections:
idle = [self.connection() for i in range(mincached)]
Expand Down Expand Up @@ -156,14 +159,4 @@ def close(self):
finally:
self._condition.release()

def __get_slave_okay(self):
"""Is it OK to perform queries on a secondary or slave?
"""
return self._slave_okay

def __set_slave_okay(self, value):
"""Property setter for slave_okay"""
assert isinstance(value, bool)
self._slave_okay = value

slave_okay = property(__get_slave_okay, __set_slave_okay)
54 changes: 54 additions & 0 deletions test/sample_app/sample_app2.py
@@ -0,0 +1,54 @@
#!/usr/bin/env python

# mkdir /tmp/asyncmongo_sample_app2
# mongod --port 27017 --oplogSize 10 --dbpath /tmp/asyncmongo_sample_app2

# $mongo
# >>>use test;
# db.addUser("testuser", "testpass");

# ab -n 1000 -c 16 http://127.0.0.1:8888/

import sys
import logging
import os
app_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
if app_dir not in sys.path:
logging.debug('adding %r to sys.path' % app_dir)
sys.path.insert(0, app_dir)

import asyncmongo
# make sure we get the local asyncmongo
assert asyncmongo.__file__.startswith(app_dir)

import tornado.ioloop
import tornado.web
import tornado.options

class MainHandler(tornado.web.RequestHandler):
@tornado.web.asynchronous
def get(self):
db.users.find_one({"user_id" : 1}, callback=self._on_response)

def _on_response(self, response, error):
assert not error
self.write(str(response))
self.finish()


if __name__ == "__main__":
tornado.options.parse_command_line()
application = tornado.web.Application([
(r"/?", MainHandler)
])
application.listen(8888)
db = asyncmongo.Client(pool_id="test",
host='127.0.0.1',
port=27017,
mincached=5,
maxcached=15,
maxconnections=30,
dbname='test',
dbuser='testuser',
dbpass='testpass')
tornado.ioloop.IOLoop.instance().start()
50 changes: 50 additions & 0 deletions test/test_authentication.py
@@ -0,0 +1,50 @@
import tornado.ioloop
import time
import logging
import subprocess

import test_shunt
import asyncmongo

TEST_TIMESTAMP = int(time.time())

class AuthenticationTest(test_shunt.MongoTest):
def setUp(self):
super(AuthenticationTest, self).setUp()
logging.info('creating user')
pipe = subprocess.Popen('''echo -e 'use test;\n db.addUser("testuser", "testpass");\n exit;' | mongo --port 27017 --host 127.0.0.1''', shell=True)
pipe.wait()

def test_authentication(self):
try:
test_shunt.setup()
db = asyncmongo.Client(pool_id='testauth', host='127.0.0.1', port=27017, dbname='test', dbuser='testuser', dbpass='testpass', maxconnections=2)

def update_callback(response, error):
logging.info(response)
assert len(response) == 1
test_shunt.register_called('update')
tornado.ioloop.IOLoop.instance().stop()

db.test_stats.update({"_id" : TEST_TIMESTAMP}, {'$inc' : {'test_count' : 1}}, upsert=True, callback=update_callback)

tornado.ioloop.IOLoop.instance().start()
test_shunt.assert_called('update')

def query_callback(response, error):
logging.info(response)
logging.info(error)
assert error is None
assert isinstance(response, dict)
assert response['_id'] == TEST_TIMESTAMP
assert response['test_count'] == 1
test_shunt.register_called('retrieved')
tornado.ioloop.IOLoop.instance().stop()

db.test_stats.find_one({"_id" : TEST_TIMESTAMP}, callback=query_callback)
tornado.ioloop.IOLoop.instance().start()
test_shunt.assert_called('retrieved')
except:
tornado.ioloop.IOLoop.instance().stop()
raise

2 changes: 1 addition & 1 deletion test/test_shunt.py
Expand Up @@ -33,7 +33,7 @@ def setUp(self):
os.makedirs(dirname)
self.temp_dirs.append(dirname)

options = ['mongod', '--bind_ip', '127.0.0.1', '--oplogSize', '10', '--dbpath', dirname] + list(options)
options = ['mongod', '--bind_ip', '127.0.0.1', '--oplogSize', '10', '--dbpath', dirname, '-v'] + list(options)
logging.debug(options)
pipe = subprocess.Popen(options)
self.mongods.append(pipe)
Expand Down

0 comments on commit 193df57

Please sign in to comment.