Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge branch 'security'

  • Loading branch information...
commit 43120533a6c914f3263f456d47f1220fcddb9b1b 2 parents cd53f63 + 487aa26
Tyler Hobbs thobbs authored
Showing with 210 additions and 16 deletions.
  1. +176 −8 pycassa/connection.py
  2. +28 −5 pycassa/pool.py
  3. +6 −3 pycassa/system_manager.py
184 pycassa/connection.py
View
@@ -1,5 +1,9 @@
-from thrift.transport import TTransport
-from thrift.transport import TSocket
+import struct
+from cStringIO import StringIO
+
+from thrift.transport import TTransport, TSocket, TSSLSocket
+from thrift.transport.TTransport import (TTransportBase, CReadableTransport,
+ TTransportException)
from thrift.protocol import TBinaryProtocol
from pycassa.cassandra import Cassandra
@@ -8,11 +12,28 @@
DEFAULT_SERVER = 'localhost:9160'
DEFAULT_PORT = 9160
+
+def default_socket_factory(host, port):
+ """
+ Returns a normal :class:`TSocket` instance.
+ """
+ return TSocket.TSocket(host, port)
+
+
+def default_transport_factory(tsocket, host, port):
+ """
+ Returns a normal :class:`TFramedTransport` instance wrapping `tsocket`.
+ """
+ return TTransport.TFramedTransport(tsocket)
+
+
class Connection(Cassandra.Client):
"""Encapsulation of a client session."""
def __init__(self, keyspace, server, framed_transport=True, timeout=None,
- credentials=None, api_version=None):
+ credentials=None,
+ socket_factory=default_socket_factory,
+ transport_factory=default_transport_factory):
self.keyspace = None
self.server = server
server = server.split(':')
@@ -21,13 +42,10 @@ def __init__(self, keyspace, server, framed_transport=True, timeout=None,
else:
port = server[1]
host = server[0]
- socket = TSocket.TSocket(host, int(port))
+ socket = socket_factory(host, int(port))
if timeout is not None:
socket.setTimeout(timeout * 1000.0)
- if framed_transport:
- self.transport = TTransport.TFramedTransport(socket)
- else:
- self.transport = TTransport.TBufferedTransport(socket)
+ self.transport = transport_factory(socket, host, port)
protocol = TBinaryProtocol.TBinaryProtocolAccelerated(self.transport)
Cassandra.Client.__init__(self, protocol)
self.transport.open()
@@ -45,3 +63,153 @@ def set_keyspace(self, keyspace):
def close(self):
self.transport.close()
+
+
+def make_ssl_socket_factory(ca_certs, validate=True):
+ """
+ A convenience function for creating an SSL socket factory.
+
+ `ca_certs` should contain the path to the certificate file,
+ `validate` determines whether or not SSL certificate validation will be performed.
+ """
+
+ def ssl_socket_factory(host, port):
+ """
+ Returns a :class:`TSSLSocket` instance.
+ """
+ return TSSLSocket.TSSLSocket(host, port, ca_certs=ca_certs, validate=validate)
+
+ return ssl_socket_factory
+
+
+class TSaslClientTransport(TTransportBase, CReadableTransport):
+
+ START = 1
+ OK = 2
+ BAD = 3
+ ERROR = 4
+ COMPLETE = 5
+
+ def __init__(self, transport, host, service,
+ mechanism='GSSAPI', **sasl_kwargs):
+
+ from puresasl.client import SASLClient
+
+ self.transport = transport
+ self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
+
+ self.__wbuf = StringIO()
+ self.__rbuf = StringIO()
+
+ def open(self):
+ if not self.transport.isOpen():
+ self.transport.open()
+
+ self.send_sasl_msg(self.START, self.sasl.mechanism)
+ self.send_sasl_msg(self.OK, self.sasl.process())
+
+ while True:
+ status, challenge = self.recv_sasl_msg()
+ if status == self.OK:
+ self.send_sasl_msg(self.OK, self.sasl.process(challenge))
+ elif status == self.COMPLETE:
+ if not self.sasl.complete:
+ raise TTransportException("The server erroneously indicated "
+ "that SASL negotiation was complete")
+ else:
+ break
+ else:
+ raise TTransportException("Bad SASL negotiation status: %d (%s)"
+ % (status, challenge))
+
+ def send_sasl_msg(self, status, body):
+ header = struct.pack(">BI", status, len(body))
+ self.transport.write(header + body)
+ self.transport.flush()
+
+ def recv_sasl_msg(self):
+ header = self.transport.readAll(5)
+ status, length = struct.unpack(">BI", header)
+ if length > 0:
+ payload = self.transport.readAll(length)
+ else:
+ payload = ""
+ return status, payload
+
+ def write(self, data):
+ self.__wbuf.write(data)
+
+ def flush(self):
+ data = self.__wbuf.getvalue()
+ encoded = self.sasl.wrap(data)
+ # Note stolen from TFramedTransport:
+ # N.B.: Doing this string concatenation is WAY cheaper than making
+ # two separate calls to the underlying socket object. Socket writes in
+ # Python turn out to be REALLY expensive, but it seems to do a pretty
+ # good job of managing string buffer operations without excessive copies
+ self.transport.write(''.join((struct.pack("!i", len(encoded)), encoded)))
+ self.transport.flush()
+ self.__wbuf = StringIO()
+
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
+
+ self._read_frame()
+ return self.__rbuf.read(sz)
+
+ def _read_frame(self):
+ header = self.transport.readAll(4)
+ length, = struct.unpack('!i', header)
+ encoded = self.transport.readAll(length)
+ self.__rbuf = StringIO(self.sasl.unwrap(encoded))
+
+ def close(self):
+ self.sasl.dispose()
+ self.transport.close()
+
+ # Implement the CReadableTransport interface.
+ # Stolen shamelessly from TFramedTransport
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
+
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ self._read_frame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = StringIO(prefix)
+ return self.__rbuf
+
+
+def make_sasl_transport_factory(credential_factory):
+ """
+ A convenience function for creating a SASL transport factory.
+
+ `credential_factory` should be a function taking two args: `host` and
+ `port`. It should return a ``dict`` of kwargs that will be passed
+ to :func:`puresasl.client.SASLClient.__init__()`.
+
+ Example usage::
+
+ >>> def make_credentials(host, port):
+ ... return {'host': host,
+ ... 'service': 'cassandra',
+ ... 'principal': 'user/role@FOO.EXAMPLE.COM',
+ ... 'mechanism': 'GSSAPI'}
+ >>>
+ >>> factory = make_sasl_transport_factory(make_credentials)
+ >>> pool = ConnectionPool(..., transport_factory=factory)
+
+ """
+
+ def sasl_transport_factory(tsocket, host, port):
+ sasl_kwargs = credential_factory(host, port)
+ sasl_transport = TSaslClientTransport(tsocket, **sasl_kwargs)
+ return TTransport.TFramedTransport(sasl_transport)
+
+ return sasl_transport_factory
33 pycassa/pool.py
View
@@ -13,7 +13,8 @@
from thrift import Thrift
from thrift.transport.TTransport import TTransportException
-from connection import Connection
+from connection import (Connection, default_socket_factory,
+ default_transport_factory)
from logging.pool_logger import PoolLogger
from util import as_interface
from cassandra.ttypes import TimedOutException, UnavailableException
@@ -229,7 +230,7 @@ def _set_max_overflow(self, max_overflow):
up to `pool_timeout` seconds for a connection to be returned to the
pool before giving up. Note that this setting is only meaningful when you
are accessing the pool concurrently, such as with multiple threads.
- This may be set to 0 to fail immediately or -1 to wait forever.
+ This may be set to 0 to fail immediately or -1 to wait forever.
The default value is 30. """
recycle = 10000
@@ -242,7 +243,7 @@ def _set_max_overflow(self, max_overflow):
or :exc:`~.UnavailableException`, which tend to indicate single or
multiple node failure, the operation will be retried on different nodes
up to `max_retries` times before an :exc:`~.MaximumRetryException` is raised.
- Setting this to 0 disables retries and setting to -1 allows unlimited retries.
+ Setting this to 0 disables retries and setting to -1 allows unlimited retries.
The default value is 5. """
logging_name = None
@@ -250,6 +251,23 @@ def _set_max_overflow(self, max_overflow):
If multiple pools are in use for different purposes, setting `logging_name` will
help individual pools to be identified in the logs. """
+ socket_factory = default_socket_factory
+ """ A function that creates the socket for each connection in the pool.
+ This function should take two arguments: `host`, the host the connection is
+ being made to, and `port`, the destination port.
+
+ By default, this is function is :func:`~connection.default_socket_factory`.
+ """
+
+ transport_factory = default_transport_factory
+ """ A function that creates the transport for each connection in the pool.
+ This function should take three arguments: `tsocket`, a TSocket object for the
+ transport, `host`, the host the connection is being made to, and `port`,
+ the destination port.
+
+ By default, this is function is :func:`~connection.default_transport_factory`.
+ """
+
def __init__(self, keyspace,
server_list=['localhost:9160'],
credentials=None,
@@ -257,6 +275,8 @@ def __init__(self, keyspace,
use_threadlocal=True,
pool_size=5,
prefill=True,
+ socket_factory=default_socket_factory,
+ transport_factory=default_transport_factory,
**kwargs):
"""
All connections in the pool will be opened to `keyspace`.
@@ -315,6 +335,8 @@ def __init__(self, keyspace,
self.keyspace = keyspace
self.credentials = credentials
self.timeout = timeout
+ self.socket_factory = socket_factory
+ self.transport_factory = transport_factory
if use_threadlocal:
self._tlocal = threading.local()
@@ -429,9 +451,10 @@ def fill(self):
def _get_new_wrapper(self, server):
return ConnectionWrapper(self, self.max_retries,
self.keyspace, server,
- framed_transport=True,
timeout=self.timeout,
- credentials=self.credentials)
+ credentials=self.credentials,
+ socket_factory=self.socket_factory,
+ transport_factory=self.transport_factory)
def _replace_wrapper(self):
"""Try to replace the connection."""
9 pycassa/system_manager.py
View
@@ -1,6 +1,7 @@
import time
-from pycassa.connection import Connection
+from pycassa.connection import (Connection, default_socket_factory,
+ default_transport_factory)
from pycassa.cassandra.ttypes import IndexType, KsDef, CfDef, ColumnDef,\
SchemaDisagreementException
import pycassa.marshal as marshal
@@ -66,8 +67,10 @@ class SystemManager(object):
"""
def __init__(self, server='localhost:9160', credentials=None, framed_transport=True,
- timeout=_DEFAULT_TIMEOUT):
- self._conn = Connection(None, server, framed_transport, timeout, credentials)
+ timeout=_DEFAULT_TIMEOUT, socket_factory=default_socket_factory,
+ transport_factory=default_transport_factory):
+ self._conn = Connection(None, server, framed_transport, timeout,
+ credentials, socket_factory, transport_factory)
def close(self):
""" Closes the underlying connection """
Please sign in to comment.
Something went wrong with that request. Please try again.