Skip to content

Commit

Permalink
Add option to disable or customise SSL verification
Browse files Browse the repository at this point in the history
  • Loading branch information
sjaensch committed Oct 10, 2018
1 parent d10d688 commit f20a0c1
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 10 deletions.
43 changes: 38 additions & 5 deletions bravado_asyncio/http_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import logging
import ssl
from collections import Mapping
from typing import Any
from typing import Callable # noqa: F401
from typing import cast
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Union

import aiohttp
Expand Down Expand Up @@ -56,12 +58,26 @@ class AsyncioClient(HttpClient):
async / await.
"""

def __init__(self, run_mode: RunMode=RunMode.THREAD, loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
ssl_verify: Optional[Union[bool, ssl.SSLContext]]

def __init__(
self,
run_mode: RunMode=RunMode.THREAD,
loop: Optional[asyncio.AbstractEventLoop]=None,
ssl_verify: Union[bool, str]=True,
ssl_cert: Optional[Union[str, Sequence[str]]]=None,
) -> None:
"""Instantiate a client using the given run_mode. If you do not pass in an event loop, then
either a shared loop in a separate thread (THREAD mode) or the default asyncio
event loop (FULL_ASYNCIO mode) will be used.
Not passing in an event loop will make sure we share the :py:class:`aiohttp.ClientSession` object
between AsyncioClient instances.
:param ssl_verify: Set to False to disable SSL certificate validation. Provide the path to a
CA bundle if you need to use a custom one.
:param ssl_cert: Provide a client-side certificate to use. Either a sequence of strings pointing
to the certificate (1) and the private key (2), or a string pointing to the combined certificate
and key.
"""
self.run_mode = run_mode
if self.run_mode == RunMode.THREAD:
Expand All @@ -87,11 +103,27 @@ def __init__(self, run_mode: RunMode=RunMode.THREAD, loop: Optional[asyncio.Abst
else:
self.client_session = get_client_session(self.loop)

# translate the requests-type SSL options to a ssl.SSLContext object as used by aiohttp.
# see https://aiohttp.readthedocs.io/en/stable/client_advanced.html#ssl-control-for-tcp-sockets
self.ssl_verify = None # None is the default value of the ssl argument for the aiohttp request function
if ssl_verify is False:
self.ssl_verify = False

if isinstance(ssl_verify, str) or ssl_cert:
cafile = None
if isinstance(ssl_verify, str):
cafile = ssl_verify
self.ssl_verify = ssl.create_default_context(cafile=cafile)
if ssl_cert:
if isinstance(ssl_cert, str):
ssl_cert = [ssl_cert]
self.ssl_verify.load_cert_chain(*ssl_cert)

def request(
self,
request_params: Dict[str, Any],
operation: Optional[Operation]=None,
request_config: Optional[RequestConfig]=None,
self,
request_params: Dict[str, Any],
operation: Optional[Operation]=None,
request_config: Optional[RequestConfig]=None,
) -> HttpFuture:
"""Sets up the request params for aiohttp and executes the request in the background.
Expand Down Expand Up @@ -150,6 +182,7 @@ def request(
for k, v in request_params.get('headers', {}).items()
},
skip_auto_headers=skip_auto_headers,
ssl=self.ssl_verify,
timeout=timeout,
)

Expand Down
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
aiobravado
bottle
bravado-core>=4.11.0
bravado[testing]>=10.1.0
bravado>=10.1.0
coverage
ephemeral_port_reserve
mock
Expand Down
58 changes: 55 additions & 3 deletions tests/http_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ def mock_client_session():
yield _mock


@pytest.fixture(params=[RunMode.THREAD])
def asyncio_client(request, mock_client_session):
client = AsyncioClient(run_mode=request.param, loop=mock.Mock(spec=asyncio.AbstractEventLoop))
@pytest.fixture
def asyncio_client(mock_client_session, ssl_verify=True, ssl_cert=None):
client = AsyncioClient(
run_mode=RunMode.THREAD,
loop=mock.Mock(spec=asyncio.AbstractEventLoop),
ssl_verify=ssl_verify,
ssl_cert=ssl_cert,
)
client.run_coroutine_func = mock.Mock('run_coroutine_func')
return client

Expand All @@ -38,6 +43,12 @@ def request_params():
}


@pytest.fixture
def mock_create_default_context():
with mock.patch('ssl.create_default_context', autospec=True) as _mock:
yield _mock


def test_fail_on_unknown_run_mode():
with pytest.raises(ValueError):
AsyncioClient(run_mode='unknown/invalid')
Expand All @@ -59,6 +70,7 @@ def test_request(asyncio_client, mock_client_session, request_params):
data=mock.ANY,
headers={},
skip_auto_headers=['Content-Type'],
ssl=None,
timeout=None,
)
assert mock_client_session.return_value.request.call_args[1]['data']._fields == []
Expand Down Expand Up @@ -88,6 +100,7 @@ def test_simple_get(asyncio_client, mock_client_session, request_params):
data=mock.ANY,
headers={},
skip_auto_headers=['Content-Type'],
ssl=None,
timeout=None,
)
assert mock_client_session.return_value.request.call_args[1]['data']._fields == []
Expand Down Expand Up @@ -121,6 +134,7 @@ def test_formdata(asyncio_client, mock_client_session, request_params, param_nam
data=mock.ANY,
headers={},
skip_auto_headers=['Content-Type'],
ssl=None,
timeout=None,
)

Expand Down Expand Up @@ -168,3 +182,41 @@ def test_connect_timeout_logs_warning(asyncio_client, mock_client_session, reque
assert mock_log.warning.call_count == 1
assert 'connect_timeout' in mock_log.warning.call_args[0][0]
assert mock_client_session.return_value.request.call_args[1]['timeout'] is None


def test_disable_ssl_verification(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify=False)
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] is False
assert mock_create_default_context.call_count == 0


def test_use_custom_ssl_ca(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify='my_ca_cert')
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] == mock_create_default_context.return_value
mock_create_default_context.assert_called_once_with(cafile='my_ca_cert')
assert mock_create_default_context.return_value.load_cert_chain.call_count == 0


@pytest.mark.parametrize(
'ssl_cert, expected_args',
(
('my_cert', ('my_cert',)),
(['my_cert'], ('my_cert',)),
(['my_cert', 'my_key'], ('my_cert', 'my_key')),
),
)
def test_use_custom_ssl_cert(ssl_cert, expected_args, mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_cert=ssl_cert)
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] == mock_create_default_context.return_value
assert mock_create_default_context.return_value.load_cert_chain.call_args[0] == expected_args


def test_use_custom_ssl_ca_and_cert(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify='my_ca_cert', ssl_cert='my_cert')
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] == mock_create_default_context.return_value
mock_create_default_context.assert_called_once_with(cafile='my_ca_cert')
assert mock_create_default_context.return_value.load_cert_chain.call_args[0] == ('my_cert',)

0 comments on commit f20a0c1

Please sign in to comment.