Skip to content

Commit

Permalink
Add support for aiohttp 2.X, don't use Python 3.6 annotation syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
sjaensch committed Oct 11, 2018
1 parent f20a0c1 commit 23d2b1e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 12 deletions.
31 changes: 21 additions & 10 deletions bravado_asyncio/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import ssl
from collections import Mapping
from distutils.version import LooseVersion
from typing import Any
from typing import Callable # noqa: F401
from typing import cast
Expand Down Expand Up @@ -58,13 +59,11 @@ class AsyncioClient(HttpClient):
async / await.
"""

ssl_verify: Optional[Union[bool, ssl.SSLContext]]

def __init__(
self,
run_mode: RunMode=RunMode.THREAD,
loop: Optional[asyncio.AbstractEventLoop]=None,
ssl_verify: Union[bool, str]=True,
ssl_verify: Optional[Union[bool, str]]=None,
ssl_cert: Optional[Union[str, Sequence[str]]]=None,
) -> None:
"""Instantiate a client using the given run_mode. If you do not pass in an event loop, then
Expand Down Expand Up @@ -105,19 +104,19 @@ def __init__(

# translate the requests-type SSL options to a ssl.SSLContext object as used by aiohttp.
# see https://aiohttp.readthedocs.io/en/stable/client_advanced.html#ssl-control-for-tcp-sockets
self.ssl_verify = None # None is the default value of the ssl argument for the aiohttp request function
if ssl_verify is False:
self.ssl_verify = False

if isinstance(ssl_verify, str) or ssl_cert:
self.ssl_verify = True # type: Optional[bool]
cafile = None
if isinstance(ssl_verify, str):
cafile = ssl_verify
self.ssl_verify = ssl.create_default_context(cafile=cafile)
self.ssl_context = ssl.create_default_context(cafile=cafile) # type: Optional[ssl.SSLContext]
if ssl_cert:
if isinstance(ssl_cert, str):
ssl_cert = [ssl_cert]
self.ssl_verify.load_cert_chain(*ssl_cert)
self.ssl_context.load_cert_chain(*ssl_cert)
else:
self.ssl_verify = ssl_verify
self.ssl_context = None

def request(
self,
Expand Down Expand Up @@ -182,8 +181,8 @@ def request(
for k, v in request_params.get('headers', {}).items()
},
skip_auto_headers=skip_auto_headers,
ssl=self.ssl_verify,
timeout=timeout,
**self._get_ssl_params()
)

future = self.run_coroutine_func(coroutine, loop=self.loop)
Expand All @@ -204,3 +203,15 @@ def prepare_params(self, params: Optional[Dict[str, Any]]) -> Union[Optional[Dic
entries = [(key, str(value))] if not is_list_like(value) else [(key, str(v)) for v in value]
items.extend(entries)
return MultiDict(items)

def _get_ssl_params(self) -> Dict[str, Any]:
aiohttp_version = LooseVersion(aiohttp.__version__)
if aiohttp_version < LooseVersion('3'):
return {
'verify_ssl': self.ssl_verify,
'ssl_context': self.ssl_context
}
else:
return {
'ssl': self.ssl_context if self.ssl_context else self.ssl_verify
}
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
bottle
bravado-core>=4.11.0
bravado>=10.1.0
bravado[integration-tests]>=10.1.0
coverage
ephemeral_port_reserve
mock
Expand Down
71 changes: 70 additions & 1 deletion tests/http_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def mock_client_session():


@pytest.fixture
def asyncio_client(mock_client_session, ssl_verify=True, ssl_cert=None):
def asyncio_client(mock_client_session, ssl_verify=None, ssl_cert=None):
client = AsyncioClient(
run_mode=RunMode.THREAD,
loop=mock.Mock(spec=asyncio.AbstractEventLoop),
Expand All @@ -43,6 +43,18 @@ def request_params():
}


@pytest.fixture
def mock_aiohttp_version():
with mock.patch('aiohttp.__version__', new='3.0.0'):
yield


@pytest.fixture
def mock_legacy_aiohttp_version():
with mock.patch('aiohttp.__version__', new='2.3.10'):
yield


@pytest.fixture
def mock_create_default_context():
with mock.patch('ssl.create_default_context', autospec=True) as _mock:
Expand All @@ -54,6 +66,7 @@ def test_fail_on_unknown_run_mode():
AsyncioClient(run_mode='unknown/invalid')


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_request(asyncio_client, mock_client_session, request_params):
"""Make sure request calls the right functions and instantiates the HttpFuture correctly."""
asyncio_client.response_adapter = mock.Mock(name='response_adapter', spec=AioHTTPResponseAdapter)
Expand Down Expand Up @@ -88,6 +101,7 @@ def test_request(asyncio_client, mock_client_session, request_params):
)


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_simple_get(asyncio_client, mock_client_session, request_params):
request_params['params'] = {'foo': 'bar'}

Expand All @@ -113,6 +127,7 @@ def test_int_param(asyncio_client, mock_client_session, request_params):
assert mock_client_session.return_value.request.call_args[1]['params'] == {'foo': '5'}


@pytest.mark.usefixtures('mock_aiohttp_version')
@pytest.mark.parametrize(
'param_name, param_value, expected_param_value',
(
Expand Down Expand Up @@ -184,13 +199,15 @@ def test_connect_timeout_logs_warning(asyncio_client, mock_client_session, reque
assert mock_client_session.return_value.request.call_args[1]['timeout'] is None


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_disable_ssl_verification(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify=False)
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] is False
assert mock_create_default_context.call_count == 0


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_use_custom_ssl_ca(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify='my_ca_cert')
client.request({})
Expand All @@ -199,6 +216,7 @@ def test_use_custom_ssl_ca(mock_client_session, mock_create_default_context):
assert mock_create_default_context.return_value.load_cert_chain.call_count == 0


@pytest.mark.usefixtures('mock_aiohttp_version')
@pytest.mark.parametrize(
'ssl_cert, expected_args',
(
Expand All @@ -214,9 +232,60 @@ def test_use_custom_ssl_cert(ssl_cert, expected_args, mock_client_session, mock_
assert mock_create_default_context.return_value.load_cert_chain.call_args[0] == expected_args


@pytest.mark.usefixtures('mock_aiohttp_version')
def test_use_custom_ssl_ca_and_cert(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify='my_ca_cert', ssl_cert='my_cert')
client.request({})
assert mock_client_session.return_value.request.call_args[1]['ssl'] == mock_create_default_context.return_value
mock_create_default_context.assert_called_once_with(cafile='my_ca_cert')
assert mock_create_default_context.return_value.load_cert_chain.call_args[0] == ('my_cert',)


# SSL tests for legacy aiohttp versions (2.X)

@pytest.mark.usefixtures('mock_legacy_aiohttp_version')
def test_disable_ssl_verification_legacy(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify=False)
client.request({})
assert mock_client_session.return_value.request.call_args[1]['verify_ssl'] is False
assert mock_create_default_context.call_count == 0


@pytest.mark.usefixtures('mock_legacy_aiohttp_version')
def test_use_custom_ssl_ca_legacy(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify='my_ca_cert')
client.request({})
assert mock_client_session.return_value.request.call_args[1]['verify_ssl'] is True
assert mock_client_session.return_value.request.call_args[1]['ssl_context'] == \
mock_create_default_context.return_value
mock_create_default_context.assert_called_once_with(cafile='my_ca_cert')
assert mock_create_default_context.return_value.load_cert_chain.call_count == 0


@pytest.mark.usefixtures('mock_legacy_aiohttp_version')
@pytest.mark.parametrize(
'ssl_cert, expected_args',
(
('my_cert', ('my_cert',)),
(['my_cert'], ('my_cert',)),
(['my_cert', 'my_key'], ('my_cert', 'my_key')),
),
)
def test_use_custom_ssl_cert_legacy(ssl_cert, expected_args, mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_cert=ssl_cert)
client.request({})
assert mock_client_session.return_value.request.call_args[1]['verify_ssl'] is True
assert mock_client_session.return_value.request.call_args[1]['ssl_context'] ==\
mock_create_default_context.return_value
assert mock_create_default_context.return_value.load_cert_chain.call_args[0] == expected_args


@pytest.mark.usefixtures('mock_legacy_aiohttp_version')
def test_use_custom_ssl_ca_and_cert_legacy(mock_client_session, mock_create_default_context):
client = asyncio_client(mock_client_session=mock_client_session, ssl_verify='my_ca_cert', ssl_cert='my_cert')
client.request({})
assert mock_client_session.return_value.request.call_args[1]['verify_ssl'] is True
assert mock_client_session.return_value.request.call_args[1]['ssl_context'] == \
mock_create_default_context.return_value
mock_create_default_context.assert_called_once_with(cafile='my_ca_cert')
assert mock_create_default_context.return_value.load_cert_chain.call_args[0] == ('my_cert',)

0 comments on commit 23d2b1e

Please sign in to comment.