Skip to content

Commit

Permalink
Merge pull request #1132 from pika/pika-1107-python-3.7
Browse files Browse the repository at this point in the history
Better parsing of URL parameter "ssl_options"
  • Loading branch information
michaelklishin committed Nov 5, 2018
2 parents e58f3a1 + 1e6023e commit 5144400
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 26 deletions.
71 changes: 54 additions & 17 deletions pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import logging
import math
import numbers
import os
import platform
import socket
import warnings
import ssl

Expand Down Expand Up @@ -963,27 +963,64 @@ def _set_url_ssl_options(self, value):
"""Deserialize and apply the corresponding query string arg
"""
options = ast.literal_eval(value)
if options is None:
opts = ast.literal_eval(value)
if opts is None:
if self.ssl_options is not None:
raise ValueError(
'Specified ssl_options=None URL arg is inconsistent with '
'the specified https URL scheme.')
else:
# Convert options to pika.SSLOptions via ssl.SSLSocket()
sock = socket.socket()
try:
ssl_sock = ssl.SSLSocket(sock=sock, **options)
try:
self.ssl_options = pika.SSLOptions(
context=ssl_sock.context,
server_hostname=ssl_sock.server_hostname)
finally:
ssl_sock.close()
finally:
sock.close()


# Note: this is the deprecated wrap_socket signature and info:
#
# Internally, function creates a SSLContext with protocol
# ssl_version and SSLContext.options set to cert_reqs.
# If parameters keyfile, certfile, ca_certs or ciphers are set,
# then the values are passed to SSLContext.load_cert_chain(),
# SSLContext.load_verify_locations(), and SSLContext.set_ciphers().
#
# ssl.wrap_socket(sock,
# keyfile=None,
# certfile=None,
# server_side=False, # Not URL-supported
# cert_reqs=CERT_NONE, # Not URL-supported
# ssl_version=PROTOCOL_TLS, # Not URL-supported
# ca_certs=None,
# do_handshake_on_connect=True, # Not URL-supported
# suppress_ragged_eofs=True, # Not URL-supported
# ciphers=None
cxt = None
if 'ca_certs' in opts:
opt_ca_certs = opts['ca_certs']
if os.path.isfile(opt_ca_certs):
cxt = ssl.create_default_context(cafile=opt_ca_certs)
elif os.path.isdir(opt_ca_certs):
cxt = ssl.create_default_context(capath=opt_ca_certs)
else:
LOGGER.warning('ca_certs is specified via ssl_options but '
'is neither a valid file nor directory: "%s"',
opt_ca_certs)

if 'certfile' in opts:
if os.path.isfile(opts['certfile']):
keyfile = opts.get('keyfile')
password = opts.get('password')
cxt.load_cert_chain(opts['certfile'], keyfile, password)
else:
LOGGER.warning('certfile is specified via ssl_options but '
'is not a valid file: "%s"',
opts['certfile'])

if 'ciphers' in opts:
opt_ciphers = opts['ciphers']
if opt_ciphers is not None:
cxt.set_ciphers(opt_ciphers)
else:
LOGGER.warning('ciphers specified in ssl_options but '
'evaluates to None')

server_hostname = opts.get('server_hostname')
self.ssl_options = pika.SSLOptions(context=cxt,
server_hostname=server_hostname)

def _set_url_tcp_options(self, value):
"""Deserialize and apply the corresponding query string arg"""
Expand Down
16 changes: 7 additions & 9 deletions tests/unit/connection_parameters_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,12 +699,10 @@ def test_good_parameters(self):
# on <VerifyMode.CERT_NONE: 1>:
# {'cert_reqs': <VerifyMode.CERT_NONE: 1>, 'server_hostname': 'blah.blah.com'}
'ssl_options': {
'keyfile': None,
'certfile': None,
'ssl_version': int(ssl.PROTOCOL_SSLv23),
'ca_certs': None,
'cert_reqs': int(ssl.CERT_NONE),
'npn_protocols': None,
'ca_certs': '/etc/ssl',
'certfile': '/etc/certs/cert.pem',
'keyfile': '/etc/certs/key.pem',
'password': 'test123',
'ciphers': None,
'server_hostname': 'blah.blah.com'
},
Expand All @@ -719,7 +717,7 @@ def test_good_parameters(self):
test_params['backpressure_detection'] = backpressure
virtual_host = '/'
query_string = urlencode(test_params)
test_url = ('https://myuser:mypass@www.test.com:5678/%s?%s' % (
test_url = ('amqps://myuser:mypass@www.test.com:5678/%s?%s' % (
url_quote(virtual_host, safe=''),
query_string,
))
Expand All @@ -733,8 +731,6 @@ def test_good_parameters(self):
actual_value = getattr(params, t_param)

if t_param == 'ssl_options':
self.assertEqual(actual_value.context.verify_mode,
expected_value['cert_reqs'])
self.assertEqual(actual_value.server_hostname,
expected_value['server_hostname'])
else:
Expand All @@ -749,6 +745,8 @@ def test_good_parameters(self):

# check all values from base URL
self.assertIsNotNone(params.ssl_options)
self.assertIsNotNone(params.ssl_options.context)
self.assertIsInstance(params.ssl_options.context, ssl.SSLContext)
self.assertEqual(params.credentials.username, 'myuser')
self.assertEqual(params.credentials.password, 'mypass')
self.assertEqual(params.host, 'www.test.com')
Expand Down

0 comments on commit 5144400

Please sign in to comment.