Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow users to register custom encoders #296

Merged
merged 4 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/test_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,3 +946,21 @@ def test_copy_copy():
)
def test_escape_key(key_str, escaped):
assert api.key(key_str).as_string() == escaped


def test_custom_encoders():
import decimal

@api.register_encoder
def encode_decimal(obj):
if isinstance(obj, decimal.Decimal):
return api.float_(str(obj))
raise TypeError

assert api.item(decimal.Decimal("1.23")).as_string() == "1.23"

with pytest.raises(TypeError):
api.item(object())

assert api.dumps({"foo": decimal.Decimal("1.23")}) == "foo = 1.23\n"
api.unregister_encoder(encode_decimal)
4 changes: 4 additions & 0 deletions tomlkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from tomlkit.api import loads
from tomlkit.api import nl
from tomlkit.api import parse
from tomlkit.api import register_encoder
from tomlkit.api import string
from tomlkit.api import table
from tomlkit.api import time
from tomlkit.api import unregister_encoder
from tomlkit.api import value
from tomlkit.api import ws

Expand Down Expand Up @@ -52,4 +54,6 @@
"TOMLDocument",
"value",
"ws",
"register_encoder",
"unregister_encoder",
]
22 changes: 22 additions & 0 deletions tomlkit/api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from __future__ import annotations

import contextlib
import datetime as _datetime

from collections.abc import Mapping
from typing import IO
from typing import Iterable
from typing import TypeVar

from tomlkit._utils import parse_rfc3339
from tomlkit.container import Container
from tomlkit.exceptions import UnexpectedCharError
from tomlkit.items import CUSTOM_ENCODERS
from tomlkit.items import AoT
from tomlkit.items import Array
from tomlkit.items import Bool
from tomlkit.items import Comment
from tomlkit.items import Date
from tomlkit.items import DateTime
from tomlkit.items import DottedKey
from tomlkit.items import Encoder
from tomlkit.items import Float
from tomlkit.items import InlineTable
from tomlkit.items import Integer
Expand Down Expand Up @@ -284,3 +288,21 @@ def nl() -> Whitespace:
def comment(string: str) -> Comment:
"""Create a comment item."""
return Comment(Trivia(comment_ws=" ", comment="# " + string))


E = TypeVar("E", bound=Encoder)


def register_encoder(encoder: E) -> E:
"""Add a custom encoder, which should be a function that will be called
if the value can't otherwise be converted. It should takes a single value
and return a TOMLKit item or raise a ``TypeError``.
"""
CUSTOM_ENCODERS.append(encoder)
return encoder


def unregister_encoder(encoder: Encoder) -> None:
"""Unregister a custom encoder."""
with contextlib.suppress(ValueError):
CUSTOM_ENCODERS.remove(encoder)
24 changes: 23 additions & 1 deletion tomlkit/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from enum import Enum
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Collection
from typing import Iterable
from typing import Iterator
Expand Down Expand Up @@ -57,6 +58,15 @@ class _CustomDict(MutableMapping, dict):


ItemT = TypeVar("ItemT", bound="Item")
Encoder = Callable[[Any], "Item"]
CUSTOM_ENCODERS: list[Encoder] = []


class _ConvertError(TypeError, ValueError):
"""An internal error raised when item() fails to convert a value.
It should be a TypeError, but due to historical reasons
it needs to subclass ValueError as well.
"""


@overload
Expand Down Expand Up @@ -218,8 +228,20 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I
Trivia(),
value.isoformat(),
)
else:
for encoder in CUSTOM_ENCODERS:
try:
rv = encoder(value)
except TypeError:
pass
else:
if not isinstance(rv, Item):
raise _ConvertError(
f"Custom encoder returned {type(rv)}, not a subclass of Item"
)
return rv

raise ValueError(f"Invalid type {type(value)}")
raise _ConvertError(f"Invalid type {type(value)}")


class StringType(Enum):
Expand Down