diff --git a/ppa/in_memory.py b/ppa/in_memory.py index 8ea1ff1..42867fd 100644 --- a/ppa/in_memory.py +++ b/ppa/in_memory.py @@ -7,13 +7,11 @@ def in_memory_repository( entity_already_exists_exception: Type[Exception] = ValueError, entity_not_found_exception: Type[Exception] = ValueError, ): - if find_by_fields is None: - find_by_fields = [] def decorator(cls) -> Type: class InMemoryRepository(cls): def __init__(self) -> None: self._store: dict[str, Any] = {} - for key in find_by_fields: + for key in find_by_fields or []: setattr(self, f"find_by_{key}", self._make_find_by_method(key)) def add(self, entity: Any) -> Any: diff --git a/ppa/test/test_in_memory.py b/ppa/test/test_in_memory.py index d735cd5..45cebee 100644 --- a/ppa/test/test_in_memory.py +++ b/ppa/test/test_in_memory.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from ppa.in_memory import in_memory_repository @@ -7,14 +9,18 @@ class User: def __init__(self, id_: int, name: str): self.id = id_ self.name = name + + @pytest.fixture def user(): - return User(1, 'john') + return User(1, "john") + @in_memory_repository(find_by_fields=[]) class UserRepository: pass + @pytest.fixture def user_repository(): return UserRepository() @@ -25,13 +31,14 @@ def test_in_memory_repository_find_by_fields(user: User): class FieldsRepository: pass - rep = FieldsRepository() + rep: Any = FieldsRepository() rep.add(user) assert rep.find_by_id(user.id) == [user] assert rep.find_by_name(user.name) == [user] + def test_in_memory_repository_exceptions(user: User): class CustomException1(Exception): pass @@ -46,8 +53,7 @@ class CustomException2(Exception): class ExceptionsRepository: pass - - rep = ExceptionsRepository() + rep: Any = ExceptionsRepository() with pytest.raises(CustomException2): rep.delete(user.id) @@ -57,7 +63,7 @@ class ExceptionsRepository: rep.add(user) -def test_in_memory_repository_methods(user_repository: UserRepository, user: User): +def test_in_memory_repository_methods(user_repository: Any, user: User): user_repository.add(user) assert user_repository.retrieve(user.id) == user updated_user = User(id_=user.id, name="john_updated") @@ -69,17 +75,17 @@ def test_in_memory_repository_methods(user_repository: UserRepository, user: Use assert user_repository.retrieve(user.id) -def test_in_memory_repository_add_already_exists(user_repository: UserRepository, user: User): +def test_in_memory_repository_add_already_exists(user_repository: Any, user: User): user_repository.add(user) with pytest.raises(ValueError): user_repository.add(user) -def test_in_memory_repository_update_not_found(user_repository: UserRepository, user: User): + +def test_in_memory_repository_update_not_found(user_repository: Any, user: User): with pytest.raises(ValueError): user_repository.update(user) -def test_in_memory_repository_delete_not_found(user_repository: UserRepository, user: User): +def test_in_memory_repository_delete_not_found(user_repository: Any, user: User): with pytest.raises(ValueError): user_repository.delete(user) -