diff --git a/docs/advanced-usage.rst b/docs/advanced-usage.rst index ac7d2c7..fcf8951 100644 --- a/docs/advanced-usage.rst +++ b/docs/advanced-usage.rst @@ -113,3 +113,40 @@ This is useful for workflows where cookies or other information need to persist It's often more useful in logs to know which module initiated the code doing the logging. ``apiron`` allows for an existing logger object to be passed to an endpoint call using the ``logger`` argument so that logs will indicate the caller module rather than :mod:`apiron.client`. + + +********************** +Instantiated endpoints +********************** + +While the other documented usage patterns implement the singleton pattern, you may wish to use instantiated +services for reasons such as those mentioned `in this issue `_. + +This feature can be enabled by setting ``APIRON_INSTANTIATED_SERVICES=1`` either in the shell in which your +program runs or early in the entrypoint to your program, prior to the evaluation of your service classes. + +Endpoints should then be called on instances of ``Service`` subclasses, rather than the class itself: + +As an additional benefit, arguments passed into the constructor will be passed through to the endpoint as arguments +to the caller that forms the request. See ``aprion.client.call`` for all the available options. + +.. code-block:: python + + import os + + import requests + + from apiron import JsonEndpoint, Service + + os.environ['APIRON_INSTANTIATED_SERVICES'] = "1" + + + class GitHub(Service): + domain = 'https://api.github.com' + user = JsonEndpoint(path='/users/{username}') + repo = JsonEndpoint(path='/repos/{org}/{repo}') + + + service = GitHub(session=requests.Session()) + response = service.user(username='defunkt') + print(response) diff --git a/src/apiron/endpoint/endpoint.py b/src/apiron/endpoint/endpoint.py index 22afba7..9b52fba 100644 --- a/src/apiron/endpoint/endpoint.py +++ b/src/apiron/endpoint/endpoint.py @@ -2,21 +2,8 @@ import logging import string -import sys import warnings -from functools import partial, update_wrapper -from typing import Optional, Any, Callable, Dict, Iterable, List, TypeVar, Union, TYPE_CHECKING - -if TYPE_CHECKING: - if sys.version_info >= (3, 10): - from typing import Concatenate, ParamSpec - else: - from typing_extensions import Concatenate, ParamSpec - - from apiron.service import Service - - P = ParamSpec("P") - R = TypeVar("R") +from typing import Optional, Iterable, Any, Dict, List, Union import requests @@ -27,26 +14,13 @@ LOGGER = logging.getLogger(__name__) -def _create_caller( - call_fn: Callable["Concatenate[Service, Endpoint, P]", "R"], - instance: Any, - owner: Any, -) -> Callable["P", "R"]: - return partial(call_fn, instance, owner) - - class Endpoint: """ A basic service endpoint that responds with the default ``Content-Type`` for that endpoint """ - def __get__(self, instance, owner): - caller = _create_caller(client.call, owner, self) - update_wrapper(caller, client.call) - return caller - - def __call__(self): - raise TypeError("Endpoints are only callable in conjunction with a Service class.") + def __call__(self, service, *args, **kwargs): + return client.call(service, self, *args, **{**self.kwargs, **service._kwargs, **kwargs}) def __init__( self, @@ -55,6 +29,7 @@ def __init__( default_params: Optional[Dict[str, Any]] = None, required_params: Optional[Iterable[str]] = None, return_raw_response_object: bool = False, + **kwargs, ): """ :param str path: @@ -72,8 +47,11 @@ def __init__( Whether to return a :class:`requests.Response` object or call :func:`format_response` on it first. This can be overridden when calling the endpoint. (Default ``False``) + :param kwargs: + Default arguments to pass through to `apiron.client.call`. """ self.default_method = default_method + self.kwargs = kwargs if "?" in path: warnings.warn( diff --git a/src/apiron/endpoint/stub.py b/src/apiron/endpoint/stub.py index 6d9e774..9b8a457 100644 --- a/src/apiron/endpoint/stub.py +++ b/src/apiron/endpoint/stub.py @@ -11,8 +11,8 @@ class StubEndpoint(Endpoint): before the endpoint is complete. """ - def __get__(self, instance, owner): - return self.stub_response + def __call__(self, service, *args, **kwargs): + return self.stub_response(*args, **kwargs) def __init__(self, stub_response: Optional[Any] = None, **kwargs): """ diff --git a/src/apiron/service/base.py b/src/apiron/service/base.py index 6d7097b..d59a90e 100644 --- a/src/apiron/service/base.py +++ b/src/apiron/service/base.py @@ -1,16 +1,33 @@ +import os +import types from typing import Any, Dict, List, Set from apiron import Endpoint class ServiceMeta(type): + _instance: "ServiceBase" + @property def required_headers(cls) -> Dict[str, str]: return cls().required_headers - @property - def endpoints(cls) -> Set[Endpoint]: - return {attr for attr_name, attr in cls.__dict__.items() if isinstance(attr, Endpoint)} + @classmethod + def _instantiated_services(cls) -> bool: + setting_variable = "APIRON_INSTANTIATED_SERVICES" + false_values = ["0", "false"] + true_values = ["1", "true"] + environment_setting = os.getenv(setting_variable, "false").lower() + if environment_setting in false_values: + return False + elif environment_setting in true_values: + return True + + setting_values = false_values + true_values + raise ValueError( + f'Invalid {setting_variable}, "{environment_setting}"\n', + f"{setting_variable} must be one of {setting_values}\n", + ) def __str__(cls) -> str: return str(cls()) @@ -18,14 +35,48 @@ def __str__(cls) -> str: def __repr__(cls) -> str: return repr(cls()) + def __new__(cls, name, bases, namespace, **kwargs): + klass = super().__new__(cls, name, bases, namespace, **kwargs) + + # Behave as a normal class if instantiated services are enabled or if + # this is an apiron base class. + if cls._instantiated_services() or klass.__module__.split(".")[:2] == ["apiron", "service"]: + return klass + + # Singleton class. + if not hasattr(klass, "_instance"): + klass._instance = klass() + + # Mask declared Endpoints with bound instance methods. (singleton) + for k, v in namespace.items(): + if isinstance(v, Endpoint): + setattr(klass, k, types.MethodType(v, klass._instance)) + + return klass._instance + class ServiceBase(metaclass=ServiceMeta): required_headers: Dict[str, Any] = {} auth = () proxies: Dict[str, str] = {} + domain: str - @classmethod - def get_hosts(cls) -> List[str]: + def __setattr__(self, name, value): + """Transform assigned Endpoints into bound instance methods.""" + if isinstance(value, Endpoint): + value = types.MethodType(value, self) + super().__setattr__(name, value) + + @property + def endpoints(self) -> Set[Endpoint]: + endpoints = set() + for attr in self.__dict__.values(): + func = getattr(attr, "__func__", None) + if isinstance(func, Endpoint): + endpoints.add(func) + return endpoints + + def get_hosts(self) -> List[str]: """ The fully-qualified hostnames that correspond to this service. These are often determined by asking a load balancer or service discovery mechanism. @@ -35,7 +86,7 @@ def get_hosts(cls) -> List[str]: :rtype: list """ - return [] + return [self.domain] class Service(ServiceBase): @@ -45,23 +96,21 @@ class Service(ServiceBase): A service has a domain off of which one or more endpoints stem. """ - domain: str + @property + def domain(self): + return self._domain if self._domain else self.__class__.domain - @classmethod - def get_hosts(cls) -> List[str]: - """ - The fully-qualified hostnames that correspond to this service. - These are often determined by asking a load balancer or service discovery mechanism. + def __init__(self, domain=None, **kwargs): + self._domain = domain + self._kwargs = kwargs - :return: - The hostname strings corresponding to this service - :rtype: - list - """ - return [cls.domain] + # Mask declared Endpoints with bound instance methods. (instantiated) + for name, attr in self.__class__.__dict__.items(): + if isinstance(attr, Endpoint): + setattr(self, name, types.MethodType(attr, self)) def __str__(self) -> str: - return self.__class__.domain + return self.domain def __repr__(self) -> str: - return f"{self.__class__.__name__}(domain={self.__class__.domain})" + return f"{self.__class__.__name__}(domain={self.domain})" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..87eb7e0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,38 @@ +import os + +import pytest + +import apiron + + +def instantiated_service(returntype="instance"): + os.environ["APIRON_INSTANTIATED_SERVICES"] = "1" + + class SomeService(apiron.Service): + pass + + if returntype == "instance": + return SomeService(domain="http://foo.com") + elif returntype == "class": + return SomeService + + raise ValueError('Expected "returntype" value to be "instance" or "class".') + + +def singleton_service(): + os.environ["APIRON_INSTANTIATED_SERVICES"] = "0" + + class SomeService(apiron.Service): + domain = "http://foo.com" + + return SomeService + + +@pytest.fixture(scope="function", params=["singleton", "instance"]) +def service(request): + if request.param == "singleton": + yield singleton_service() + elif request.param == "instance": + yield instantiated_service() + else: + raise ValueError(f'unknown service type "{request.param}"') diff --git a/tests/service/test_base.py b/tests/service/test_base.py index 8ce47c5..db6e2bb 100644 --- a/tests/service/test_base.py +++ b/tests/service/test_base.py @@ -1,20 +1,7 @@ -import pytest - -from apiron import Endpoint, Service, ServiceBase - - -@pytest.fixture -def service(): - class SomeService(Service): - domain = "http://foo.com" - - return SomeService +from apiron import Endpoint class TestServiceBase: - def test_get_hosts_returns_empty_list_by_default(self): - assert [] == ServiceBase.get_hosts() - def test_required_headers_returns_empty_dict_by_default(self, service): assert {} == service.required_headers diff --git a/tests/service/test_instantiated.py b/tests/service/test_instantiated.py new file mode 100644 index 0000000..8c30ad0 --- /dev/null +++ b/tests/service/test_instantiated.py @@ -0,0 +1,35 @@ +import os + +import pytest + +from apiron.service.base import ServiceMeta + +from .. import conftest + + +class TestInstantiatedServices: + @pytest.mark.parametrize("value,result", [("0", False), ("false", False), ("1", True), ("true", True)]) + def test_instantiated_services_variable_true(self, value, result): + os.environ["APIRON_INSTANTIATED_SERVICES"] = value + + assert ServiceMeta._instantiated_services() is result + + @pytest.mark.parametrize("value", ["", "YES"]) + def test_instantiated_services_variable_other(self, value): + os.environ["APIRON_INSTANTIATED_SERVICES"] = value + + with pytest.raises(ValueError, match="Invalid"): + ServiceMeta._instantiated_services() + + def test_singleton_constructor_arguments(self): + """Singleton services do not accept arguments.""" + service = conftest.singleton_service() + + with pytest.raises(TypeError, match="object is not callable"): + service(foo="bar") + + def test_instantiated_services_constructor_arguments(self): + """Instantiated services accept arguments.""" + service = conftest.instantiated_service(returntype="class") + + service(foo="bar") diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index f5ec476..78a25f3 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -6,14 +6,6 @@ import apiron -@pytest.fixture -def service(): - class SomeService(apiron.Service): - domain = "http://foo.com" - - return SomeService - - @pytest.fixture def stub_function(): def stub_response(**kwargs): @@ -28,12 +20,12 @@ def stub_response(**kwargs): class TestEndpoint: def test_call(self, service): service.foo = apiron.Endpoint() - service.foo() + service.foo() # type: ignore def test_call_without_service_raises_exception(self): foo = apiron.Endpoint() with pytest.raises(TypeError): - foo() + foo() # type: ignore def test_default_attributes_from_constructor(self): foo = apiron.Endpoint() @@ -163,11 +155,11 @@ def test_repr_method(self): class TestStubEndpoint: def test_stub_default_response(self, service): service.stub_endpoint = apiron.StubEndpoint() - assert service.stub_endpoint() == {"response": "StubEndpoint(path='/')"} + assert service.stub_endpoint() == {"response": "StubEndpoint(path='/')"} # type: ignore def test_call_static(self, service): service.stub_endpoint = apiron.StubEndpoint(stub_response="stub response") - assert service.stub_endpoint() == "stub response" + assert service.stub_endpoint() == "stub response" # type: ignore @pytest.mark.parametrize( "test_call_kwargs,expected_response", @@ -181,7 +173,7 @@ def test_call_dynamic(self, test_call_kwargs, expected_response, service, stub_f def test_call_without_service_raises_exception(self): stub_endpoint = apiron.StubEndpoint(stub_response="foo") with pytest.raises(TypeError): - stub_endpoint() + stub_endpoint() # type: ignore def test_str_method(self): foo = apiron.StubEndpoint(path="/bar/baz")