Skip to content

Commit

Permalink
Merge 4348175 into d10d688
Browse files Browse the repository at this point in the history
  • Loading branch information
sjaensch committed Oct 12, 2018
2 parents d10d688 + 4348175 commit 7a2dd13
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 10 deletions.
53 changes: 48 additions & 5 deletions bravado_asyncio/http_client.py
@@ -1,11 +1,14 @@
import asyncio
import logging
import ssl
from collections import Mapping
from distutils.version import LooseVersion
from typing import Any
from typing import Callable # noqa: F401
from typing import cast
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Union

import aiohttp
Expand Down Expand Up @@ -56,12 +59,24 @@ class AsyncioClient(HttpClient):
async / await.
"""

def __init__(self, run_mode: RunMode=RunMode.THREAD, loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
def __init__(
self,
run_mode: RunMode=RunMode.THREAD,
loop: Optional[asyncio.AbstractEventLoop]=None,
ssl_verify: Optional[Union[bool, str]]=None,
ssl_cert: Optional[Union[str, Sequence[str]]]=None,
) -> None:
"""Instantiate a client using the given run_mode. If you do not pass in an event loop, then
either a shared loop in a separate thread (THREAD mode) or the default asyncio
event loop (FULL_ASYNCIO mode) will be used.
Not passing in an event loop will make sure we share the :py:class:`aiohttp.ClientSession` object
between AsyncioClient instances.
:param ssl_verify: Set to False to disable SSL certificate validation. Provide the path to a
CA bundle if you need to use a custom one.
:param ssl_cert: Provide a client-side certificate to use. Either a sequence of strings pointing
to the certificate (1) and the private key (2), or a string pointing to the combined certificate
and key.
"""
self.run_mode = run_mode
if self.run_mode == RunMode.THREAD:
Expand All @@ -87,11 +102,27 @@ def __init__(self, run_mode: RunMode=RunMode.THREAD, loop: Optional[asyncio.Abst
else:
self.client_session = get_client_session(self.loop)

# translate the requests-type SSL options to a ssl.SSLContext object as used by aiohttp.
# see https://aiohttp.readthedocs.io/en/stable/client_advanced.html#ssl-control-for-tcp-sockets
if isinstance(ssl_verify, str) or ssl_cert:
self.ssl_verify = None # type: Optional[bool]
cafile = None
if isinstance(ssl_verify, str):
cafile = ssl_verify
self.ssl_context = ssl.create_default_context(cafile=cafile) # type: Optional[ssl.SSLContext]
if ssl_cert:
if isinstance(ssl_cert, str):
ssl_cert = [ssl_cert]
self.ssl_context.load_cert_chain(*ssl_cert)
else:
self.ssl_verify = ssl_verify
self.ssl_context = None

def request(
self,
request_params: Dict[str, Any],
operation: Optional[Operation]=None,
request_config: Optional[RequestConfig]=None,
self,
request_params: Dict[str, Any],
operation: Optional[Operation]=None,
request_config: Optional[RequestConfig]=None,
) -> HttpFuture:
"""Sets up the request params for aiohttp and executes the request in the background.
Expand Down Expand Up @@ -151,6 +182,7 @@ def request(
},
skip_auto_headers=skip_auto_headers,
timeout=timeout,
**self._get_ssl_params()
)

future = self.run_coroutine_func(coroutine, loop=self.loop)
Expand All @@ -171,3 +203,14 @@ def prepare_params(self, params: Optional[Dict[str, Any]]) -> Union[Optional[Dic
entries = [(key, str(value))] if not is_list_like(value) else [(key, str(v)) for v in value]
items.extend(entries)
return MultiDict(items)

def _get_ssl_params(self) -> Dict[str, Any]:
aiohttp_version = LooseVersion(aiohttp.__version__)
if aiohttp_version < LooseVersion('3'):
if (self.ssl_verify is not None) or (self.ssl_context is not None):
log.warning('SSL options are not supported and will be ignored for aiohttp versions below 3.')
return {}
else:
return {
'ssl': self.ssl_context if self.ssl_context else self.ssl_verify
}
3 changes: 1 addition & 2 deletions requirements-dev.txt
@@ -1,7 +1,6 @@
aiobravado
bottle
bravado-core>=4.11.0
bravado[testing]>=10.1.0
bravado[integration-tests]>=10.1.0
coverage
ephemeral_port_reserve
mock
Expand Down
71 changes: 68 additions & 3 deletions tests/http_client_test.py
Expand Up @@ -16,9 +16,14 @@ def mock_client_session():
yield _mock


@pytest.fixture(params=[RunMode.THREAD])
def asyncio_client(request, mock_client_session):
client = AsyncioClient(run_mode=request.param, loop=mock.Mock(spec=asyncio.AbstractEventLoop))
@pytest.fixture
def asyncio_client(mock_client_session, ssl_verify=None, ssl_cert=None):
client = AsyncioClient(
run_mode=RunMode.THREAD,
loop=mock.Mock(spec=asyncio.AbstractEventLoop),
ssl_verify=ssl_verify,
ssl_cert=ssl_cert,
)
client.run_coroutine_func = mock.Mock('run_coroutine_func')
return client

Expand All @@ -38,11 +43,24 @@ def request_params():
}


@pytest.fixture
def mock_aiohttp_version():
with mock.patch('aiohttp.__version__', new='3.0.0'):
yield


@pytest.fixture
def mock_create_default_context():
with mock.patch('ssl.create_default_context', autospec=True) as _mock:
yield _mock


def test_fail_on_unknown_run_mode():
with pytest.raises(ValueError):
AsyncioClient(run_mode='unknown/invalid')


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_request(asyncio_client, mock_client_session, request_params):
"""Make sure request calls the right functions and instantiates the HttpFuture correctly."""
asyncio_client.response_adapter = mock.Mock(name='response_adapter', spec=AioHTTPResponseAdapter)
Expand All @@ -59,6 +77,7 @@ def test_request(asyncio_client, mock_client_session, request_params):
data=mock.ANY,
headers={},
skip_auto_headers=['Content-Type'],
ssl=None,
timeout=None,
)
assert mock_client_session.return_value.request.call_args[1]['data']._fields == []
Expand All @@ -76,6 +95,7 @@ def test_request(asyncio_client, mock_client_session, request_params):
)


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_simple_get(asyncio_client, mock_client_session, request_params):
request_params['params'] = {'foo': 'bar'}

Expand All @@ -88,6 +108,7 @@ def test_simple_get(asyncio_client, mock_client_session, request_params):
data=mock.ANY,
headers={},
skip_auto_headers=['Content-Type'],
ssl=None,
timeout=None,
)
assert mock_client_session.return_value.request.call_args[1]['data']._fields == []
Expand All @@ -100,6 +121,7 @@ def test_int_param(asyncio_client, mock_client_session, request_params):
assert mock_client_session.return_value.request.call_args[1]['params'] == {'foo': '5'}


@pytest.mark.usefixtures('mock_aiohttp_version')
@pytest.mark.parametrize(
'param_name, param_value, expected_param_value',
(
Expand All @@ -121,6 +143,7 @@ def test_formdata(asyncio_client, mock_client_session, request_params, param_nam
data=mock.ANY,
headers={},
skip_auto_headers=['Content-Type'],
ssl=None,
timeout=None,
)

Expand Down Expand Up @@ -168,3 +191,45 @@ def test_connect_timeout_logs_warning(asyncio_client, mock_client_session, reque
assert mock_log.warning.call_count == 1
assert 'connect_timeout' in mock_log.warning.call_args[0][0]
assert mock_client_session.return_value.request.call_args[1]['timeout'] is None


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_disable_ssl_verification(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify=False)
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] is False
assert mock_create_default_context.call_count == 0


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_use_custom_ssl_ca(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify='my_ca_cert')
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] == mock_create_default_context.return_value
mock_create_default_context.assert_called_once_with(cafile='my_ca_cert')
assert mock_create_default_context.return_value.load_cert_chain.call_count == 0


@pytest.mark.usefixtures('mock_aiohttp_version')
@pytest.mark.parametrize(
'ssl_cert, expected_args',
(
('my_cert', ('my_cert',)),
(['my_cert'], ('my_cert',)),
(['my_cert', 'my_key'], ('my_cert', 'my_key')),
),
)
def test_use_custom_ssl_cert(ssl_cert, expected_args, mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_cert=ssl_cert)
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] == mock_create_default_context.return_value
assert mock_create_default_context.return_value.load_cert_chain.call_args[0] == expected_args


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_use_custom_ssl_ca_and_cert(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify='my_ca_cert', ssl_cert='my_cert')
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] == mock_create_default_context.return_value
mock_create_default_context.assert_called_once_with(cafile='my_ca_cert')
assert mock_create_default_context.return_value.load_cert_chain.call_args[0] == ('my_cert',)

0 comments on commit 7a2dd13

Please sign in to comment.