Skip to content

Commit

Permalink
update model during save validation (#776)
Browse files Browse the repository at this point in the history
* update model during save validation

* test suite

* update typing in the test

* Pydantic v1 test compatibility
  • Loading branch information
roman-right committed Dec 3, 2023
1 parent f109cd8 commit 8626f88
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 2 deletions.
2 changes: 2 additions & 0 deletions beanie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from beanie.odm.enums import SortDirection
from beanie.odm.fields import (
BackLink,
BeanieObjectId,
DeleteRules,
Indexed,
Link,
Expand All @@ -40,6 +41,7 @@
"UnionDoc",
"init_beanie",
"PydanticObjectId",
"BeanieObjectId",
"Indexed",
"TimeSeriesConfig",
"Granularity",
Expand Down
3 changes: 2 additions & 1 deletion beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,8 @@ def check_hidden_fields(cls):
async def validate_self(self, *args, **kwargs):
# TODO: it can be sync, but needs some actions controller improvements
if self.get_settings().validate_on_save:
parse_model(self.__class__, get_model_dump(self))
new_model = parse_model(self.__class__, get_model_dump(self))
merge_models(self, new_model)

def to_ref(self):
if self.id is None:
Expand Down
2 changes: 2 additions & 0 deletions beanie/odm/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def __modify_schema__(cls, field_schema):
PydanticObjectId
] = str # it is a workaround to force pydantic make json schema for this field

BeanieObjectId = PydanticObjectId


class ExpressionField(str):
def __getitem__(self, item):
Expand Down
29 changes: 28 additions & 1 deletion tests/odm/documents/test_validation_on_save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Optional

import pytest
from pydantic import ValidationError
from bson import ObjectId
from pydantic import BaseModel, ValidationError

from beanie import PydanticObjectId
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
from tests.odm.models import (
DocumentWithValidationOnSave,
Lock,
Expand Down Expand Up @@ -31,6 +36,28 @@ async def test_validate_on_save_changes():
await doc.save_changes()


async def test_validate_on_save_keep_the_id_type():
class UpdateModel(BaseModel):
num_1: Optional[int] = None
related: Optional[PydanticObjectId] = None

doc = DocumentWithValidationOnSave(num_1=1, num_2=2)
await doc.insert()
update = UpdateModel(related=PydanticObjectId())
if IS_PYDANTIC_V2:
doc = doc.model_copy(update=update.model_dump(exclude_unset=True))
else:
doc = doc.copy(update=update.dict(exclude_unset=True))
doc.num_2 = 1000
await doc.save()
in_db = await DocumentWithValidationOnSave.get_motor_collection().find_one(
{"_id": doc.id}
)
assert isinstance(in_db["related"], ObjectId)
new_doc = await DocumentWithValidationOnSave.get(doc.id)
assert isinstance(new_doc.related, PydanticObjectId)


async def test_validate_on_save_action():
doc = DocumentWithValidationOnSave(num_1=1, num_2=2)
await doc.insert()
Expand Down
1 change: 1 addition & 0 deletions tests/odm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ class DocumentWithTurnedOffStateManagement(Document):
class DocumentWithValidationOnSave(Document):
num_1: int
num_2: int
related: PydanticObjectId = Field(default_factory=PydanticObjectId)

@after_event(ValidateOnSave)
def num_2_plus_1(self):
Expand Down

0 comments on commit 8626f88

Please sign in to comment.