Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge remote-tracking branch 'origin/sasl-support' into security

Conflicts:
	pycassa/connection.py
	pycassa/pool.py
	pycassa/system_manager.py
  • Loading branch information...
commit a1a886270149c3200fbb3466ab18751ec167fbba 2 parents c3e4c07 + b8951ec
@thobbs thobbs authored
View
157 pycassa/connection.py
@@ -1,5 +1,9 @@
-from thrift.transport import TTransport
-from thrift.transport import TSocket, TSSLSocket
+import struct
+from cStringIO import StringIO
+
+from thrift.transport import TTransport, TSocket, TSSLSocket
+from thrift.transport.TTransport import (TTransportBase, CReadableTransport,
+ TTransportException)
from thrift.protocol import TBinaryProtocol
from pycassa.cassandra import Cassandra
@@ -16,11 +20,20 @@ def default_socket_factory(host, port):
return TSocket.TSocket(host, port)
+def default_transport_factory(tsocket, host, port):
+ """
+ Returns a normal :class:`TFramedTransport` instance wrapping `tsocket`.
+ """
+ return TTransport.TFramedTransport(tsocket)
+
+
class Connection(Cassandra.Client):
"""Encapsulation of a client session."""
def __init__(self, keyspace, server, framed_transport=True, timeout=None,
- credentials=None, socket_factory=default_socket_factory):
+ credentials=None,
+ socket_factory=default_socket_factory,
+ transport_factory=default_transport_factory):
self.keyspace = None
self.server = server
server = server.split(':')
@@ -32,10 +45,7 @@ def __init__(self, keyspace, server, framed_transport=True, timeout=None,
socket = socket_factory(host, int(port))
if timeout is not None:
socket.setTimeout(timeout * 1000.0)
- if framed_transport:
- self.transport = TTransport.TFramedTransport(socket)
- else:
- self.transport = TTransport.TBufferedTransport(socket)
+ self.transport = transport_factory(socket, host, port)
protocol = TBinaryProtocol.TBinaryProtocolAccelerated(self.transport)
Cassandra.Client.__init__(self, protocol)
self.transport.open()
@@ -70,3 +80,136 @@ def ssl_socket_factory(host, port):
return TSSLSocket.TSSLSocket(host, port, ca_certs=ca_certs, validate=validate)
return ssl_socket_factory
+
+
+class TSaslClientTransport(TTransportBase, CReadableTransport):
+
+ START = 1
+ OK = 2
+ BAD = 3
+ ERROR = 4
+ COMPLETE = 5
+
+ def __init__(self, transport, host, service,
+ mechanism='GSSAPI', **sasl_kwargs):
+
+ from puresasl.client import SASLClient
+
+ self.transport = transport
+ self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
+
+ self.__wbuf = StringIO()
+ self.__rbuf = StringIO()
+
+ def open(self):
+ if not self.transport.isOpen():
+ self.transport.open()
+
+ self.send_sasl_msg(self.START, self.sasl.mechanism)
+ self.send_sasl_msg(self.OK, self.sasl.process())
+
+ while True:
+ status, challenge = self.recv_sasl_msg()
+ if status == self.OK:
+ self.send_sasl_msg(self.OK, self.sasl.process(challenge))
+ elif status == self.COMPLETE:
+ if not self.sasl.complete:
+ raise TTransportException("The server erroneously indicated "
+ "that SASL negotiation was complete")
+ else:
+ break
+ else:
+ raise TTransportException("Bad SASL negotiation status: %d (%s)"
+ % (status, challenge))
+
+ def send_sasl_msg(self, status, body):
+ header = struct.pack(">BI", status, len(body))
+ self.transport.write(header + body)
+ self.transport.flush()
+
+ def recv_sasl_msg(self):
+ header = self.transport.readAll(5)
+ status, length = struct.unpack(">BI", header)
+ if length > 0:
+ payload = self.transport.readAll(length)
+ else:
+ payload = ""
+ return status, payload
+
+ def write(self, data):
+ self.__wbuf.write(data)
+
+ def flush(self):
+ data = self.__wbuf.getvalue()
+ encoded = self.sasl.wrap(data)
+ # Note stolen from TFramedTransport:
+ # N.B.: Doing this string concatenation is WAY cheaper than making
+ # two separate calls to the underlying socket object. Socket writes in
+ # Python turn out to be REALLY expensive, but it seems to do a pretty
+ # good job of managing string buffer operations without excessive copies
+ self.transport.write(''.join((struct.pack("!i", len(encoded)), encoded)))
+ self.transport.flush()
+ self.__wbuf = StringIO()
+
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
+
+ self._read_frame()
+ return self.__rbuf.read(sz)
+
+ def _read_frame(self):
+ header = self.transport.readAll(4)
+ length, = struct.unpack('!i', header)
+ encoded = self.transport.readAll(length)
+ self.__rbuf = StringIO(self.sasl.unwrap(encoded))
+
+ def close(self):
+ self.sasl.dispose()
+ self.transport.close()
+
+ # Implement the CReadableTransport interface.
+ # Stolen shamelessly from TFramedTransport
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
+
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ self._read_frame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = StringIO(prefix)
+ return self.__rbuf
+
+
+def make_sasl_transport_factory(credential_factory):
+ """
+ A convenience function for creating a SASL transport factory.
+
+ `credential_factory` should be a function taking two args: `host` and
+ `port`. It should return a ``dict`` of kwargs that will be passed
+ to :func:`puresasl.client.SASLClient.__init__()`.
+
+ Example usage::
+
+ >>> def make_credentials(host, port):
+ ... return {'host': host,
+ ... 'service': 'cassandra',
+ ... 'principal': 'user/role@FOO.EXAMPLE.COM',
+ ... 'mechanism': 'GSSAPI'}
+ >>>
+ >>> factory = make_sasl_transport_factory(make_credentials)
+ >>> pool = ConnectionPool(..., transport_factory=factory)
+
+ """
+
+ def sasl_transport_factory(tsocket, host, port):
+ sasl_kwargs = credential_factory(host, port)
+ sasl_transport = TSaslClientTransport(tsocket, **sasl_kwargs)
+ return TTransport.TFramedTransport(sasl_transport)
+
+ return sasl_transport_factory
View
18 pycassa/pool.py
@@ -13,7 +13,8 @@
from thrift import Thrift
from thrift.transport.TTransport import TTransportException
-from connection import Connection, default_socket_factory
+from connection import (Connection, default_socket_factory,
+ default_transport_factory)
from logging.pool_logger import PoolLogger
from util import as_interface
from cassandra.ttypes import TimedOutException, UnavailableException
@@ -258,6 +259,15 @@ def _set_max_overflow(self, max_overflow):
By default, this is function is :func:`~connection.default_socket_factory`.
"""
+ transport_factory = default_transport_factory
+ """ A function that creates the transport for each connection in the pool.
+ This function should take three arguments: `tsocket`, a TSocket object for the
+ transport, `host`, the host the connection is being made to, and `port`,
+ the destination port.
+
+ By default, this is function is :func:`~connection.default_transport_factory`.
+ """
+
def __init__(self, keyspace,
server_list=['localhost:9160'],
credentials=None,
@@ -266,6 +276,7 @@ def __init__(self, keyspace,
pool_size=5,
prefill=True,
socket_factory=default_socket_factory,
+ transport_factory=default_transport_factory,
**kwargs):
"""
All connections in the pool will be opened to `keyspace`.
@@ -325,6 +336,7 @@ def __init__(self, keyspace,
self.credentials = credentials
self.timeout = timeout
self.socket_factory = socket_factory
+ self.transport_factory = transport_factory
if use_threadlocal:
self._tlocal = threading.local()
@@ -439,10 +451,10 @@ def fill(self):
def _get_new_wrapper(self, server):
return ConnectionWrapper(self, self.max_retries,
self.keyspace, server,
- framed_transport=True,
timeout=self.timeout,
credentials=self.credentials,
- socket_factory=self.socket_factory)
+ socket_factory=self.socket_factory,
+ transport_factory=self.transport_factory)
def _replace_wrapper(self):
"""Try to replace the connection."""
View
9 pycassa/system_manager.py
@@ -1,6 +1,7 @@
import time
-from pycassa.connection import Connection, default_socket_factory
+from pycassa.connection import (Connection, default_socket_factory,
+ default_transport_factory)
from pycassa.cassandra.ttypes import IndexType, KsDef, CfDef, ColumnDef,\
SchemaDisagreementException
import pycassa.marshal as marshal
@@ -66,8 +67,10 @@ class SystemManager(object):
"""
def __init__(self, server='localhost:9160', credentials=None, framed_transport=True,
- timeout=_DEFAULT_TIMEOUT, socket_factory=default_socket_factory):
- self._conn = Connection(None, server, framed_transport, timeout, credentials, socket_factory)
+ timeout=_DEFAULT_TIMEOUT, socket_factory=default_socket_factory,
+ transport_factory=default_transport_factory):
+ self._conn = Connection(None, server, framed_transport, timeout,
+ credentials, socket_factory, transport_factory)
def close(self):
""" Closes the underlying connection """
View
32 sasl_demo.py
@@ -0,0 +1,32 @@
+#!/usr/bin/python
+
+from pycassa.pool import ConnectionPool
+from pycassa.columnfamily import ColumnFamily
+from pycassa.connection import make_sasl_transport_factory
+from pycassa.system_manager import SystemManager
+
+def make_creds(host, port):
+ # typically, you would use the passed-in host, but my kerberos test setup
+ # is not that sophisticated
+ return {'host': 'thobbs-laptop2',
+ 'service': 'host',
+ 'mechanism': 'GSSAPI'}
+
+transport_factory = make_sasl_transport_factory(make_creds)
+
+sysman = SystemManager(transport_factory=transport_factory)
+if 'Keyspace1' not in sysman.list_keyspaces():
+ sysman.create_keyspace('Keyspace1', 'SimpleStrategy', {'replication_factor': '1'})
+ sysman.create_column_family('Keyspace1', 'Standard1')
+sysman.close()
+
+pool = ConnectionPool('Keyspace1', transport_factory=transport_factory)
+cf = ColumnFamily(pool, 'Standard1')
+
+for i in range(100):
+ cf.insert('key%d' % i, {'col': 'val'})
+
+for i in range(100):
+ print 'key%d:' % i, cf.get('key%d' % i)
+
+pool.dispose()
Please sign in to comment.
Something went wrong with that request. Please try again.