Skip to content

Commit

Permalink
Merge pull request #489 from notEvil/ssl_context
Browse files Browse the repository at this point in the history
Changes related to SSL context
  • Loading branch information
comrumino committed Jun 3, 2022
2 parents 6d4cf10 + 7cf41e8 commit 5e50641
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 43 deletions.
30 changes: 21 additions & 9 deletions rpyc/core/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def ssl_connect(cls, host, port, ssl_kwargs, **kwargs):
:param host: the host name
:param port: the TCP port
:param ssl_kwargs: a dictionary of keyword arguments to be passed
directly to ``ssl.wrap_socket``
:param ssl_kwargs: a dictionary of keyword arguments for
``ssl.SSLContext`` and ``ssl.SSLContext.wrap_socket``
:param kwargs: additional keyword arguments: ``family``, ``socktype``,
``proto``, ``timeout``, ``nodelay``, passed directly to
the ``socket`` constructor, or ``ipv6``.
Expand All @@ -206,20 +206,32 @@ def ssl_connect(cls, host, port, ssl_kwargs, **kwargs):
:returns: a :class:`SocketStream`
"""
from ssl import SSLContext
import ssl
if kwargs.pop("ipv6", False):
kwargs["family"] = socket.AF_INET6
s = cls._connect(host, port, **kwargs)
try:
context = SSLContext(ssl_kwargs.pop('ssl_version'))
certfile = ssl_kwargs.pop('certfile', None)
keyfile = ssl_kwargs.pop('keyfile', None)
if "ssl_version" in ssl_kwargs:
context = ssl.SSLContext(ssl_kwargs.pop("ssl_version"))
else:
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
certfile = ssl_kwargs.pop("certfile", None)
keyfile = ssl_kwargs.pop("keyfile", None)
if certfile is not None:
context.load_cert_chain(certfile, keyfile=keyfile)
context.check_hostname = ssl_kwargs.pop('check_hostname', True)
context.verify_mode = ssl_kwargs.pop('cert_reqs', ssl.CERT_NONE)
s2 = context.wrap_socket(s, **ssl_kwargs)
ca_certs = ssl_kwargs.pop("ca_certs", None)
if ca_certs is not None:
context.load_verify_locations(ca_certs)
ciphers = ssl_kwargs.pop("ciphers", None)
if ciphers is not None:
context.set_ciphers(ciphers)
check_hostname = ssl_kwargs.pop("check_hostname", None)
if check_hostname is not None:
context.check_hostname = check_hostname
cert_reqs = ssl_kwargs.pop("cert_reqs", None)
if cert_reqs is not None:
context.verify_mode = cert_reqs
s2 = context.wrap_socket(s, server_hostname=host, **ssl_kwargs)
return cls(s2)
except BaseException:
s.close()
Expand Down
27 changes: 15 additions & 12 deletions rpyc/utils/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class AuthenticationError(Exception):

class SSLAuthenticator(object):
"""An implementation of the authenticator protocol for ``SSL``. The given
socket is wrapped by ``ssl.wrap_socket`` and is validated based on
socket is wrapped by ``ssl.SSLContext.wrap_socket`` and is validated based on
certificates
:param keyfile: the server's key file
Expand All @@ -48,7 +48,7 @@ class SSLAuthenticator(object):
to restrict the available ciphers. New in Python 2.7/3.2
:param ssl_version: the SSL version to use
Refer to `ssl.wrap_socket <http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket>`_
Refer to `ssl.SSLContext <http://docs.python.org/dev/library/ssl.html#ssl.SSLContext>`_
for more info.
Clients can connect to this authenticator using
Expand All @@ -70,19 +70,22 @@ def __init__(self, keyfile, certfile, ca_certs=None, cert_reqs=None,
self.cert_reqs = ssl.CERT_NONE
else:
self.cert_reqs = cert_reqs
if ssl_version is None:
self.ssl_version = ssl.PROTOCOL_TLS
else:
self.ssl_version = ssl_version
self.ssl_version = ssl_version

def __call__(self, sock):
kwargs = dict(keyfile=self.keyfile, certfile=self.certfile,
server_side=True, ca_certs=self.ca_certs, cert_reqs=self.cert_reqs,
ssl_version=self.ssl_version)
if self.ciphers is not None:
kwargs["ciphers"] = self.ciphers
try:
sock2 = ssl.wrap_socket(sock, **kwargs)
if self.ssl_version is None:
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
else:
context = ssl.SSLContext(self.ssl_version)
context.load_cert_chain(self.certfile, keyfile=self.keyfile)
if self.ca_certs is not None:
context.load_verify_locations(self.ca_certs)
if self.ciphers is not None:
context.set_ciphers(self.ciphers)
if self.cert_reqs is not None:
context.verify_mode = self.cert_reqs
sock2 = context.wrap_socket(sock, server_side=True)
except ssl.SSLError:
ex = sys.exc_info()[1]
raise AuthenticationError(str(ex))
Expand Down
25 changes: 14 additions & 11 deletions rpyc/utils/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,20 @@ def ssl_connect(host, port=DEFAULT_SERVER_SSL_PORT, keyfile=None,
:param port: the TCP port to use
:param ipv6: whether to create an IPv6 socket or an IPv4 one
The following arguments are passed directly to
`ssl.wrap_socket <http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket>`_:
:param keyfile: see ``ssl.wrap_socket``. May be ``None``
:param certfile: see ``ssl.wrap_socket``. May be ``None``
:param ca_certs: see ``ssl.wrap_socket``. May be ``None``
:param cert_reqs: see ``ssl.wrap_socket``. By default, if ``ca_cert`` is specified,
the requirement is set to ``CERT_REQUIRED``; otherwise it is
set to ``CERT_NONE``
:param ssl_version: see ``ssl.wrap_socket``. The default is ``PROTOCOL_TLSv1``
:param ciphers: see ``ssl.wrap_socket``. May be ``None``. New in Python 2.7/3.2
The following arguments are passed to
`ssl.SSLContext <http://docs.python.org/dev/library/ssl.html#ssl.SSLContext>`_ and
its corresponding methods:
:param keyfile: see ``ssl.SSLContext.load_cert_chain``. May be ``None``
:param certfile: see ``ssl.SSLContext.load_cert_chain``. May be ``None``
:param ca_certs: see ``ssl.SSLContext.load_verify_locations``. May be ``None``
:param cert_reqs: see ``ssl.SSLContext.verify_mode``. By default, if ``ca_cert`` is
specified, the requirement is set to ``CERT_REQUIRED``; otherwise
it is set to ``CERT_NONE``
:param ssl_version: see ``ssl.SSLContext``. The default is defined by
``ssl.create_default_context``
:param ciphers: see ``ssl.SSLContext.set_ciphers``. May be ``None``. New in
Python 2.7/3.2
:returns: an RPyC connection exposing ``SlaveService``
Expand Down
22 changes: 11 additions & 11 deletions rpyc/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,16 @@ def ssl_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
:param ipv6: whether to create an IPv6 socket or an IPv4 one(defaults to ``False``)
:param keepalive: whether to set TCP keepalive on the socket (defaults to ``False``)
:param keyfile: see ``ssl.wrap_socket``. May be ``None``
:param certfile: see ``ssl.wrap_socket``. May be ``None``
:param ca_certs: see ``ssl.wrap_socket``. May be ``None``
:param cert_reqs: see ``ssl.wrap_socket``. By default, if ``ca_cert`` is specified,
the requirement is set to ``CERT_REQUIRED``; otherwise it is
set to ``CERT_NONE``
:param ssl_version: see ``ssl.wrap_socket``. The default is ``PROTOCOL_TLS_CLIENT``
:param ciphers: see ``ssl.wrap_socket``. May be ``None``. New in Python 2.7/3.2
:param keyfile: see ``ssl.SSLContext.load_cert_chain``. May be ``None``
:param certfile: see ``ssl.SSLContext.load_cert_chain``. May be ``None``
:param ca_certs: see ``ssl.SSLContext.load_verify_locations``. May be ``None``
:param cert_reqs: see ``ssl.SSLContext.verify_mode``. By default, if ``ca_cert`` is
specified, the requirement is set to ``CERT_REQUIRED``; otherwise
it is set to ``CERT_NONE``
:param ssl_version: see ``ssl.SSLContext``. The default is defined by
``ssl.create_default_context``
:param ciphers: see ``ssl.SSLContext.set_ciphers``. May be ``None``. New in
Python 2.7/3.2
:param verify_mode: see ``ssl.SSLContext.verify_mode``
:returns: an RPyC connection
Expand All @@ -159,9 +161,7 @@ def ssl_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
ssl_kwargs["cert_reqs"] = cert_reqs
elif cert_reqs != ssl.CERT_NONE:
ssl_kwargs["check_hostname"] = False
if ssl_version is None:
ssl_kwargs["ssl_version"] = ssl.PROTOCOL_TLS_CLIENT
else:
if ssl_version is not None:
ssl_kwargs["ssl_version"] = ssl_version
if ciphers is not None:
ssl_kwargs["ciphers"] = ciphers
Expand Down

0 comments on commit 5e50641

Please sign in to comment.