Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 80 additions & 47 deletions injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,25 @@
import threading
import types
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
cast,
Dict,
Generator,
Generic,
get_args,
Iterable,
List,
Optional,
overload,
Set,
Tuple,
Type,
TypeVar,
TYPE_CHECKING,
Union,
cast,
get_args,
overload,
Comment on lines +27 to +43
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports got sorted when I ran black

)

try:
Expand All @@ -52,13 +53,13 @@
# canonical. Since this typing_extensions import is only for mypy it'll work even without
# typing_extensions actually installed so all's good.
if TYPE_CHECKING:
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints
else:
# Ignoring errors here as typing_extensions stub doesn't know about those things yet
try:
from typing import _AnnotatedAlias, Annotated, get_type_hints
from typing import Annotated, _AnnotatedAlias, get_type_hints
except ImportError:
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints


__author__ = 'Alec Thomas <alec@swapoff.org>'
Expand Down Expand Up @@ -340,39 +341,60 @@ def __repr__(self) -> str:


@private
class ListOfProviders(Provider, Generic[T]):
class MultiBinder(Provider, Generic[T]):
"""Provide a list of instances via other Providers."""

_providers: List[Provider[T]]
_multi_bindings: List['Binding']

def __init__(self) -> None:
self._providers = []
def __init__(self, parent: 'Binder') -> None:
self._multi_bindings = []
self._binder = Binder(parent.injector, auto_bind=False, parent=parent)

def append(self, provider: Provider[T]) -> None:
self._providers.append(provider)
def append(self, provider: Provider[T], scope: Type['Scope']) -> None:
# HACK: generate a pseudo-type for this element in the list.
# This is needed for scopes to work properly. Some, like the Singleton scope,
# key instances by type, so we need one that is unique to this binding.
pseudo_type = type(f"multibind-type-{id(provider)}", (provider.__class__,), {})
self._multi_bindings.append(Binding(pseudo_type, provider, scope))

def get_scoped_providers(self, injector: 'Injector') -> Generator[Provider[T], None, None]:
for binding in self._multi_bindings:
if (
isinstance(binding.provider, ClassProvider)
and binding.scope is NoScope
and self._binder.parent
and self._binder.parent.has_explicit_binding_for(binding.provider._cls)
):
parent_binding, _ = self._binder.parent.get_binding(binding.provider._cls)
scope_binding, _ = self._binder.parent.get_binding(parent_binding.scope)
else:
scope_binding, _ = self._binder.get_binding(binding.scope)
scope_instance: Scope = scope_binding.provider.get(injector)
provider_instance = scope_instance.get(binding.interface, binding.provider)
yield provider_instance

def __repr__(self) -> str:
return '%s(%r)' % (type(self).__name__, self._providers)
return '%s(%r)' % (type(self).__name__, self._multi_bindings)


class MultiBindProvider(ListOfProviders[List[T]]):
class MultiBindProvider(MultiBinder[List[T]]):
"""Used by :meth:`Binder.multibind` to flatten results of providers that
return sequences."""

def get(self, injector: 'Injector') -> List[T]:
result: List[T] = []
for provider in self._providers:
for provider in self.get_scoped_providers(injector):
instances: List[T] = _ensure_iterable(provider.get(injector))
result.extend(instances)
return result


class MapBindProvider(ListOfProviders[Dict[str, T]]):
class MapBindProvider(MultiBinder[Dict[str, T]]):
"""A provider for map bindings."""

def get(self, injector: 'Injector') -> Dict[str, T]:
map: Dict[str, T] = {}
for provider in self._providers:
for provider in self.get_scoped_providers(injector):
map.update(provider.get(injector))
return map

Expand All @@ -387,7 +409,11 @@ def get(self, injector: 'Injector') -> Dict[str, T]:
return {self._key: self._provider.get(injector)}


_BindingBase = namedtuple('_BindingBase', 'interface provider scope')
@dataclass
class _BindingBase:
interface: type
provider: Provider
scope: Type['Scope']
Comment on lines -390 to +416
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to get more type-checking confidence in my work, I converted this named tuple to a dataclass.



@private
Expand Down Expand Up @@ -531,44 +557,51 @@ def multibind(

:param scope: Optional Scope in which to bind.
"""
if interface not in self._bindings:
provider: ListOfProviders
if (
isinstance(interface, dict)
or isinstance(interface, type)
and issubclass(interface, dict)
or _get_origin(_punch_through_alias(interface)) is dict
):
provider = MapBindProvider()
else:
provider = MultiBindProvider()
binding = self.create_binding(interface, provider, scope)
self._bindings[interface] = binding
else:
binding = self._bindings[interface]
provider = binding.provider
assert isinstance(provider, ListOfProviders)

if isinstance(provider, MultiBindProvider) and isinstance(to, list):
multi_binder = self._get_multi_binder(interface)
if isinstance(multi_binder, MultiBindProvider) and isinstance(to, list):
try:
element_type = get_args(_punch_through_alias(interface))[0]
except IndexError:
raise InvalidInterface(
f"Use typing.List[T] or list[T] to specify the element type of the list"
)
for element in to:
provider.append(self.provider_for(element_type, element))
elif isinstance(provider, MapBindProvider) and isinstance(to, dict):
element_binding = self.create_binding(element_type, element, scope)
multi_binder.append(element_binding.provider, element_binding.scope)
elif isinstance(multi_binder, MapBindProvider) and isinstance(to, dict):
try:
value_type = get_args(_punch_through_alias(interface))[1]
except IndexError:
raise InvalidInterface(
f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict"
)
for key, value in to.items():
provider.append(KeyValueProvider(key, self.provider_for(value_type, value)))
element_binding = self.create_binding(value_type, value, scope)
multi_binder.append(KeyValueProvider(key, element_binding.provider), element_binding.scope)
else:
provider.append(self.provider_for(interface, to))
element_binding = self.create_binding(interface, to, scope)
multi_binder.append(element_binding.provider, element_binding.scope)

def _get_multi_binder(self, interface: type) -> MultiBinder:
multi_binder: MultiBinder
if interface not in self._bindings:
if (
isinstance(interface, dict)
or isinstance(interface, type)
and issubclass(interface, dict)
or _get_origin(_punch_through_alias(interface)) is dict
):
multi_binder = MapBindProvider(self)
else:
multi_binder = MultiBindProvider(self)
binding = self.create_binding(interface, multi_binder)
self._bindings[interface] = binding
else:
binding = self._bindings[interface]
assert isinstance(binding.provider, MultiBinder)
multi_binder = binding.provider

return multi_binder

def install(self, module: _InstallableModuleType) -> None:
"""Install a module into this binder.
Expand Down Expand Up @@ -611,10 +644,10 @@ def create_binding(
self, interface: type, to: Any = None, scope: Union['ScopeDecorator', Type['Scope'], None] = None
) -> Binding:
provider = self.provider_for(interface, to)
scope = scope or getattr(to or interface, '__scope__', NoScope)
scope = scope or getattr(to or interface, '__scope__', None)
if isinstance(scope, ScopeDecorator):
scope = scope.scope
return Binding(interface, provider, scope)
return Binding(interface, provider, scope or NoScope)

def provider_for(self, interface: Any, to: Any = None) -> Provider:
base_type = _punch_through_alias(interface)
Expand Down Expand Up @@ -696,7 +729,7 @@ def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']:
# The special interface is added here so that requesting a special
# interface with auto_bind disabled works
if self._auto_bind or self._is_special_interface(interface):
binding = ImplicitBinding(*self.create_binding(interface))
binding = ImplicitBinding(**self.create_binding(interface).__dict__)
self._bindings[interface] = binding
return binding, self

Expand Down Expand Up @@ -817,7 +850,7 @@ def __repr__(self) -> str:
class NoScope(Scope):
"""An unscoped provider."""

def get(self, unused_key: Type[T], provider: Provider[T]) -> Provider[T]:
def get(self, key: Type[T], provider: Provider[T]) -> Provider[T]:
Comment on lines -820 to +853
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by: fix invalid override (keyword arguments must match that of the method being overridden).

return provider


Expand Down
93 changes: 76 additions & 17 deletions injector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

"""Functional tests for the "Injector" dependency injection framework."""

from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, NewType, Optional, Union
import abc
import sys
import threading
import traceback
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, NewType, Optional, Union

if sys.version_info >= (3, 9):
from typing import Annotated
Expand All @@ -29,32 +29,32 @@
import pytest

from injector import (
AssistedBuilder,
Binder,
CallError,
CircularDependency,
ClassAssistedBuilder,
ClassProvider,
Error,
Inject,
Injector,
InstanceProvider,
InvalidInterface,
Module,
NoInject,
ProviderOf,
Scope,
Comment on lines +32 to 46
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran black on this file as well, which caused the imports to be sorted

InstanceProvider,
ClassProvider,
ScopeDecorator,
SingletonScope,
UnknownArgument,
UnsatisfiedRequirement,
get_bindings,
inject,
multiprovider,
noninjectable,
provider,
singleton,
threadlocal,
UnsatisfiedRequirement,
CircularDependency,
Module,
SingletonScope,
ScopeDecorator,
AssistedBuilder,
provider,
ProviderOf,
ClassAssistedBuilder,
Error,
UnknownArgument,
InvalidInterface,
)


Expand Down Expand Up @@ -723,6 +723,65 @@ def configure_dict(binder: Binder):
Injector([configure_dict])


def test_multibind_types_respect_the_bound_type_scope() -> None:
def configure(binder: Binder) -> None:
binder.bind(PluginA, to=PluginA, scope=singleton)
binder.multibind(List[Plugin], to=PluginA)

injector = Injector([configure])
first_list = injector.get(List[Plugin])
second_list = injector.get(List[Plugin])
child_injector = injector.create_child_injector()
third_list = child_injector.get(List[Plugin])

assert first_list[0] is second_list[0]
assert third_list[0] is second_list[0]


def test_multibind_list_scopes_applies_to_the_bound_items() -> None:
def configure(binder: Binder) -> None:
binder.multibind(List[Plugin], to=PluginA, scope=singleton)
binder.multibind(List[Plugin], to=PluginB)
binder.multibind(List[Plugin], to=[PluginC], scope=singleton)

injector = Injector([configure])
first_list = injector.get(List[Plugin])
second_list = injector.get(List[Plugin])

assert first_list is not second_list
assert first_list[0] is second_list[0]
assert first_list[1] is not second_list[1]
assert first_list[2] is second_list[2]


def test_multibind_dict_scopes_applies_to_the_bound_items() -> None:
def configure(binder: Binder) -> None:
binder.multibind(Dict[str, Plugin], to={'a': PluginA}, scope=singleton)
binder.multibind(Dict[str, Plugin], to={'b': PluginB})
binder.multibind(Dict[str, Plugin], to={'c': PluginC}, scope=singleton)

injector = Injector([configure])
first_dict = injector.get(Dict[str, Plugin])
second_dict = injector.get(Dict[str, Plugin])

assert first_dict is not second_dict
assert first_dict['a'] is second_dict['a']
assert first_dict['b'] is not second_dict['b']
assert first_dict['c'] is second_dict['c']


def test_multibind_scopes_does_not_apply_to_the_type_globally() -> None:
def configure(binder: Binder) -> None:
binder.multibind(List[Plugin], to=PluginA, scope=singleton)

injector = Injector([configure])
plugins = injector.get(List[Plugin])

assert plugins[0] is not injector.get(PluginA)
assert plugins[0] is not injector.get(Plugin)
assert injector.get(PluginA) is not injector.get(PluginA)


def test_regular_bind_and_provider_dont_work_with_multibind():
# We only want multibind and multiprovider to work to avoid confusion

Expand Down