Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-right committed Sep 10, 2021
1 parent 8615709 commit d8251fd
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 11 deletions.
27 changes: 16 additions & 11 deletions beanie/odm/documents.py
Expand Up @@ -44,7 +44,7 @@
from beanie.odm.queries.update import UpdateMany
from beanie.odm.utils.collection import collection_factory
from beanie.odm.utils.dump import get_dict
from beanie.odm.utils.state import saved_state_needed
from beanie.odm.utils.state import saved_state_needed, save_state_after

DocType = TypeVar("DocType", bound="Document")
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)
Expand Down Expand Up @@ -91,6 +91,7 @@ async def _sync(self) -> None:
self._save_state()

@wrap_with_actions(EventTypes.INSERT)
@save_state_after
async def insert(
self: DocType, session: Optional[ClientSession] = None
) -> DocType:
Expand All @@ -105,8 +106,6 @@ async def insert(
if not isinstance(new_id, self.__fields__["id"].type_):
new_id = self.__fields__["id"].type_(new_id)
self.id = new_id
if self.use_state_management():
self._save_state()
return self

async def create(
Expand Down Expand Up @@ -424,6 +423,7 @@ def all(
)

@wrap_with_actions(EventTypes.REPLACE)
@save_state_after
async def replace(
self: DocType, session: Optional[ClientSession] = None
) -> DocType:
Expand All @@ -439,8 +439,6 @@ async def replace(
await self.find_one({"_id": self.id}).replace_one(
self, session=session
)
if self.use_state_management():
self._save_state()
return self

async def save(
Expand Down Expand Up @@ -479,6 +477,7 @@ async def replace_many(
await cls.find(In(cls.id, ids_list), session=session).delete()
await cls.insert_many(documents, session=session)

@save_state_after
async def update(
self, *args, session: Optional[ClientSession] = None
) -> None:
Expand Down Expand Up @@ -668,17 +667,15 @@ def _save_state(self):
if self.use_state_management():
self._saved_state = self.dict()

def get_saved_state(self):
return self._saved_state

@classmethod
def _parse_obj_saving_state(cls: Type[DocType], obj: Any) -> DocType:
result: DocType = cls.parse_obj(obj)
result._save_state()
return result

@saved_state_needed
def rollback(self) -> None:
for key, value in self._saved_state.items(): # type: ignore
setattr(self, key, value)

@property # type: ignore
@saved_state_needed
def is_changed(self) -> bool:
Expand All @@ -688,21 +685,29 @@ def is_changed(self) -> bool:

@saved_state_needed
def get_changes(self) -> Dict[str, Any]:
# TODO search deeply
changes = {}
if self.is_changed:
current_state = self.dict()
for k, v in self._saved_state.items(): # type: ignore
if v != current_state[k]:
changes[k] = v
changes[k] = current_state[k]
return changes

@saved_state_needed
@save_state_after
async def save_changes(self) -> None:
if not self.is_changed:
return None
changes = self.get_changes()
await self.set(changes)

@saved_state_needed
def rollback(self) -> None:
if self.is_changed:
for key, value in self._saved_state.items(): # type: ignore
setattr(self, key, value)

class Config:
json_encoders = {
ObjectId: lambda v: str(v),
Expand Down
11 changes: 11 additions & 0 deletions beanie/odm/utils/state.py
Expand Up @@ -31,3 +31,14 @@ async def async_wrapper(self: "DocType", *args, **kwargs):
if inspect.iscoroutinefunction(f):
return async_wrapper
return sync_wrapper


def save_state_after(f: Callable):
@wraps(f)
async def wrapper(self: "DocType", *args, **kwargs):
result = await f(self, *args, **kwargs)
if self.use_state_management():
self._save_state()
return result

return wrapper
4 changes: 4 additions & 0 deletions tests/odm/conftest.py
Expand Up @@ -17,6 +17,8 @@
DocumentWithCustomIdUUID,
DocumentWithCustomIdInt,
DocumentWithActions,
DocumentWithTurnedOnStateManagement,
DocumentWithTurnedOffStateManagement,
)
from tests.odm.models import (
Sample,
Expand Down Expand Up @@ -124,6 +126,8 @@ async def init(loop, db):
DocumentWithCustomIdInt,
Sample,
DocumentWithActions,
DocumentWithTurnedOnStateManagement,
DocumentWithTurnedOffStateManagement,
]
await init_beanie(
database=db,
Expand Down
13 changes: 13 additions & 0 deletions tests/odm/models.py
Expand Up @@ -155,3 +155,16 @@ def num_2_change(self):
@after_event(Replace)
def num_3_change(self):
self.num_3 -= 1


class DocumentWithTurnedOnStateManagement(Document):
num_1: int
num_2: int

class Collection:
use_state_management = True


class DocumentWithTurnedOffStateManagement(Document):
num_1: int
num_2: int
122 changes: 122 additions & 0 deletions tests/odm/test_state_management.py
@@ -0,0 +1,122 @@
import pytest
from bson import ObjectId

from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved
from tests.odm.models import (
DocumentWithTurnedOnStateManagement,
DocumentWithTurnedOffStateManagement,
)


@pytest.fixture
def state():
return {"num_1": 1, "num_2": 2, "id": ObjectId()}


@pytest.fixture
def doc(state):
return DocumentWithTurnedOnStateManagement._parse_obj_saving_state(state)


@pytest.fixture
async def saved_doc(doc):
await doc.insert()
return doc


def test_use_state_management_property():
assert DocumentWithTurnedOnStateManagement.use_state_management() is True
assert DocumentWithTurnedOffStateManagement.use_state_management() is False


def test_save_state():
doc = DocumentWithTurnedOnStateManagement(num_1=1, num_2=2)
assert doc.get_saved_state() is None
doc._save_state()
assert doc.get_saved_state() == {"num_1": 1, "num_2": 2, "id": None}


def test_parse_object_with_saving_state():
obj = {"num_1": 1, "num_2": 2, "id": ObjectId()}
doc = DocumentWithTurnedOnStateManagement._parse_obj_saving_state(obj)
assert doc.get_saved_state() == obj


def test_saved_state_needed():
doc_1 = DocumentWithTurnedOffStateManagement(num_1=1, num_2=2)
with pytest.raises(StateManagementIsTurnedOff):
doc_1.is_changed

doc_2 = DocumentWithTurnedOnStateManagement(num_1=1, num_2=2)
with pytest.raises(StateNotSaved):
doc_2.is_changed


def test_if_changed(doc):
assert doc.is_changed is False
doc.num_1 = 10
assert doc.is_changed is True


def test_get_changes(doc):
doc.num_1 = 100
assert doc.get_changes() == {"num_1": 100}


async def test_save_changes(saved_doc):
saved_doc.num_1 = 100
await saved_doc.save_changes()

assert saved_doc.get_saved_state()["num_1"] == 100

new_doc = await DocumentWithTurnedOnStateManagement.get(saved_doc.id)
assert new_doc.num_1 == 100


async def test_find_one(saved_doc, state):
new_doc = await DocumentWithTurnedOnStateManagement.get(saved_doc.id)
assert new_doc.get_saved_state() == state

new_doc = await DocumentWithTurnedOnStateManagement.find_one(
DocumentWithTurnedOnStateManagement.id == saved_doc.id
)
assert new_doc.get_saved_state() == state


async def test_find_many():
docs = []
for i in range(10):
docs.append(DocumentWithTurnedOnStateManagement(num_1=i, num_2=i + 1))
await DocumentWithTurnedOnStateManagement.insert_many(docs)

found_docs = await DocumentWithTurnedOnStateManagement.find(
DocumentWithTurnedOnStateManagement.num_1 > 4
).to_list()

for doc in found_docs:
assert doc.get_saved_state() is not None


async def test_insert(state):
doc = DocumentWithTurnedOnStateManagement.parse_obj(state)
assert doc.get_saved_state() is None
await doc.insert()
assert doc.get_saved_state() == state


async def test_replace(saved_doc):
saved_doc.num_1 = 100
await saved_doc.replace()
assert saved_doc.get_saved_state()["num_1"] == 100


async def test_save_chages(saved_doc):
saved_doc.num_1 = 100
await saved_doc.save_changes()
assert saved_doc.get_saved_state()["num_1"] == 100


async def test_rollback(doc, state):
doc.num_1 = 100
doc.rollback()
assert doc.num_1 == state["num_1"]

0 comments on commit d8251fd

Please sign in to comment.