Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft: TLS support for meta and storage clients #239

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions nebula3/fbthrift/transport/TSSLSocket.py
Expand Up @@ -120,7 +120,7 @@ def _warn_if_insecure_version_specified(version):

def _get_ssl_socket(socket, ssl_version, cert_reqs=ssl.CERT_NONE,
ca_certs=None, keyfile=None, certfile=None,
disable_weaker_versions=True):
disable_weaker_versions=True, server_hostname=None):
ctx = ssl.SSLContext(ssl_version)
ctx.verify_mode = cert_reqs
if certfile is not None:
Expand All @@ -144,7 +144,7 @@ def _get_ssl_socket(socket, ssl_version, cert_reqs=ssl.CERT_NONE,
if hasattr(ssl, 'OP_NO_TLSv1_1'):
ctx.options |= ssl.OP_NO_TLSv1_1

return ctx.wrap_socket(socket)
return ctx.wrap_socket(socket, server_hostname=server_hostname)


class TSSLSocket(TSocket):
Expand Down Expand Up @@ -222,6 +222,7 @@ def open(self):
keyfile=self.client_keyfile,
certfile=self.client_certfile,
disable_weaker_versions=not self.allow_weak_ssl_versions,
server_hostname=self.host,
)

if self.verify_name:
Expand Down
29 changes: 22 additions & 7 deletions nebula3/mclient/__init__.py
Expand Up @@ -28,19 +28,20 @@
)
from nebula3.meta import ttypes, MetaService

from nebula3.fbthrift.transport import TSocket, TTransport
from nebula3.fbthrift.transport import TSocket, TTransport, TSSLSocket
from nebula3.fbthrift.protocol import TBinaryProtocol
from nebula3.logger import logger


class MetaClient(object):
def __init__(self, addresses, timeout):
def __init__(self, addresses, timeout, ssl_config=None):
if len(addresses) == 0:
raise RuntimeError('Input empty addresses')
self._timeout = timeout
self._connection = None
self._retry_count = 3
self._addresses = addresses
self._ssl_config = ssl_config
for address in addresses:
try:
socket.gethostbyname(address[0])
Expand All @@ -50,13 +51,27 @@ def __init__(self, addresses, timeout):
self._lock = RLock()

def open(self):
"""open the connection to connect meta service

"""open the SSL connection to connect meta service
:ssl_config: configs for SSL
:eturn: void
"""
try:
self.close()
s = TSocket.TSocket(self._leader[0], self._leader[1])
if self._ssl_config is not None:
s = TSSLSocket.TSSLSocket(
self._leader[0],
self._leader[1],
self._ssl_config.unix_socket,
self._ssl_config.ssl_version,
self._ssl_config.cert_reqs,
self._ssl_config.ca_certs,
self._ssl_config.verify_name,
self._ssl_config.keyfile,
self._ssl_config.certfile,
self._ssl_config.allow_weak_ssl_versions,
)
else:
s = TSocket.TSocket(self._leader[0], self._leader[1])
if self._timeout > 0:
s.setTimeout(self._timeout)
transport = TTransport.TBufferedTransport(s)
Expand Down Expand Up @@ -267,7 +282,7 @@ def __repr__(self):
self.parts_alloc,
)

def __init__(self, meta_addrs, timeout=2000, load_period=10, decode_type='utf-8'):
def __init__(self, meta_addrs, timeout=2000, load_period=10, decode_type='utf-8', ssl_config=None):
self._decode_type = decode_type
self._load_period = load_period
self._lock = RLock()
Expand All @@ -276,7 +291,7 @@ def __init__(self, meta_addrs, timeout=2000, load_period=10, decode_type='utf-8'
self._storage_addrs = []
self._storage_leader = {}
self._close = False
self._meta_client = MetaClient(meta_addrs, timeout)
self._meta_client = MetaClient(meta_addrs, timeout, ssl_config=ssl_config)
self._meta_client.open()

# load meta data
Expand Down
8 changes: 6 additions & 2 deletions nebula3/sclient/GraphStorageClient.py
Expand Up @@ -35,10 +35,11 @@ class GraphStorageClient(object):
DEFAULT_END_TIME = sys.maxsize
DEFAULT_LIMIT = 1000

def __init__(self, meta_cache, storage_addrs=None, time_out=60000):
def __init__(self, meta_cache, storage_addrs=None, time_out=60000, ssl_config=None):
self._meta_cache = meta_cache
self._storage_addrs = storage_addrs
self._time_out = time_out
self._ssl_config = ssl_config
self._connections = []
self._create_connection()

Expand Down Expand Up @@ -76,7 +77,10 @@ def _create_connection(self):
try:
for addr in self._storage_addrs:
conn = GraphStorageConnection(addr, self._time_out, self._meta_cache)
conn.open()
if self._ssl_config is None:
conn.open()
else:
conn.open_SSL(ssl_config=self._ssl_config)
self._connections.append(conn)
except Exception as e:
logger.error('Create storage connection failed: {}'.format(e))
Expand Down
27 changes: 25 additions & 2 deletions nebula3/sclient/net/__init__.py
Expand Up @@ -11,7 +11,7 @@
from nebula3.Exception import InValidHostname
from nebula3.storage import GraphStorageService
from nebula3.storage.ttypes import ScanVertexRequest, ScanEdgeRequest
from nebula3.fbthrift.transport import TSocket, TTransport
from nebula3.fbthrift.transport import TSocket, TTransport, TSSLSocket
from nebula3.fbthrift.protocol import TBinaryProtocol


Expand All @@ -22,6 +22,7 @@ def __init__(self, address, timeout, meta_cache):
self._meta_cache = meta_cache
self._connection = None
self._ip = ''
self._ssl_conf = None
try:
self._ip = socket.gethostbyname(address.host)
if not isinstance(address.port, int):
Expand All @@ -30,9 +31,31 @@ def __init__(self, address, timeout, meta_cache):
raise InValidHostname(str(address.host))

def open(self):
self.open_SSL(ssl_config=None)

def open_SSL(self, ssl_config=None):
"""open the SSL connection
:ssl_config: configs for SSL
:return: void
"""
self._ssl_conf = ssl_config
try:
self.close()
s = TSocket.TSocket(self._address.host, self._address.port)
if self._ssl_conf is not None:
s = TSSLSocket.TSSLSocket(
self._address.host,
self._address.port,
ssl_config.unix_socket,
ssl_config.ssl_version,
ssl_config.cert_reqs,
ssl_config.ca_certs,
ssl_config.verify_name,
ssl_config.keyfile,
ssl_config.certfile,
ssl_config.allow_weak_ssl_versions,
)
else:
s = TSocket.TSocket(self._address.host, self._address.port)
if self._timeout > 0:
s.setTimeout(self._timeout)
transport = TTransport.TBufferedTransport(s)
Expand Down