Skip to content

Commit

Permalink
Merge branch 'security'
Browse files Browse the repository at this point in the history
  • Loading branch information
thobbs committed Nov 8, 2012
2 parents cd53f63 + 487aa26 commit 4312053
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 16 deletions.
184 changes: 176 additions & 8 deletions pycassa/connection.py
@@ -1,5 +1,9 @@
from thrift.transport import TTransport
from thrift.transport import TSocket
import struct
from cStringIO import StringIO

from thrift.transport import TTransport, TSocket, TSSLSocket
from thrift.transport.TTransport import (TTransportBase, CReadableTransport,
TTransportException)
from thrift.protocol import TBinaryProtocol

from pycassa.cassandra import Cassandra
Expand All @@ -8,11 +12,28 @@
DEFAULT_SERVER = 'localhost:9160'
DEFAULT_PORT = 9160


def default_socket_factory(host, port):
"""
Returns a normal :class:`TSocket` instance.
"""
return TSocket.TSocket(host, port)


def default_transport_factory(tsocket, host, port):
"""
Returns a normal :class:`TFramedTransport` instance wrapping `tsocket`.
"""
return TTransport.TFramedTransport(tsocket)


class Connection(Cassandra.Client):
"""Encapsulation of a client session."""

def __init__(self, keyspace, server, framed_transport=True, timeout=None,
credentials=None, api_version=None):
credentials=None,
socket_factory=default_socket_factory,
transport_factory=default_transport_factory):
self.keyspace = None
self.server = server
server = server.split(':')
Expand All @@ -21,13 +42,10 @@ def __init__(self, keyspace, server, framed_transport=True, timeout=None,
else:
port = server[1]
host = server[0]
socket = TSocket.TSocket(host, int(port))
socket = socket_factory(host, int(port))
if timeout is not None:
socket.setTimeout(timeout * 1000.0)
if framed_transport:
self.transport = TTransport.TFramedTransport(socket)
else:
self.transport = TTransport.TBufferedTransport(socket)
self.transport = transport_factory(socket, host, port)
protocol = TBinaryProtocol.TBinaryProtocolAccelerated(self.transport)
Cassandra.Client.__init__(self, protocol)
self.transport.open()
Expand All @@ -45,3 +63,153 @@ def set_keyspace(self, keyspace):

def close(self):
self.transport.close()


def make_ssl_socket_factory(ca_certs, validate=True):
"""
A convenience function for creating an SSL socket factory.
`ca_certs` should contain the path to the certificate file,
`validate` determines whether or not SSL certificate validation will be performed.
"""

def ssl_socket_factory(host, port):
"""
Returns a :class:`TSSLSocket` instance.
"""
return TSSLSocket.TSSLSocket(host, port, ca_certs=ca_certs, validate=validate)

return ssl_socket_factory


class TSaslClientTransport(TTransportBase, CReadableTransport):

START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5

def __init__(self, transport, host, service,
mechanism='GSSAPI', **sasl_kwargs):

from puresasl.client import SASLClient

self.transport = transport
self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)

self.__wbuf = StringIO()
self.__rbuf = StringIO()

def open(self):
if not self.transport.isOpen():
self.transport.open()

self.send_sasl_msg(self.START, self.sasl.mechanism)
self.send_sasl_msg(self.OK, self.sasl.process())

while True:
status, challenge = self.recv_sasl_msg()
if status == self.OK:
self.send_sasl_msg(self.OK, self.sasl.process(challenge))
elif status == self.COMPLETE:
if not self.sasl.complete:
raise TTransportException("The server erroneously indicated "
"that SASL negotiation was complete")
else:
break
else:
raise TTransportException("Bad SASL negotiation status: %d (%s)"
% (status, challenge))

def send_sasl_msg(self, status, body):
header = struct.pack(">BI", status, len(body))
self.transport.write(header + body)
self.transport.flush()

def recv_sasl_msg(self):
header = self.transport.readAll(5)
status, length = struct.unpack(">BI", header)
if length > 0:
payload = self.transport.readAll(length)
else:
payload = ""
return status, payload

def write(self, data):
self.__wbuf.write(data)

def flush(self):
data = self.__wbuf.getvalue()
encoded = self.sasl.wrap(data)
# Note stolen from TFramedTransport:
# N.B.: Doing this string concatenation is WAY cheaper than making
# two separate calls to the underlying socket object. Socket writes in
# Python turn out to be REALLY expensive, but it seems to do a pretty
# good job of managing string buffer operations without excessive copies
self.transport.write(''.join((struct.pack("!i", len(encoded)), encoded)))
self.transport.flush()
self.__wbuf = StringIO()

def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret

self._read_frame()
return self.__rbuf.read(sz)

def _read_frame(self):
header = self.transport.readAll(4)
length, = struct.unpack('!i', header)
encoded = self.transport.readAll(length)
self.__rbuf = StringIO(self.sasl.unwrap(encoded))

def close(self):
self.sasl.dispose()
self.transport.close()

# Implement the CReadableTransport interface.
# Stolen shamelessly from TFramedTransport
@property
def cstringio_buf(self):
return self.__rbuf

def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastbinary doesn't
# ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately.
while len(prefix) < reqlen:
self._read_frame()
prefix += self.__rbuf.getvalue()
self.__rbuf = StringIO(prefix)
return self.__rbuf


def make_sasl_transport_factory(credential_factory):
"""
A convenience function for creating a SASL transport factory.
`credential_factory` should be a function taking two args: `host` and
`port`. It should return a ``dict`` of kwargs that will be passed
to :func:`puresasl.client.SASLClient.__init__()`.
Example usage::
>>> def make_credentials(host, port):
... return {'host': host,
... 'service': 'cassandra',
... 'principal': 'user/role@FOO.EXAMPLE.COM',
... 'mechanism': 'GSSAPI'}
>>>
>>> factory = make_sasl_transport_factory(make_credentials)
>>> pool = ConnectionPool(..., transport_factory=factory)
"""

def sasl_transport_factory(tsocket, host, port):
sasl_kwargs = credential_factory(host, port)
sasl_transport = TSaslClientTransport(tsocket, **sasl_kwargs)
return TTransport.TFramedTransport(sasl_transport)

return sasl_transport_factory
33 changes: 28 additions & 5 deletions pycassa/pool.py
Expand Up @@ -13,7 +13,8 @@

from thrift import Thrift
from thrift.transport.TTransport import TTransportException
from connection import Connection
from connection import (Connection, default_socket_factory,
default_transport_factory)
from logging.pool_logger import PoolLogger
from util import as_interface
from cassandra.ttypes import TimedOutException, UnavailableException
Expand Down Expand Up @@ -229,7 +230,7 @@ def _set_max_overflow(self, max_overflow):
up to `pool_timeout` seconds for a connection to be returned to the
pool before giving up. Note that this setting is only meaningful when you
are accessing the pool concurrently, such as with multiple threads.
This may be set to 0 to fail immediately or -1 to wait forever.
This may be set to 0 to fail immediately or -1 to wait forever.
The default value is 30. """

recycle = 10000
Expand All @@ -242,21 +243,40 @@ def _set_max_overflow(self, max_overflow):
or :exc:`~.UnavailableException`, which tend to indicate single or
multiple node failure, the operation will be retried on different nodes
up to `max_retries` times before an :exc:`~.MaximumRetryException` is raised.
Setting this to 0 disables retries and setting to -1 allows unlimited retries.
Setting this to 0 disables retries and setting to -1 allows unlimited retries.
The default value is 5. """

logging_name = None
""" By default, each pool identifies itself in the logs using ``id(self)``.
If multiple pools are in use for different purposes, setting `logging_name` will
help individual pools to be identified in the logs. """

socket_factory = default_socket_factory
""" A function that creates the socket for each connection in the pool.
This function should take two arguments: `host`, the host the connection is
being made to, and `port`, the destination port.
By default, this is function is :func:`~connection.default_socket_factory`.
"""

transport_factory = default_transport_factory
""" A function that creates the transport for each connection in the pool.
This function should take three arguments: `tsocket`, a TSocket object for the
transport, `host`, the host the connection is being made to, and `port`,
the destination port.
By default, this is function is :func:`~connection.default_transport_factory`.
"""

def __init__(self, keyspace,
server_list=['localhost:9160'],
credentials=None,
timeout=0.5,
use_threadlocal=True,
pool_size=5,
prefill=True,
socket_factory=default_socket_factory,
transport_factory=default_transport_factory,
**kwargs):
"""
All connections in the pool will be opened to `keyspace`.
Expand Down Expand Up @@ -315,6 +335,8 @@ def __init__(self, keyspace,
self.keyspace = keyspace
self.credentials = credentials
self.timeout = timeout
self.socket_factory = socket_factory
self.transport_factory = transport_factory
if use_threadlocal:
self._tlocal = threading.local()

Expand Down Expand Up @@ -429,9 +451,10 @@ def fill(self):
def _get_new_wrapper(self, server):
return ConnectionWrapper(self, self.max_retries,
self.keyspace, server,
framed_transport=True,
timeout=self.timeout,
credentials=self.credentials)
credentials=self.credentials,
socket_factory=self.socket_factory,
transport_factory=self.transport_factory)

def _replace_wrapper(self):
"""Try to replace the connection."""
Expand Down
9 changes: 6 additions & 3 deletions pycassa/system_manager.py
@@ -1,6 +1,7 @@
import time

from pycassa.connection import Connection
from pycassa.connection import (Connection, default_socket_factory,
default_transport_factory)
from pycassa.cassandra.ttypes import IndexType, KsDef, CfDef, ColumnDef,\
SchemaDisagreementException
import pycassa.marshal as marshal
Expand Down Expand Up @@ -66,8 +67,10 @@ class SystemManager(object):
"""

def __init__(self, server='localhost:9160', credentials=None, framed_transport=True,
timeout=_DEFAULT_TIMEOUT):
self._conn = Connection(None, server, framed_transport, timeout, credentials)
timeout=_DEFAULT_TIMEOUT, socket_factory=default_socket_factory,
transport_factory=default_transport_factory):
self._conn = Connection(None, server, framed_transport, timeout,
credentials, socket_factory, transport_factory)

def close(self):
""" Closes the underlying connection """
Expand Down

0 comments on commit 4312053

Please sign in to comment.