diff --git a/src/dishka/async_container.py b/src/dishka/async_container.py index 5f95096..1fe409d 100644 --- a/src/dishka/async_container.py +++ b/src/dishka/async_container.py @@ -15,7 +15,6 @@ ) from .provider import Provider from .registry import Registry, RegistryBuilder -from .validation import GraphValidator T = TypeVar("T") @@ -187,11 +186,11 @@ def make_async_container( skip_validation: bool = False, ) -> AsyncContainer: registries = RegistryBuilder( - scopes=scopes, container_type=AsyncContainer, providers=providers, + scopes=scopes, + container_type=AsyncContainer, + providers=providers, + skip_validation=skip_validation, ).build() - if not skip_validation: - validator = GraphValidator(registries) - validator.validate() return AsyncContainer( *registries, context=context, diff --git a/src/dishka/container.py b/src/dishka/container.py index 06836f3..55918d2 100644 --- a/src/dishka/container.py +++ b/src/dishka/container.py @@ -15,7 +15,6 @@ ) from .provider import Provider from .registry import Registry, RegistryBuilder -from .validation import GraphValidator T = TypeVar("T") @@ -188,9 +187,9 @@ def make_container( skip_validation: bool = False, ) -> Container: registries = RegistryBuilder( - scopes=scopes, container_type=Container, providers=providers, + scopes=scopes, + container_type=Container, + providers=providers, + skip_validation=skip_validation, ).build() - if not skip_validation: - validator = GraphValidator(registries) - validator.validate() return Container(*registries, context=context, lock_factory=lock_factory) diff --git a/src/dishka/registry.py b/src/dishka/registry.py index fcfd828..4abdcce 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -86,12 +86,67 @@ def _specialize_generic( ) +class GraphValidator: + def __init__(self, registries: Sequence[Registry]) -> None: + self.registries = registries + self.valid_keys = {} + self.path = {} + + def _validate_key( + self, key: DependencyKey, registry_index: int, + ) -> None: + if key in self.valid_keys: + return + if key in self.path: + keys = list(self.path) + factories = list(self.path.values())[keys.index(key):] + raise CycleDependenciesError(factories) + for index in range(registry_index + 1): + registry = self.registries[index] + factory = registry.get_factory(key) + if factory: + self._validate_factory(factory, registry_index) + return + raise NoFactoryError(requested=key) + + def _validate_factory( + self, factory: Factory, registry_index: int, + ): + self.path[factory.provides] = factory + try: + for dep in factory.dependencies: + # ignore TypeVar parameters + if not isinstance(dep.type_hint, TypeVar): + self._validate_key(dep, registry_index) + except NoFactoryError as e: + e.add_path(factory) + raise + finally: + self.path.pop(factory.provides) + self.valid_keys[factory.provides] = True + + def validate(self): + for registry_index, registry in enumerate(self.registries): + for factory in registry.factories.values(): + self.path = {} + try: + self._validate_factory(factory, registry_index) + except NoFactoryError as e: + raise GraphMissingFactoryError( + e.requested, e.path, + ) from None + except CycleDependenciesError as e: + raise e from None + + class RegistryBuilder: def __init__( self, + *, scopes: type[BaseScope], providers: Sequence[Provider], container_type: type, + skip_validation: bool, ) -> None: self.scopes = scopes self.providers = providers @@ -102,6 +157,7 @@ def __init__( self.aliases: dict[DependencyKey, Alias] = {} self.container_type = container_type self.decorator_depth: dict[DependencyKey, int] = defaultdict(int) + self.skip_validation = skip_validation def _collect_components(self) -> None: for provider in self.providers: @@ -243,4 +299,7 @@ def build(self): for decorator in provider.decorators: self._process_decorator(provider, decorator) - return list(self.registries.values()) + registries = list(self.registries.values()) + if not self.skip_validation: + GraphValidator(registries).validate() + return registries diff --git a/src/dishka/validation.py b/src/dishka/validation.py deleted file mode 100644 index 277535c..0000000 --- a/src/dishka/validation.py +++ /dev/null @@ -1,64 +0,0 @@ -from collections.abc import Sequence -from typing import TypeVar - -from dishka.dependency_source import Factory -from dishka.entities.key import DependencyKey -from dishka.exceptions import ( - CycleDependenciesError, - GraphMissingFactoryError, - NoFactoryError, -) -from dishka.registry import Registry - - -class GraphValidator: - def __init__(self, registries: Sequence[Registry]) -> None: - self.registries = registries - self.valid_keys = {} - self.path = {} - - def _validate_key( - self, key: DependencyKey, registry_index: int, - ) -> None: - if key in self.valid_keys: - return - if key in self.path: - keys = list(self.path) - factories = list(self.path.values())[keys.index(key):] - raise CycleDependenciesError(factories) - for index in range(registry_index + 1): - registry = self.registries[index] - factory = registry.get_factory(key) - if factory: - self._validate_factory(factory, registry_index) - return - raise NoFactoryError(requested=key) - - def _validate_factory( - self, factory: Factory, registry_index: int, - ): - self.path[factory.provides] = factory - try: - for dep in factory.dependencies: - # ignore TypeVar parameters - if not isinstance(dep.type_hint, TypeVar): - self._validate_key(dep, registry_index) - except NoFactoryError as e: - e.add_path(factory) - raise - finally: - self.path.pop(factory.provides) - self.valid_keys[factory.provides] = True - - def validate(self): - for registry_index, registry in enumerate(self.registries): - for factory in registry.factories.values(): - self.path = {} - try: - self._validate_factory(factory, registry_index) - except NoFactoryError as e: - raise GraphMissingFactoryError( - e.requested, e.path, - ) from None - except CycleDependenciesError as e: - raise e from None