-
Notifications
You must be signed in to change notification settings - Fork 94
fix: Multibind scopes #284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
11042c7
bdbf66c
1032e80
efc62e1
4d57f11
b4d9fcb
37e071e
11b52e2
2d45015
9faa32a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,24 +22,25 @@ | |
import threading | ||
import types | ||
from abc import ABCMeta, abstractmethod | ||
from collections import namedtuple | ||
from dataclasses import dataclass | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Callable, | ||
cast, | ||
Dict, | ||
Generator, | ||
Generic, | ||
get_args, | ||
Iterable, | ||
List, | ||
Optional, | ||
overload, | ||
Set, | ||
Tuple, | ||
Type, | ||
TypeVar, | ||
TYPE_CHECKING, | ||
Union, | ||
cast, | ||
get_args, | ||
overload, | ||
) | ||
|
||
try: | ||
|
@@ -52,13 +53,13 @@ | |
# canonical. Since this typing_extensions import is only for mypy it'll work even without | ||
# typing_extensions actually installed so all's good. | ||
if TYPE_CHECKING: | ||
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints | ||
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints | ||
else: | ||
# Ignoring errors here as typing_extensions stub doesn't know about those things yet | ||
try: | ||
from typing import _AnnotatedAlias, Annotated, get_type_hints | ||
from typing import Annotated, _AnnotatedAlias, get_type_hints | ||
except ImportError: | ||
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints | ||
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints | ||
|
||
|
||
__author__ = 'Alec Thomas <alec@swapoff.org>' | ||
|
@@ -340,39 +341,60 @@ def __repr__(self) -> str: | |
|
||
|
||
@private | ||
class ListOfProviders(Provider, Generic[T]): | ||
class MultiBinder(Provider, Generic[T]): | ||
"""Provide a list of instances via other Providers.""" | ||
|
||
_providers: List[Provider[T]] | ||
_multi_bindings: List['Binding'] | ||
|
||
def __init__(self) -> None: | ||
self._providers = [] | ||
def __init__(self, parent: 'Binder') -> None: | ||
self._multi_bindings = [] | ||
self._binder = Binder(parent.injector, auto_bind=False, parent=parent) | ||
|
||
def append(self, provider: Provider[T]) -> None: | ||
self._providers.append(provider) | ||
def append(self, provider: Provider[T], scope: Type['Scope']) -> None: | ||
# HACK: generate a pseudo-type for this element in the list. | ||
# This is needed for scopes to work properly. Some, like the Singleton scope, | ||
# key instances by type, so we need one that is unique to this binding. | ||
pseudo_type = type(f"multibind-type-{id(provider)}", (provider.__class__,), {}) | ||
self._multi_bindings.append(Binding(pseudo_type, provider, scope)) | ||
|
||
def get_scoped_providers(self, injector: 'Injector') -> Generator[Provider[T], None, None]: | ||
for binding in self._multi_bindings: | ||
if ( | ||
isinstance(binding.provider, ClassProvider) | ||
and binding.scope is NoScope | ||
and self._binder.parent | ||
and self._binder.parent.has_explicit_binding_for(binding.provider._cls) | ||
): | ||
parent_binding, _ = self._binder.parent.get_binding(binding.provider._cls) | ||
scope_binding, _ = self._binder.parent.get_binding(parent_binding.scope) | ||
else: | ||
scope_binding, _ = self._binder.get_binding(binding.scope) | ||
scope_instance: Scope = scope_binding.provider.get(injector) | ||
provider_instance = scope_instance.get(binding.interface, binding.provider) | ||
yield provider_instance | ||
|
||
def __repr__(self) -> str: | ||
return '%s(%r)' % (type(self).__name__, self._providers) | ||
return '%s(%r)' % (type(self).__name__, self._multi_bindings) | ||
|
||
|
||
class MultiBindProvider(ListOfProviders[List[T]]): | ||
class MultiBindProvider(MultiBinder[List[T]]): | ||
"""Used by :meth:`Binder.multibind` to flatten results of providers that | ||
return sequences.""" | ||
|
||
def get(self, injector: 'Injector') -> List[T]: | ||
result: List[T] = [] | ||
for provider in self._providers: | ||
for provider in self.get_scoped_providers(injector): | ||
instances: List[T] = _ensure_iterable(provider.get(injector)) | ||
result.extend(instances) | ||
return result | ||
|
||
|
||
class MapBindProvider(ListOfProviders[Dict[str, T]]): | ||
class MapBindProvider(MultiBinder[Dict[str, T]]): | ||
"""A provider for map bindings.""" | ||
|
||
def get(self, injector: 'Injector') -> Dict[str, T]: | ||
map: Dict[str, T] = {} | ||
for provider in self._providers: | ||
for provider in self.get_scoped_providers(injector): | ||
map.update(provider.get(injector)) | ||
return map | ||
|
||
|
@@ -387,7 +409,11 @@ def get(self, injector: 'Injector') -> Dict[str, T]: | |
return {self._key: self._provider.get(injector)} | ||
|
||
|
||
_BindingBase = namedtuple('_BindingBase', 'interface provider scope') | ||
@dataclass | ||
class _BindingBase: | ||
interface: type | ||
provider: Provider | ||
scope: Type['Scope'] | ||
Comment on lines
-390
to
+416
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to get more type-checking confidence in my work, I converted this named tuple to a dataclass. |
||
|
||
|
||
@private | ||
|
@@ -531,44 +557,51 @@ def multibind( | |
|
||
:param scope: Optional Scope in which to bind. | ||
""" | ||
if interface not in self._bindings: | ||
provider: ListOfProviders | ||
if ( | ||
isinstance(interface, dict) | ||
or isinstance(interface, type) | ||
and issubclass(interface, dict) | ||
or _get_origin(_punch_through_alias(interface)) is dict | ||
): | ||
provider = MapBindProvider() | ||
else: | ||
provider = MultiBindProvider() | ||
binding = self.create_binding(interface, provider, scope) | ||
self._bindings[interface] = binding | ||
else: | ||
binding = self._bindings[interface] | ||
provider = binding.provider | ||
assert isinstance(provider, ListOfProviders) | ||
|
||
if isinstance(provider, MultiBindProvider) and isinstance(to, list): | ||
multi_binder = self._get_multi_binder(interface) | ||
if isinstance(multi_binder, MultiBindProvider) and isinstance(to, list): | ||
try: | ||
element_type = get_args(_punch_through_alias(interface))[0] | ||
except IndexError: | ||
raise InvalidInterface( | ||
f"Use typing.List[T] or list[T] to specify the element type of the list" | ||
) | ||
for element in to: | ||
provider.append(self.provider_for(element_type, element)) | ||
elif isinstance(provider, MapBindProvider) and isinstance(to, dict): | ||
element_binding = self.create_binding(element_type, element, scope) | ||
multi_binder.append(element_binding.provider, element_binding.scope) | ||
elif isinstance(multi_binder, MapBindProvider) and isinstance(to, dict): | ||
try: | ||
value_type = get_args(_punch_through_alias(interface))[1] | ||
except IndexError: | ||
raise InvalidInterface( | ||
f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict" | ||
) | ||
for key, value in to.items(): | ||
provider.append(KeyValueProvider(key, self.provider_for(value_type, value))) | ||
element_binding = self.create_binding(value_type, value, scope) | ||
multi_binder.append(KeyValueProvider(key, element_binding.provider), element_binding.scope) | ||
else: | ||
provider.append(self.provider_for(interface, to)) | ||
element_binding = self.create_binding(interface, to, scope) | ||
multi_binder.append(element_binding.provider, element_binding.scope) | ||
|
||
def _get_multi_binder(self, interface: type) -> MultiBinder: | ||
multi_binder: MultiBinder | ||
if interface not in self._bindings: | ||
if ( | ||
isinstance(interface, dict) | ||
or isinstance(interface, type) | ||
and issubclass(interface, dict) | ||
or _get_origin(_punch_through_alias(interface)) is dict | ||
): | ||
multi_binder = MapBindProvider(self) | ||
else: | ||
multi_binder = MultiBindProvider(self) | ||
binding = self.create_binding(interface, multi_binder) | ||
self._bindings[interface] = binding | ||
else: | ||
binding = self._bindings[interface] | ||
assert isinstance(binding.provider, MultiBinder) | ||
multi_binder = binding.provider | ||
|
||
return multi_binder | ||
|
||
def install(self, module: _InstallableModuleType) -> None: | ||
"""Install a module into this binder. | ||
|
@@ -611,10 +644,10 @@ def create_binding( | |
self, interface: type, to: Any = None, scope: Union['ScopeDecorator', Type['Scope'], None] = None | ||
) -> Binding: | ||
provider = self.provider_for(interface, to) | ||
scope = scope or getattr(to or interface, '__scope__', NoScope) | ||
scope = scope or getattr(to or interface, '__scope__', None) | ||
if isinstance(scope, ScopeDecorator): | ||
scope = scope.scope | ||
return Binding(interface, provider, scope) | ||
return Binding(interface, provider, scope or NoScope) | ||
|
||
def provider_for(self, interface: Any, to: Any = None) -> Provider: | ||
base_type = _punch_through_alias(interface) | ||
|
@@ -696,7 +729,7 @@ def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']: | |
# The special interface is added here so that requesting a special | ||
# interface with auto_bind disabled works | ||
if self._auto_bind or self._is_special_interface(interface): | ||
binding = ImplicitBinding(*self.create_binding(interface)) | ||
binding = ImplicitBinding(**self.create_binding(interface).__dict__) | ||
self._bindings[interface] = binding | ||
return binding, self | ||
|
||
|
@@ -817,7 +850,7 @@ def __repr__(self) -> str: | |
class NoScope(Scope): | ||
"""An unscoped provider.""" | ||
|
||
def get(self, unused_key: Type[T], provider: Provider[T]) -> Provider[T]: | ||
def get(self, key: Type[T], provider: Provider[T]) -> Provider[T]: | ||
Comment on lines
-820
to
+853
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drive-by: fix invalid override (keyword arguments must match that of the method being overridden). |
||
return provider | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,14 +10,14 @@ | |
|
||
"""Functional tests for the "Injector" dependency injection framework.""" | ||
|
||
from contextlib import contextmanager | ||
from dataclasses import dataclass | ||
from typing import Any, NewType, Optional, Union | ||
import abc | ||
import sys | ||
import threading | ||
import traceback | ||
import warnings | ||
from contextlib import contextmanager | ||
from dataclasses import dataclass | ||
from typing import Any, NewType, Optional, Union | ||
|
||
if sys.version_info >= (3, 9): | ||
from typing import Annotated | ||
|
@@ -29,32 +29,32 @@ | |
import pytest | ||
|
||
from injector import ( | ||
AssistedBuilder, | ||
Binder, | ||
CallError, | ||
CircularDependency, | ||
ClassAssistedBuilder, | ||
ClassProvider, | ||
Error, | ||
Inject, | ||
Injector, | ||
InstanceProvider, | ||
InvalidInterface, | ||
Module, | ||
NoInject, | ||
ProviderOf, | ||
Scope, | ||
Comment on lines
+32
to
46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ran black on this file as well, which caused the imports to be sorted |
||
InstanceProvider, | ||
ClassProvider, | ||
ScopeDecorator, | ||
SingletonScope, | ||
UnknownArgument, | ||
UnsatisfiedRequirement, | ||
get_bindings, | ||
inject, | ||
multiprovider, | ||
noninjectable, | ||
provider, | ||
singleton, | ||
threadlocal, | ||
UnsatisfiedRequirement, | ||
CircularDependency, | ||
Module, | ||
SingletonScope, | ||
ScopeDecorator, | ||
AssistedBuilder, | ||
provider, | ||
ProviderOf, | ||
ClassAssistedBuilder, | ||
Error, | ||
UnknownArgument, | ||
InvalidInterface, | ||
) | ||
|
||
|
||
|
@@ -723,6 +723,65 @@ def configure_dict(binder: Binder): | |
Injector([configure_dict]) | ||
|
||
|
||
def test_multibind_types_respect_the_bound_type_scope() -> None: | ||
def configure(binder: Binder) -> None: | ||
binder.bind(PluginA, to=PluginA, scope=singleton) | ||
binder.multibind(List[Plugin], to=PluginA) | ||
|
||
injector = Injector([configure]) | ||
first_list = injector.get(List[Plugin]) | ||
second_list = injector.get(List[Plugin]) | ||
child_injector = injector.create_child_injector() | ||
third_list = child_injector.get(List[Plugin]) | ||
|
||
assert first_list[0] is second_list[0] | ||
assert third_list[0] is second_list[0] | ||
|
||
|
||
def test_multibind_list_scopes_applies_to_the_bound_items() -> None: | ||
def configure(binder: Binder) -> None: | ||
binder.multibind(List[Plugin], to=PluginA, scope=singleton) | ||
binder.multibind(List[Plugin], to=PluginB) | ||
binder.multibind(List[Plugin], to=[PluginC], scope=singleton) | ||
|
||
injector = Injector([configure]) | ||
first_list = injector.get(List[Plugin]) | ||
second_list = injector.get(List[Plugin]) | ||
|
||
assert first_list is not second_list | ||
assert first_list[0] is second_list[0] | ||
assert first_list[1] is not second_list[1] | ||
assert first_list[2] is second_list[2] | ||
|
||
|
||
def test_multibind_dict_scopes_applies_to_the_bound_items() -> None: | ||
def configure(binder: Binder) -> None: | ||
binder.multibind(Dict[str, Plugin], to={'a': PluginA}, scope=singleton) | ||
binder.multibind(Dict[str, Plugin], to={'b': PluginB}) | ||
binder.multibind(Dict[str, Plugin], to={'c': PluginC}, scope=singleton) | ||
|
||
injector = Injector([configure]) | ||
first_dict = injector.get(Dict[str, Plugin]) | ||
second_dict = injector.get(Dict[str, Plugin]) | ||
|
||
assert first_dict is not second_dict | ||
assert first_dict['a'] is second_dict['a'] | ||
assert first_dict['b'] is not second_dict['b'] | ||
assert first_dict['c'] is second_dict['c'] | ||
|
||
|
||
def test_multibind_scopes_does_not_apply_to_the_type_globally() -> None: | ||
def configure(binder: Binder) -> None: | ||
binder.multibind(List[Plugin], to=PluginA, scope=singleton) | ||
|
||
injector = Injector([configure]) | ||
plugins = injector.get(List[Plugin]) | ||
|
||
assert plugins[0] is not injector.get(PluginA) | ||
assert plugins[0] is not injector.get(Plugin) | ||
assert injector.get(PluginA) is not injector.get(PluginA) | ||
|
||
|
||
def test_regular_bind_and_provider_dont_work_with_multibind(): | ||
# We only want multibind and multiprovider to work to avoid confusion | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imports got sorted when I ran black