diff --git a/tests/test_builder.py b/tests/test_builder.py index 55474714..62417464 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -67,18 +67,16 @@ def test_prepare_request( class TestCallFactory(object): def test_call(self, mocker, request_definition, request_builder): - instance = object() args = () kwargs = {} request_preparer = mocker.Mock(spec=builder.RequestPreparer) request_preparer.create_request_builder.return_value = request_builder factory = builder.CallFactory( - instance, request_preparer, request_definition) assert factory(*args, **kwargs) is request_preparer.prepare_request.return_value request_definition.define_request.assert_called_with( - request_builder, (instance,) + args, kwargs + request_builder, args, kwargs ) assert request_builder.build.called diff --git a/tests/test_clients.py b/tests/test_clients.py index 90e8276b..20a7ad3c 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -5,7 +5,7 @@ import pytest # Local imports -from uplink.clients import interfaces, requests_, twisted_, get_client +from uplink.clients import interfaces, requests_, twisted_, register try: from uplink.clients import aiohttp_ @@ -17,12 +17,23 @@ not aiohttp_, reason="Requires Python 3.4 or above") +def test_get_default_client_with_non_callable(mocker): + # Setup + old_default = register.get_default_client() + register.set_default_client("client") + default_client = register.get_default_client() + register.set_default_client(old_default) + + # Verify: an object that is not callable should be returned as set. + assert default_client == "client" + + def test_get_client_with_http_client_adapter_subclass(): class HttpClientAdapterMock(interfaces.HttpClientAdapter): def create_request(self): pass - client = get_client(HttpClientAdapterMock) + client = register.get_client(HttpClientAdapterMock) assert isinstance(client, HttpClientAdapterMock) @@ -31,7 +42,7 @@ class TestRequests(object): def test_get_client(self, mocker): import requests session_mock = mocker.Mock(spec=requests.Session) - client = get_client(session_mock) + client = register.get_client(session_mock) assert isinstance(client, requests_.RequestsClient) @@ -93,7 +104,7 @@ class TestAiohttp(object): @requires_aiohttp def test_get_client(self, aiohttp_session_mock): - client = get_client(aiohttp_session_mock) + client = register.get_client(aiohttp_session_mock) assert isinstance(client, aiohttp_.AiohttpClient) @requires_aiohttp diff --git a/uplink/builder.py b/uplink/builder.py index cd739356..7b407343 100644 --- a/uplink/builder.py +++ b/uplink/builder.py @@ -78,13 +78,11 @@ def create_request_builder(self): class CallFactory(object): - def __init__(self, instance, request_preparer, request_definition): - self._instance = instance + def __init__(self, request_preparer, request_definition): self._request_preparer = request_preparer self._request_definition = request_definition def __call__(self, *args, **kwargs): - args = (self._instance,) + args builder = self._request_preparer.create_request_builder() self._request_definition.define_request(builder, args, kwargs) request = builder.build() @@ -133,16 +131,12 @@ def add_converter(self, *converters_): self._converters.extendleft(converters_) @utils.memoize() - def build(self, consumer, definition): + def build(self, definition): """ Creates a callable that uses the provided definition to execute HTTP requests when invoked. """ - return CallFactory( - consumer, - RequestPreparer(self, definition), - definition - ) + return CallFactory(RequestPreparer(self, definition), definition) class ConsumerMethod(object): @@ -172,7 +166,7 @@ def __get__(self, instance, owner): if instance is None: return self._request_definition_builder else: - return instance._builder.build(instance, self._request_definition) + return instance._builder.build(self._request_definition) class ConsumerMeta(type): diff --git a/uplink/types.py b/uplink/types.py index 451004aa..6ba04cc7 100644 --- a/uplink/types.py +++ b/uplink/types.py @@ -137,8 +137,8 @@ def annotations(self): def get_relevant_arguments(self, call_args): return filter(call_args.__contains__, self._arguments) - def handle_call(self, request_builder, func_args, func_kwargs): - call_args = utils.get_call_args(self._func, *func_args, **func_kwargs) + def handle_call(self, request_builder, args, kwargs): + call_args = utils.get_call_args(self._func, None, *args, **kwargs) for name in self.get_relevant_arguments(call_args): self.handle_argument( request_builder,