Skip to content

Commit

Permalink
refactory registry creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishka17 committed Mar 3, 2024
1 parent f4ddebd commit f158cbc
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 95 deletions.
10 changes: 4 additions & 6 deletions src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
UnsupportedFactoryError,
)
from .provider import Provider
from .registry import Registry, make_registries
from .registry import Registry, RegistryBuilder
from .validation import GraphValidator

T = TypeVar("T")
Expand Down Expand Up @@ -186,11 +186,9 @@ def make_async_container(
lock_factory: Callable[[], Lock] | None = Lock,
skip_validation: bool = False,
) -> AsyncContainer:
registries = make_registries(
*providers,
scopes=scopes,
container_type=AsyncContainer,
)
registries = RegistryBuilder(
scopes=scopes, container_type=AsyncContainer, providers=providers,
).build()
if not skip_validation:
validator = GraphValidator(registries)
validator.validate()
Expand Down
10 changes: 4 additions & 6 deletions src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
UnsupportedFactoryError,
)
from .provider import Provider
from .registry import Registry, make_registries
from .registry import Registry, RegistryBuilder
from .validation import GraphValidator

T = TypeVar("T")
Expand Down Expand Up @@ -187,11 +187,9 @@ def make_container(
lock_factory: Callable[[], Lock] | None = None,
skip_validation: bool = False,
) -> Container:
registries = make_registries(
*providers,
scopes=scopes,
container_type=Container,
)
registries = RegistryBuilder(
scopes=scopes, container_type=Container, providers=providers,
).build()
if not skip_validation:
validator = GraphValidator(registries)
validator.validate()
Expand Down
7 changes: 7 additions & 0 deletions src/dishka/entities/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ class Scope(BaseScope):
REQUEST = "REQUEST"
ACTION = "ACTION"
STEP = "STEP"


class InvalidScopes(BaseScope):
UNKNOWN_SCOPE = "<unknown scope>"

def __str__(self):
return self.value
4 changes: 4 additions & 0 deletions src/dishka/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class InvalidGraphError(DishkaError):
pass


class UnknownScopeError(InvalidGraphError):
pass


class CycleDependenciesError(InvalidGraphError):
def __init__(self, path: Sequence[Factory]) -> None:
self.path = path
Expand Down
249 changes: 166 additions & 83 deletions src/dishka/registry.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, NewType, TypeVar, get_args, get_origin

from ._adaptix.type_tools.basic_utils import get_type_vars, is_generic
from .dependency_source import Alias, Factory, from_context
from .entities.component import DEFAULT_COMPONENT
from .dependency_source import (
Alias,
ContextVariable,
Decorator,
Factory,
from_context,
)
from .entities.component import DEFAULT_COMPONENT, Component
from .entities.key import DependencyKey
from .entities.scope import BaseScope
from .entities.scope import BaseScope, InvalidScopes
from .exceptions import (
CycleDependenciesError,
GraphMissingFactoryError,
NoFactoryError,
UnknownScopeError,
)
from .provider import Provider

Expand Down Expand Up @@ -77,87 +86,161 @@ def _specialize_generic(
)


def make_registries(
*providers: Provider,
scopes: type[BaseScope],
container_type: type,
) -> list[Registry]:
dep_scopes: dict[DependencyKey, BaseScope] = {}
alias_sources: dict[DependencyKey, Any] = {}
aliases: dict[DependencyKey, Alias] = {}
components = {DEFAULT_COMPONENT}
for provider in providers:
component = provider.component
components.add(component)
for source in provider.factories:
provides = source.provides.with_component(component)
dep_scopes[provides] = source.scope
for source in provider.aliases:
provides = source.provides.with_component(component)
alias_sources[provides] = source.source.with_component(component)
aliases[provides] = source

registries: dict[BaseScope, Registry] = {}
for scope in scopes:
registry = Registry(scope)
context_var = from_context(provides=container_type, scope=scope)
for component in components:
registry.add_factory(context_var.as_factory(component))
registries[scope] = registry
class RegistryBuilder:
def __init__(
self,
scopes: type[BaseScope],
providers: Sequence[Provider],
container_type: type,
) -> None:
self.scopes = scopes
self.providers = providers
self.registries: dict[BaseScope, Registry] = {}
self.dependency_scopes: dict[DependencyKey, BaseScope] = {}
self.components: set[Component] = {DEFAULT_COMPONENT}
self.alias_sources: dict[DependencyKey, Any] = {}
self.aliases: dict[DependencyKey, Alias] = {}
self.container_type = container_type
self.decorator_depth: dict[DependencyKey, int] = defaultdict(int)

decorator_depth: dict[DependencyKey, int] = defaultdict(int)
def _collect_components(self) -> None:
for provider in self.providers:
self.components.add(provider.component)

for provider in providers:
component = provider.component
for source in provider.factories:
scope = source.scope
registries[scope].add_factory(source.with_component(component))
for source in provider.aliases:
alias_source = source.source.with_component(component)
visited_keys = []
while alias_source not in dep_scopes:
if alias_source not in alias_sources:
e = NoFactoryError(alias_source)
for s in visited_keys[::-1]:
e.add_path(aliases[s].as_factory("<unknown scope>",
s.component))
e.add_path(source.as_factory("<unknown scope>", component))
raise e
visited_keys.append(alias_source)
alias_source = alias_sources[alias_source]
if alias_source in visited_keys:
raise CycleDependenciesError([
aliases[s].as_factory("<unknown scope>", component)
for s in visited_keys
])
scope = dep_scopes[alias_source]
source = source.as_factory(scope, component)
dep_scopes[source.provides] = scope
registries[scope].add_factory(source)
for source in provider.decorators:
provides = source.provides.with_component(component)
scope = dep_scopes[provides]
registry = registries[scope]
undecorated_type = NewType(
f"{provides.type_hint.__name__}@{decorator_depth[provides]}",
source.provides,
)
decorator_depth[provides] += 1
old_factory = registry.get_factory(provides)
old_factory.provides = DependencyKey(
undecorated_type, old_factory.provides.component,
def _collect_provided_scopes(self) -> None:
for provider in self.providers:
for factory in provider.factories:
if factory.scope not in self.scopes:
raise UnknownScopeError(
f"Scope {factory.scope} is unknown, "
f"expected one of {self.scopes}",
)
provides = factory.provides.with_component(provider.component)
self.dependency_scopes[provides] = factory.scope
for context_var in provider.context_vars:
if context_var.scope not in self.scopes:
raise UnknownScopeError(
f"Scope {context_var.scope} is unknown, "
f"expected one of {self.scopes}",
)
for component in self.components:
provides = context_var.provides.with_component(component)
self.dependency_scopes[provides] = context_var.scope

def _collect_aliases(self) -> None:
for provider in self.providers:
component = provider.component
for alias in provider.aliases:
provides = alias.provides.with_component(component)
alias_source = alias.source.with_component(component)
self.alias_sources[provides] = alias_source
self.aliases[provides] = alias

def _init_registries(self) -> None:
for scope in self.scopes:
registry = Registry(scope)
context_var = from_context(
provides=self.container_type, scope=scope,
)
registry.add_factory(old_factory)
source = source.as_factory(
scope=scope,
new_dependency=DependencyKey(undecorated_type, None),
cache=old_factory.cache,
component=component,
for component in self.components:
registry.add_factory(context_var.as_factory(component))
self.registries[scope] = registry

def _process_factory(
self, provider: Provider, factory: Factory,
) -> None:
registry = self.registries[factory.scope]
registry.add_factory(factory.with_component(provider.component))

def _process_alias(
self, provider: Provider, alias: Alias,
) -> None:
component = provider.component
alias_source = alias.source.with_component(component)
visited_keys = []
while alias_source not in self.dependency_scopes:
if alias_source not in self.alias_sources:
e = NoFactoryError(alias_source)
for key in visited_keys[::-1]:
e.add_path(self.aliases[key].as_factory(
InvalidScopes.UNKNOWN_SCOPE, key.component,
))
e.add_path(alias.as_factory(
InvalidScopes.UNKNOWN_SCOPE, component,
))
raise e
visited_keys.append(alias_source)
alias_source = self.alias_sources[alias_source]
if alias_source in visited_keys:
raise CycleDependenciesError([
self.aliases[s].as_factory(
InvalidScopes.UNKNOWN_SCOPE, component,
)
for s in visited_keys
])

scope = self.dependency_scopes[alias_source]
registry = self.registries[scope]

factory = alias.as_factory(scope, component)
self.dependency_scopes[factory.provides] = scope
registry.add_factory(factory)

def _process_decorator(
self, provider: Provider, decorator: Decorator,
) -> None:
provides = decorator.provides.with_component(provider.component)
if provides not in self.dependency_scopes:
raise GraphMissingFactoryError(
requested=provides,
path=[decorator.as_factory(
scope=InvalidScopes.UNKNOWN_SCOPE,
new_dependency=provides,
cache=False,
component=provider.component,
)],
)
registries[scope].add_factory(source)
for source in provider.context_vars:
scope = source.scope
registry = registries[scope]
for component in components:
registry.add_factory(source.as_factory(component))
return list(registries.values())
scope = self.dependency_scopes[provides]
registry = self.registries[scope]
undecorated_type = NewType(
f"{provides.type_hint.__name__}@{self.decorator_depth[provides]}",
decorator.provides.type_hint,
)
self.decorator_depth[provides] += 1
old_factory = registry.get_factory(provides)
old_factory.provides = DependencyKey(
undecorated_type, old_factory.provides.component,
)
new_factory = decorator.as_factory(
scope=scope,
new_dependency=DependencyKey(undecorated_type, None),
cache=old_factory.cache,
component=provider.component,
)
registry.add_factory(old_factory)
registry.add_factory(new_factory)

def _process_context_var(
self, provider: Provider, context_var: ContextVariable,
) -> None:
registry = self.registries[context_var.scope]
for component in self.components:
registry.add_factory(context_var.as_factory(component))

def build(self):
self._collect_components()
self._collect_provided_scopes()
self._collect_aliases()
self._init_registries()

for provider in self.providers:
for factory in provider.factories:
self._process_factory(provider, factory)
for alias in provider.aliases:
self._process_alias(provider, alias)
for context_var in provider.context_vars:
self._process_context_var(provider, context_var)
for decorator in provider.decorators:
self._process_decorator(provider, decorator)

return list(self.registries.values())

0 comments on commit f158cbc

Please sign in to comment.