Skip to content

Commit

Permalink
move registry validation to registry builder
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishka17 committed Mar 3, 2024
1 parent f158cbc commit 22af242
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 75 deletions.
9 changes: 4 additions & 5 deletions src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)
from .provider import Provider
from .registry import Registry, RegistryBuilder
from .validation import GraphValidator

T = TypeVar("T")

Expand Down Expand Up @@ -187,11 +186,11 @@ def make_async_container(
skip_validation: bool = False,
) -> AsyncContainer:
registries = RegistryBuilder(
scopes=scopes, container_type=AsyncContainer, providers=providers,
scopes=scopes,
container_type=AsyncContainer,
providers=providers,
skip_validation=skip_validation,
).build()
if not skip_validation:
validator = GraphValidator(registries)
validator.validate()
return AsyncContainer(
*registries,
context=context,
Expand Down
9 changes: 4 additions & 5 deletions src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)
from .provider import Provider
from .registry import Registry, RegistryBuilder
from .validation import GraphValidator

T = TypeVar("T")

Expand Down Expand Up @@ -188,9 +187,9 @@ def make_container(
skip_validation: bool = False,
) -> Container:
registries = RegistryBuilder(
scopes=scopes, container_type=Container, providers=providers,
scopes=scopes,
container_type=Container,
providers=providers,
skip_validation=skip_validation,
).build()
if not skip_validation:
validator = GraphValidator(registries)
validator.validate()
return Container(*registries, context=context, lock_factory=lock_factory)
60 changes: 59 additions & 1 deletion src/dishka/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,66 @@ def _specialize_generic(
)


class GraphValidator:
def __init__(self, registries: Sequence[Registry]) -> None:
self.registries = registries
self.valid_keys = {}
self.path = {}

def _validate_key(
self, key: DependencyKey, registry_index: int,
) -> None:
if key in self.valid_keys:
return
if key in self.path:
keys = list(self.path)
factories = list(self.path.values())[keys.index(key):]
raise CycleDependenciesError(factories)
for index in range(registry_index + 1):
registry = self.registries[index]
factory = registry.get_factory(key)
if factory:
self._validate_factory(factory, registry_index)
return
raise NoFactoryError(requested=key)

def _validate_factory(
self, factory: Factory, registry_index: int,
):
self.path[factory.provides] = factory
try:
for dep in factory.dependencies:
# ignore TypeVar parameters
if not isinstance(dep.type_hint, TypeVar):
self._validate_key(dep, registry_index)
except NoFactoryError as e:
e.add_path(factory)
raise
finally:
self.path.pop(factory.provides)
self.valid_keys[factory.provides] = True

def validate(self):
for registry_index, registry in enumerate(self.registries):
for factory in registry.factories.values():
self.path = {}
try:
self._validate_factory(factory, registry_index)
except NoFactoryError as e:
raise GraphMissingFactoryError(
e.requested, e.path,
) from None
except CycleDependenciesError as e:
raise e from None


class RegistryBuilder:
def __init__(
self,
scopes: type[BaseScope],
providers: Sequence[Provider],
container_type: type,
skip_validation: bool,
) -> None:
self.scopes = scopes
self.providers = providers
Expand All @@ -102,6 +156,7 @@ def __init__(
self.aliases: dict[DependencyKey, Alias] = {}
self.container_type = container_type
self.decorator_depth: dict[DependencyKey, int] = defaultdict(int)
self.skip_validation = skip_validation

def _collect_components(self) -> None:
for provider in self.providers:
Expand Down Expand Up @@ -243,4 +298,7 @@ def build(self):
for decorator in provider.decorators:
self._process_decorator(provider, decorator)

return list(self.registries.values())
registries = list(self.registries.values())
if not self.skip_validation:
GraphValidator(registries).validate()
return registries
64 changes: 0 additions & 64 deletions src/dishka/validation.py

This file was deleted.

0 comments on commit 22af242

Please sign in to comment.