Skip to content

Commit

Permalink
fix: full methods of integer and float (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostming committed Jul 27, 2023
1 parent 9e39a63 commit e07f6a1
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 96 deletions.
65 changes: 65 additions & 0 deletions tomlkit/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import TypeVar


WT = TypeVar("WT", bound="WrapperType")

if TYPE_CHECKING: # pragma: no cover
# Define _CustomList and _CustomDict as a workaround for:
# https://github.com/python/mypy/issues/11427
#
# According to this issue, the typeshed contains a "lie"
# (it adds MutableSequence to the ancestry of list and MutableMapping to
# the ancestry of dict) which completely messes with the type inference for
# Table, InlineTable, Array and Container.
#
# Importing from builtins is preferred over simple assignment, see issues:
# https://github.com/python/mypy/issues/8715
# https://github.com/python/mypy/issues/10068
from builtins import dict as _CustomDict # noqa: N812
from builtins import float as _CustomFloat # noqa: N812
from builtins import int as _CustomInt # noqa: N812
from builtins import list as _CustomList # noqa: N812
from typing import Callable
from typing import Concatenate
from typing import ParamSpec
from typing import Protocol

P = ParamSpec("P")

class WrapperType(Protocol):
def _new(self: WT, value: Any) -> WT:
...

else:
from collections.abc import MutableMapping
from collections.abc import MutableSequence
from numbers import Integral
from numbers import Real

class _CustomList(MutableSequence, list):
"""Adds MutableSequence mixin while pretending to be a builtin list"""

class _CustomDict(MutableMapping, dict):
"""Adds MutableMapping mixin while pretending to be a builtin dict"""

class _CustomInt(Integral, int):
"""Adds Integral mixin while pretending to be a builtin int"""

class _CustomFloat(Real, float):
"""Adds Real mixin while pretending to be a builtin float"""


def wrap_method(
original_method: Callable[Concatenate[WT, P], Any]
) -> Callable[Concatenate[WT, P], Any]:
def wrapper(self: WT, *args: P.args, **kwargs: P.kwargs) -> Any:
result = original_method(self, *args, **kwargs)
if result is NotImplemented:
return result
return self._new(result)

return wrapper
2 changes: 1 addition & 1 deletion tomlkit/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterator

from tomlkit._compat import decode
from tomlkit._types import _CustomDict
from tomlkit._utils import merge_dicts
from tomlkit.exceptions import KeyAlreadyPresent
from tomlkit.exceptions import NonExistentKey
Expand All @@ -19,7 +20,6 @@
from tomlkit.items import Table
from tomlkit.items import Trivia
from tomlkit.items import Whitespace
from tomlkit.items import _CustomDict
from tomlkit.items import item as _item


Expand Down
185 changes: 92 additions & 93 deletions tomlkit/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import abc
import copy
import dataclasses
import math
import re
import string
import sys

from datetime import date
from datetime import datetime
Expand All @@ -24,42 +26,24 @@

from tomlkit._compat import PY38
from tomlkit._compat import decode
from tomlkit._types import _CustomDict
from tomlkit._types import _CustomFloat
from tomlkit._types import _CustomInt
from tomlkit._types import _CustomList
from tomlkit._types import wrap_method
from tomlkit._utils import CONTROL_CHARS
from tomlkit._utils import escape_string
from tomlkit.exceptions import InvalidStringError


if TYPE_CHECKING: # pragma: no cover
# Define _CustomList and _CustomDict as a workaround for:
# https://github.com/python/mypy/issues/11427
#
# According to this issue, the typeshed contains a "lie"
# (it adds MutableSequence to the ancestry of list and MutableMapping to
# the ancestry of dict) which completely messes with the type inference for
# Table, InlineTable, Array and Container.
#
# Importing from builtins is preferred over simple assignment, see issues:
# https://github.com/python/mypy/issues/8715
# https://github.com/python/mypy/issues/10068
from builtins import dict as _CustomDict # noqa: N812, TC004
from builtins import list as _CustomList # noqa: N812, TC004

# Allow type annotations but break circular imports
if TYPE_CHECKING:
from tomlkit import container
else:
from collections.abc import MutableMapping
from collections.abc import MutableSequence

class _CustomList(MutableSequence, list):
"""Adds MutableSequence mixin while pretending to be a builtin list"""

class _CustomDict(MutableMapping, dict):
"""Adds MutableMapping mixin while pretending to be a builtin dict"""


ItemT = TypeVar("ItemT", bound="Item")
Encoder = Callable[[Any], "Item"]
CUSTOM_ENCODERS: list[Encoder] = []
AT = TypeVar("AT", bound="AbstractTable")


class _ConvertError(TypeError, ValueError):
Expand Down Expand Up @@ -456,7 +440,7 @@ def __eq__(self, other: Any) -> bool:
class DottedKey(Key):
def __init__(
self,
keys: Iterable[Key],
keys: Iterable[SingleKey],
sep: str | None = None,
original: str | None = None,
) -> None:
Expand Down Expand Up @@ -606,25 +590,27 @@ def __str__(self) -> str:
return f"{self._trivia.indent}{decode(self._trivia.comment)}"


class Integer(int, Item):
class Integer(Item, _CustomInt):
"""
An integer literal.
"""

def __new__(cls, value: int, trivia: Trivia, raw: str) -> Integer:
return super().__new__(cls, value)
return int.__new__(cls, value)

def __init__(self, _: int, trivia: Trivia, raw: str) -> None:
def __init__(self, value: int, trivia: Trivia, raw: str) -> None:
super().__init__(trivia)

self._original = value
self._raw = raw
self._sign = False

if re.match(r"^[+\-]\d+$", raw):
self._sign = True

def unwrap(self) -> int:
return int(self)
return self._original

__int__ = unwrap

@property
def discriminant(self) -> int:
Expand All @@ -638,30 +624,6 @@ def value(self) -> int:
def as_string(self) -> str:
return self._raw

def __add__(self, other):
result = super().__add__(other)
if result is NotImplemented:
return result
return self._new(result)

def __radd__(self, other):
result = super().__radd__(other)
if result is NotImplemented:
return result
return self._new(result)

def __sub__(self, other):
result = super().__sub__(other)
if result is NotImplemented:
return result
return self._new(result)

def __rsub__(self, other):
result = super().__rsub__(other)
if result is NotImplemented:
return result
return self._new(result)

def _new(self, result):
raw = str(result)
if self._sign:
Expand All @@ -673,26 +635,63 @@ def _new(self, result):
def _getstate(self, protocol=3):
return int(self), self._trivia, self._raw


class Float(float, Item):
# int methods
__abs__ = wrap_method(int.__abs__)
__add__ = wrap_method(int.__add__)
__and__ = wrap_method(int.__and__)
__ceil__ = wrap_method(int.__ceil__)
__eq__ = int.__eq__
__floor__ = wrap_method(int.__floor__)
__floordiv__ = wrap_method(int.__floordiv__)
__invert__ = wrap_method(int.__invert__)
__le__ = int.__le__
__lshift__ = wrap_method(int.__lshift__)
__lt__ = int.__lt__
__mod__ = wrap_method(int.__mod__)
__mul__ = wrap_method(int.__mul__)
__neg__ = wrap_method(int.__neg__)
__or__ = wrap_method(int.__or__)
__pos__ = wrap_method(int.__pos__)
__pow__ = wrap_method(int.__pow__)
__radd__ = wrap_method(int.__radd__)
__rand__ = wrap_method(int.__rand__)
__rfloordiv__ = wrap_method(int.__rfloordiv__)
__rlshift__ = wrap_method(int.__rlshift__)
__rmod__ = wrap_method(int.__rmod__)
__rmul__ = wrap_method(int.__rmul__)
__ror__ = wrap_method(int.__ror__)
__round__ = wrap_method(int.__round__)
__rpow__ = wrap_method(int.__rpow__)
__rrshift__ = wrap_method(int.__rrshift__)
__rshift__ = wrap_method(int.__rshift__)
__rtruediv__ = wrap_method(int.__rtruediv__)
__rxor__ = wrap_method(int.__rxor__)
__truediv__ = wrap_method(int.__truediv__)
__trunc__ = wrap_method(int.__trunc__)
__xor__ = wrap_method(int.__xor__)


class Float(Item, _CustomFloat):
"""
A float literal.
"""

def __new__(cls, value: float, trivia: Trivia, raw: str) -> Integer:
return super().__new__(cls, value)
def __new__(cls, value: float, trivia: Trivia, raw: str) -> Float:
return float.__new__(cls, value)

def __init__(self, _: float, trivia: Trivia, raw: str) -> None:
def __init__(self, value: float, trivia: Trivia, raw: str) -> None:
super().__init__(trivia)

self._original = value
self._raw = raw
self._sign = False

if re.match(r"^[+\-].+$", raw):
self._sign = True

def unwrap(self) -> float:
return float(self)
return self._original

__float__ = unwrap

@property
def discriminant(self) -> int:
Expand All @@ -706,32 +705,6 @@ def value(self) -> float:
def as_string(self) -> str:
return self._raw

def __add__(self, other):
result = super().__add__(other)

return self._new(result)

def __radd__(self, other):
result = super().__radd__(other)

if isinstance(other, Float):
return self._new(result)

return result

def __sub__(self, other):
result = super().__sub__(other)

return self._new(result)

def __rsub__(self, other):
result = super().__rsub__(other)

if isinstance(other, Float):
return self._new(result)

return result

def _new(self, result):
raw = str(result)

Expand All @@ -744,6 +717,35 @@ def _new(self, result):
def _getstate(self, protocol=3):
return float(self), self._trivia, self._raw

# float methods
__abs__ = wrap_method(float.__abs__)
__add__ = wrap_method(float.__add__)
__eq__ = float.__eq__
__floordiv__ = wrap_method(float.__floordiv__)
__le__ = float.__le__
__lt__ = float.__lt__
__mod__ = wrap_method(float.__mod__)
__mul__ = wrap_method(float.__mul__)
__neg__ = wrap_method(float.__neg__)
__pos__ = wrap_method(float.__pos__)
__pow__ = wrap_method(float.__pow__)
__radd__ = wrap_method(float.__radd__)
__rfloordiv__ = wrap_method(float.__rfloordiv__)
__rmod__ = wrap_method(float.__rmod__)
__rmul__ = wrap_method(float.__rmul__)
__round__ = wrap_method(float.__round__)
__rpow__ = wrap_method(float.__rpow__)
__rtruediv__ = wrap_method(float.__rtruediv__)
__truediv__ = wrap_method(float.__truediv__)
__trunc__ = float.__trunc__

if sys.version_info >= (3, 9):
__ceil__ = float.__ceil__
__floor__ = float.__floor__
else:
__ceil__ = math.ceil
__floor__ = math.floor


class Bool(Item):
"""
Expand Down Expand Up @@ -1410,9 +1412,6 @@ def _getstate(self, protocol=3):
return list(self._iter_items()), self._trivia, self._multiline


AT = TypeVar("AT", bound="AbstractTable")


class AbstractTable(Item, _CustomDict):
"""Common behaviour of both :class:`Table` and :class:`InlineTable`"""

Expand Down Expand Up @@ -1452,11 +1451,11 @@ def append(self, key, value):
raise NotImplementedError

@overload
def add(self: AT, value: Comment | Whitespace) -> AT:
def add(self: AT, key: Comment | Whitespace) -> AT:
...

@overload
def add(self: AT, key: Key | str, value: Any) -> AT:
def add(self: AT, key: Key | str, value: Any = ...) -> AT:
...

def add(self, key, value=None):
Expand Down
2 changes: 1 addition & 1 deletion tomlkit/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Parser:
Parser for TOML documents.
"""

def __init__(self, string: str) -> None:
def __init__(self, string: str | bytes) -> None:
# Input to parse
self._src = Source(decode(string))

Expand Down
2 changes: 1 addition & 1 deletion tomlkit/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, source: Source) -> None:
def __call__(self, *args, **kwargs):
return _State(self._source, *args, **kwargs)

def __enter__(self) -> None:
def __enter__(self) -> _State:
state = self()
self._states.append(state)
return state.__enter__()
Expand Down

0 comments on commit e07f6a1

Please sign in to comment.