Skip to content

Commit

Permalink
Encode date objects (#816)
Browse files Browse the repository at this point in the history
* Run dos2unix on encoder.py

* Add default encoder for dates

* Don't try to encode unknown types
  • Loading branch information
gsakkis committed Jan 10, 2024
1 parent f2e1439 commit e876e1a
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 184 deletions.
333 changes: 162 additions & 171 deletions beanie/odm/utils/encoder.py
@@ -1,171 +1,162 @@
import dataclasses as dc
import datetime
import decimal
import enum
import ipaddress
import operator
import pathlib
import re
import uuid
from typing import (
Any,
Callable,
Container,
Iterable,
Mapping,
MutableMapping,
Optional,
Tuple,
)

import bson
import pydantic

import beanie
from beanie.odm.fields import Link, LinkTypes
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2, get_model_fields

SingleArgCallable = Callable[[Any], Any]
DEFAULT_CUSTOM_ENCODERS: MutableMapping[type, SingleArgCallable] = {
ipaddress.IPv4Address: str,
ipaddress.IPv4Interface: str,
ipaddress.IPv4Network: str,
ipaddress.IPv6Address: str,
ipaddress.IPv6Interface: str,
ipaddress.IPv6Network: str,
pathlib.PurePath: str,
pydantic.SecretBytes: pydantic.SecretBytes.get_secret_value,
pydantic.SecretStr: pydantic.SecretStr.get_secret_value,
datetime.timedelta: operator.methodcaller("total_seconds"),
enum.Enum: operator.attrgetter("value"),
Link: operator.attrgetter("ref"),
bytes: bson.Binary,
decimal.Decimal: bson.Decimal128,
uuid.UUID: bson.Binary.from_uuid,
re.Pattern: bson.Regex.from_native,
}
if IS_PYDANTIC_V2:
from pydantic_core import Url

DEFAULT_CUSTOM_ENCODERS[Url] = str

BSON_SCALAR_TYPES = (
type(None),
str,
int,
float,
datetime.datetime,
bson.Binary,
bson.DBRef,
bson.Decimal128,
bson.MaxKey,
bson.MinKey,
bson.ObjectId,
)


@dc.dataclass
class Encoder:
"""
BSON encoding class
"""

exclude: Container[str] = frozenset()
custom_encoders: Mapping[type, SingleArgCallable] = dc.field(
default_factory=dict
)
to_db: bool = False
keep_nulls: bool = True

def _encode_document(self, obj: "beanie.Document") -> Mapping[str, Any]:
obj.parse_store()
settings = obj.get_settings()
obj_dict = {}
if settings.union_doc is not None:
obj_dict[settings.class_id] = (
settings.union_doc_alias or obj.__class__.__name__
)
if obj._class_id:
obj_dict[settings.class_id] = obj._class_id

link_fields = obj.get_link_fields() or {}
sub_encoder = Encoder(
# don't propagate self.exclude to subdocuments
custom_encoders=settings.bson_encoders,
to_db=self.to_db,
keep_nulls=self.keep_nulls,
)
for key, value in self._iter_model_items(obj):
if key in link_fields:
link_type = link_fields[key].link_type
if link_type in (LinkTypes.DIRECT, LinkTypes.OPTIONAL_DIRECT):
if value is not None:
value = value.to_ref()
elif link_type in (LinkTypes.LIST, LinkTypes.OPTIONAL_LIST):
if value is not None:
value = [link.to_ref() for link in value]
elif self.to_db:
continue
obj_dict[key] = sub_encoder.encode(value)
return obj_dict

def encode(self, obj: Any) -> Any:
if self.custom_encoders:
encoder = _get_encoder(obj, self.custom_encoders)
if encoder is not None:
return encoder(obj)

if isinstance(obj, BSON_SCALAR_TYPES):
return obj

encoder = _get_encoder(obj, DEFAULT_CUSTOM_ENCODERS)
if encoder is not None:
return encoder(obj)

if isinstance(obj, beanie.Document):
return self._encode_document(obj)
if IS_PYDANTIC_V2 and isinstance(obj, pydantic.RootModel):
return self.encode(obj.root)
if isinstance(obj, pydantic.BaseModel):
items = self._iter_model_items(obj)
return {key: self.encode(value) for key, value in items}
if isinstance(obj, Mapping):
return {str(key): self.encode(value) for key, value in obj.items()}
if isinstance(obj, Iterable):
return [self.encode(value) for value in obj]

errors = []
try:
data = dict(obj)
except Exception as e:
errors.append(e)
try:
data = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors)
return self.encode(data)

def _iter_model_items(
self, obj: pydantic.BaseModel
) -> Iterable[Tuple[str, Any]]:
exclude, keep_nulls = self.exclude, self.keep_nulls
get_model_field = get_model_fields(obj).get
for key, value in obj.__iter__():
field_info = get_model_field(key)
if field_info is not None:
key = field_info.alias or key
if key not in exclude and (value is not None or keep_nulls):
yield key, value


def _get_encoder(
obj: Any, custom_encoders: Mapping[type, SingleArgCallable]
) -> Optional[SingleArgCallable]:
encoder = custom_encoders.get(type(obj))
if encoder is not None:
return encoder
for cls, encoder in custom_encoders.items():
if isinstance(obj, cls):
return encoder
return None
import dataclasses as dc
import datetime
import decimal
import enum
import ipaddress
import operator
import pathlib
import re
import uuid
from typing import (
Any,
Callable,
Container,
Iterable,
Mapping,
MutableMapping,
Optional,
Tuple,
)

import bson
import pydantic

import beanie
from beanie.odm.fields import Link, LinkTypes
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2, get_model_fields

SingleArgCallable = Callable[[Any], Any]
DEFAULT_CUSTOM_ENCODERS: MutableMapping[type, SingleArgCallable] = {
ipaddress.IPv4Address: str,
ipaddress.IPv4Interface: str,
ipaddress.IPv4Network: str,
ipaddress.IPv6Address: str,
ipaddress.IPv6Interface: str,
ipaddress.IPv6Network: str,
pathlib.PurePath: str,
pydantic.SecretBytes: pydantic.SecretBytes.get_secret_value,
pydantic.SecretStr: pydantic.SecretStr.get_secret_value,
datetime.date: lambda d: datetime.datetime.combine(d, datetime.time.min),
datetime.timedelta: operator.methodcaller("total_seconds"),
enum.Enum: operator.attrgetter("value"),
Link: operator.attrgetter("ref"),
bytes: bson.Binary,
decimal.Decimal: bson.Decimal128,
uuid.UUID: bson.Binary.from_uuid,
re.Pattern: bson.Regex.from_native,
}
if IS_PYDANTIC_V2:
from pydantic_core import Url

DEFAULT_CUSTOM_ENCODERS[Url] = str

BSON_SCALAR_TYPES = (
type(None),
str,
int,
float,
datetime.datetime,
bson.Binary,
bson.DBRef,
bson.Decimal128,
bson.MaxKey,
bson.MinKey,
bson.ObjectId,
)


@dc.dataclass
class Encoder:
"""
BSON encoding class
"""

exclude: Container[str] = frozenset()
custom_encoders: Mapping[type, SingleArgCallable] = dc.field(
default_factory=dict
)
to_db: bool = False
keep_nulls: bool = True

def _encode_document(self, obj: "beanie.Document") -> Mapping[str, Any]:
obj.parse_store()
settings = obj.get_settings()
obj_dict = {}
if settings.union_doc is not None:
obj_dict[settings.class_id] = (
settings.union_doc_alias or obj.__class__.__name__
)
if obj._class_id:
obj_dict[settings.class_id] = obj._class_id

link_fields = obj.get_link_fields() or {}
sub_encoder = Encoder(
# don't propagate self.exclude to subdocuments
custom_encoders=settings.bson_encoders,
to_db=self.to_db,
keep_nulls=self.keep_nulls,
)
for key, value in self._iter_model_items(obj):
if key in link_fields:
link_type = link_fields[key].link_type
if link_type in (LinkTypes.DIRECT, LinkTypes.OPTIONAL_DIRECT):
if value is not None:
value = value.to_ref()
elif link_type in (LinkTypes.LIST, LinkTypes.OPTIONAL_LIST):
if value is not None:
value = [link.to_ref() for link in value]
elif self.to_db:
continue
obj_dict[key] = sub_encoder.encode(value)
return obj_dict

def encode(self, obj: Any) -> Any:
if self.custom_encoders:
encoder = _get_encoder(obj, self.custom_encoders)
if encoder is not None:
return encoder(obj)

if isinstance(obj, BSON_SCALAR_TYPES):
return obj

encoder = _get_encoder(obj, DEFAULT_CUSTOM_ENCODERS)
if encoder is not None:
return encoder(obj)

if isinstance(obj, beanie.Document):
return self._encode_document(obj)
if IS_PYDANTIC_V2 and isinstance(obj, pydantic.RootModel):
return self.encode(obj.root)
if isinstance(obj, pydantic.BaseModel):
items = self._iter_model_items(obj)
return {key: self.encode(value) for key, value in items}
if isinstance(obj, Mapping):
return {str(key): self.encode(value) for key, value in obj.items()}
if isinstance(obj, Iterable):
return [self.encode(value) for value in obj]

raise ValueError(f"Cannot encode {obj!r}")

def _iter_model_items(
self, obj: pydantic.BaseModel
) -> Iterable[Tuple[str, Any]]:
exclude, keep_nulls = self.exclude, self.keep_nulls
get_model_field = get_model_fields(obj).get
for key, value in obj.__iter__():
field_info = get_model_field(key)
if field_info is not None:
key = field_info.alias or key
if key not in exclude and (value is not None or keep_nulls):
yield key, value


def _get_encoder(
obj: Any, custom_encoders: Mapping[type, SingleArgCallable]
) -> Optional[SingleArgCallable]:
encoder = custom_encoders.get(type(obj))
if encoder is not None:
return encoder
for cls, encoder in custom_encoders.items():
if isinstance(obj, cls):
return encoder
return None
16 changes: 3 additions & 13 deletions tests/odm/models.py
Expand Up @@ -283,6 +283,9 @@ class DocumentWithCustomFiledsTypes(Document):
tuple_type: Tuple[int, str]
path: Path

class Settings:
bson_encoders = {Color: vars}

if IS_PYDANTIC_V2:
model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand Down Expand Up @@ -581,19 +584,6 @@ class DocumentWithStringField(Document):
class DocumentForEncodingTestDate(Document):
date_field: datetime.date = Field(default_factory=datetime.date.today)

class Settings:
name = "test_date"
bson_encoders = {
datetime.date: lambda dt: datetime.datetime(
year=dt.year,
month=dt.month,
day=dt.day,
hour=0,
minute=0,
second=0,
)
}


class DocumentUnion(UnionDoc):
class Settings:
Expand Down

0 comments on commit e876e1a

Please sign in to comment.