diff --git a/injector/__init__.py b/injector/__init__.py index 9b52729..d231ff3 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -22,24 +22,25 @@ import threading import types from abc import ABCMeta, abstractmethod -from collections import namedtuple +from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Any, Callable, - cast, Dict, + Generator, Generic, - get_args, Iterable, List, Optional, - overload, Set, Tuple, Type, TypeVar, - TYPE_CHECKING, Union, + cast, + get_args, + overload, ) try: @@ -52,13 +53,13 @@ # canonical. Since this typing_extensions import is only for mypy it'll work even without # typing_extensions actually installed so all's good. if TYPE_CHECKING: - from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints + from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints else: # Ignoring errors here as typing_extensions stub doesn't know about those things yet try: - from typing import _AnnotatedAlias, Annotated, get_type_hints + from typing import Annotated, _AnnotatedAlias, get_type_hints except ImportError: - from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints + from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints __author__ = 'Alec Thomas ' @@ -340,39 +341,60 @@ def __repr__(self) -> str: @private -class ListOfProviders(Provider, Generic[T]): +class MultiBinder(Provider, Generic[T]): """Provide a list of instances via other Providers.""" - _providers: List[Provider[T]] + _multi_bindings: List['Binding'] - def __init__(self) -> None: - self._providers = [] + def __init__(self, parent: 'Binder') -> None: + self._multi_bindings = [] + self._binder = Binder(parent.injector, auto_bind=False, parent=parent) - def append(self, provider: Provider[T]) -> None: - self._providers.append(provider) + def append(self, provider: Provider[T], scope: Type['Scope']) -> None: + # HACK: generate a pseudo-type for this element in the list. + # This is needed for scopes to work properly. Some, like the Singleton scope, + # key instances by type, so we need one that is unique to this binding. + pseudo_type = type(f"multibind-type-{id(provider)}", (provider.__class__,), {}) + self._multi_bindings.append(Binding(pseudo_type, provider, scope)) + + def get_scoped_providers(self, injector: 'Injector') -> Generator[Provider[T], None, None]: + for binding in self._multi_bindings: + if ( + isinstance(binding.provider, ClassProvider) + and binding.scope is NoScope + and self._binder.parent + and self._binder.parent.has_explicit_binding_for(binding.provider._cls) + ): + parent_binding, _ = self._binder.parent.get_binding(binding.provider._cls) + scope_binding, _ = self._binder.parent.get_binding(parent_binding.scope) + else: + scope_binding, _ = self._binder.get_binding(binding.scope) + scope_instance: Scope = scope_binding.provider.get(injector) + provider_instance = scope_instance.get(binding.interface, binding.provider) + yield provider_instance def __repr__(self) -> str: - return '%s(%r)' % (type(self).__name__, self._providers) + return '%s(%r)' % (type(self).__name__, self._multi_bindings) -class MultiBindProvider(ListOfProviders[List[T]]): +class MultiBindProvider(MultiBinder[List[T]]): """Used by :meth:`Binder.multibind` to flatten results of providers that return sequences.""" def get(self, injector: 'Injector') -> List[T]: result: List[T] = [] - for provider in self._providers: + for provider in self.get_scoped_providers(injector): instances: List[T] = _ensure_iterable(provider.get(injector)) result.extend(instances) return result -class MapBindProvider(ListOfProviders[Dict[str, T]]): +class MapBindProvider(MultiBinder[Dict[str, T]]): """A provider for map bindings.""" def get(self, injector: 'Injector') -> Dict[str, T]: map: Dict[str, T] = {} - for provider in self._providers: + for provider in self.get_scoped_providers(injector): map.update(provider.get(injector)) return map @@ -387,7 +409,11 @@ def get(self, injector: 'Injector') -> Dict[str, T]: return {self._key: self._provider.get(injector)} -_BindingBase = namedtuple('_BindingBase', 'interface provider scope') +@dataclass +class _BindingBase: + interface: type + provider: Provider + scope: Type['Scope'] @private @@ -531,25 +557,8 @@ def multibind( :param scope: Optional Scope in which to bind. """ - if interface not in self._bindings: - provider: ListOfProviders - if ( - isinstance(interface, dict) - or isinstance(interface, type) - and issubclass(interface, dict) - or _get_origin(_punch_through_alias(interface)) is dict - ): - provider = MapBindProvider() - else: - provider = MultiBindProvider() - binding = self.create_binding(interface, provider, scope) - self._bindings[interface] = binding - else: - binding = self._bindings[interface] - provider = binding.provider - assert isinstance(provider, ListOfProviders) - - if isinstance(provider, MultiBindProvider) and isinstance(to, list): + multi_binder = self._get_multi_binder(interface) + if isinstance(multi_binder, MultiBindProvider) and isinstance(to, list): try: element_type = get_args(_punch_through_alias(interface))[0] except IndexError: @@ -557,8 +566,9 @@ def multibind( f"Use typing.List[T] or list[T] to specify the element type of the list" ) for element in to: - provider.append(self.provider_for(element_type, element)) - elif isinstance(provider, MapBindProvider) and isinstance(to, dict): + element_binding = self.create_binding(element_type, element, scope) + multi_binder.append(element_binding.provider, element_binding.scope) + elif isinstance(multi_binder, MapBindProvider) and isinstance(to, dict): try: value_type = get_args(_punch_through_alias(interface))[1] except IndexError: @@ -566,9 +576,32 @@ def multibind( f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict" ) for key, value in to.items(): - provider.append(KeyValueProvider(key, self.provider_for(value_type, value))) + element_binding = self.create_binding(value_type, value, scope) + multi_binder.append(KeyValueProvider(key, element_binding.provider), element_binding.scope) else: - provider.append(self.provider_for(interface, to)) + element_binding = self.create_binding(interface, to, scope) + multi_binder.append(element_binding.provider, element_binding.scope) + + def _get_multi_binder(self, interface: type) -> MultiBinder: + multi_binder: MultiBinder + if interface not in self._bindings: + if ( + isinstance(interface, dict) + or isinstance(interface, type) + and issubclass(interface, dict) + or _get_origin(_punch_through_alias(interface)) is dict + ): + multi_binder = MapBindProvider(self) + else: + multi_binder = MultiBindProvider(self) + binding = self.create_binding(interface, multi_binder) + self._bindings[interface] = binding + else: + binding = self._bindings[interface] + assert isinstance(binding.provider, MultiBinder) + multi_binder = binding.provider + + return multi_binder def install(self, module: _InstallableModuleType) -> None: """Install a module into this binder. @@ -611,10 +644,10 @@ def create_binding( self, interface: type, to: Any = None, scope: Union['ScopeDecorator', Type['Scope'], None] = None ) -> Binding: provider = self.provider_for(interface, to) - scope = scope or getattr(to or interface, '__scope__', NoScope) + scope = scope or getattr(to or interface, '__scope__', None) if isinstance(scope, ScopeDecorator): scope = scope.scope - return Binding(interface, provider, scope) + return Binding(interface, provider, scope or NoScope) def provider_for(self, interface: Any, to: Any = None) -> Provider: base_type = _punch_through_alias(interface) @@ -696,7 +729,7 @@ def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']: # The special interface is added here so that requesting a special # interface with auto_bind disabled works if self._auto_bind or self._is_special_interface(interface): - binding = ImplicitBinding(*self.create_binding(interface)) + binding = ImplicitBinding(**self.create_binding(interface).__dict__) self._bindings[interface] = binding return binding, self @@ -817,7 +850,7 @@ def __repr__(self) -> str: class NoScope(Scope): """An unscoped provider.""" - def get(self, unused_key: Type[T], provider: Provider[T]) -> Provider[T]: + def get(self, key: Type[T], provider: Provider[T]) -> Provider[T]: return provider diff --git a/injector_test.py b/injector_test.py index 6260033..9df43eb 100644 --- a/injector_test.py +++ b/injector_test.py @@ -10,14 +10,14 @@ """Functional tests for the "Injector" dependency injection framework.""" -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Any, NewType, Optional, Union import abc import sys import threading import traceback import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, NewType, Optional, Union if sys.version_info >= (3, 9): from typing import Annotated @@ -29,32 +29,32 @@ import pytest from injector import ( + AssistedBuilder, Binder, CallError, + CircularDependency, + ClassAssistedBuilder, + ClassProvider, + Error, Inject, Injector, + InstanceProvider, + InvalidInterface, + Module, NoInject, + ProviderOf, Scope, - InstanceProvider, - ClassProvider, + ScopeDecorator, + SingletonScope, + UnknownArgument, + UnsatisfiedRequirement, get_bindings, inject, multiprovider, noninjectable, + provider, singleton, threadlocal, - UnsatisfiedRequirement, - CircularDependency, - Module, - SingletonScope, - ScopeDecorator, - AssistedBuilder, - provider, - ProviderOf, - ClassAssistedBuilder, - Error, - UnknownArgument, - InvalidInterface, ) @@ -723,6 +723,65 @@ def configure_dict(binder: Binder): Injector([configure_dict]) +def test_multibind_types_respect_the_bound_type_scope() -> None: + def configure(binder: Binder) -> None: + binder.bind(PluginA, to=PluginA, scope=singleton) + binder.multibind(List[Plugin], to=PluginA) + + injector = Injector([configure]) + first_list = injector.get(List[Plugin]) + second_list = injector.get(List[Plugin]) + child_injector = injector.create_child_injector() + third_list = child_injector.get(List[Plugin]) + + assert first_list[0] is second_list[0] + assert third_list[0] is second_list[0] + + +def test_multibind_list_scopes_applies_to_the_bound_items() -> None: + def configure(binder: Binder) -> None: + binder.multibind(List[Plugin], to=PluginA, scope=singleton) + binder.multibind(List[Plugin], to=PluginB) + binder.multibind(List[Plugin], to=[PluginC], scope=singleton) + + injector = Injector([configure]) + first_list = injector.get(List[Plugin]) + second_list = injector.get(List[Plugin]) + + assert first_list is not second_list + assert first_list[0] is second_list[0] + assert first_list[1] is not second_list[1] + assert first_list[2] is second_list[2] + + +def test_multibind_dict_scopes_applies_to_the_bound_items() -> None: + def configure(binder: Binder) -> None: + binder.multibind(Dict[str, Plugin], to={'a': PluginA}, scope=singleton) + binder.multibind(Dict[str, Plugin], to={'b': PluginB}) + binder.multibind(Dict[str, Plugin], to={'c': PluginC}, scope=singleton) + + injector = Injector([configure]) + first_dict = injector.get(Dict[str, Plugin]) + second_dict = injector.get(Dict[str, Plugin]) + + assert first_dict is not second_dict + assert first_dict['a'] is second_dict['a'] + assert first_dict['b'] is not second_dict['b'] + assert first_dict['c'] is second_dict['c'] + + +def test_multibind_scopes_does_not_apply_to_the_type_globally() -> None: + def configure(binder: Binder) -> None: + binder.multibind(List[Plugin], to=PluginA, scope=singleton) + + injector = Injector([configure]) + plugins = injector.get(List[Plugin]) + + assert plugins[0] is not injector.get(PluginA) + assert plugins[0] is not injector.get(Plugin) + assert injector.get(PluginA) is not injector.get(PluginA) + + def test_regular_bind_and_provider_dont_work_with_multibind(): # We only want multibind and multiprovider to work to avoid confusion