Skip to content

Commit

Permalink
Numbers and core type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
DenverCoder1 committed Feb 3, 2023
1 parent 25e4360 commit 9b67367
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 42 deletions.
31 changes: 11 additions & 20 deletions babel/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import pickle
from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any, overload
from typing import TYPE_CHECKING, Any

from babel import localedata
from babel.plural import PluralRule
Expand Down Expand Up @@ -260,21 +260,13 @@ def negotiate(
if identifier:
return Locale.parse(identifier, sep=sep)

@overload
@classmethod
def parse(cls, identifier: None, sep: str = ..., resolve_likely_subtags: bool = ...) -> None: ...

@overload
@classmethod
def parse(cls, identifier: str | Locale, sep: str = ..., resolve_likely_subtags: bool = ...) -> Locale: ...

@classmethod
def parse(
cls,
identifier: str | Locale | None,
sep: str = '_',
resolve_likely_subtags: bool = True,
) -> Locale | None:
) -> Locale:
"""Create a `Locale` instance for the given locale identifier.
>>> l = Locale.parse('de-DE', sep='-')
Expand Down Expand Up @@ -317,10 +309,9 @@ def parse(
identifier
:raise `UnknownLocaleError`: if no locale data is available for the
requested locale
:raise `TypeError`: if the identifier is not a string or a `Locale`
"""
if identifier is None:
return None
elif isinstance(identifier, Locale):
if isinstance(identifier, Locale):
return identifier
elif not isinstance(identifier, str):
raise TypeError(f"Unexpected value for identifier: {identifier!r}")
Expand Down Expand Up @@ -364,9 +355,9 @@ def _try_load_reducing(parts):
language, territory, script, variant = parts
modifier = None
language = get_global('language_aliases').get(language, language)
territory = get_global('territory_aliases').get(territory, (territory,))[0]
script = get_global('script_aliases').get(script, script)
variant = get_global('variant_aliases').get(variant, variant)
territory = get_global('territory_aliases').get(territory or '', (territory,))[0]
script = get_global('script_aliases').get(script or '', script)
variant = get_global('variant_aliases').get(variant or '', variant)

if territory == 'ZZ':
territory = None
Expand All @@ -389,9 +380,9 @@ def _try_load_reducing(parts):
if likely_subtag is not None:
parts2 = parse_locale(likely_subtag)
if len(parts2) == 5:
language2, _, script2, variant2, modifier2 = parse_locale(likely_subtag)
language2, _, script2, variant2, modifier2 = parts2
else:
language2, _, script2, variant2 = parse_locale(likely_subtag)
language2, _, script2, variant2 = parts2
modifier2 = None
locale = _try_load_reducing((language2, territory, script2, variant2, modifier2))
if locale is not None:
Expand Down Expand Up @@ -1147,7 +1138,7 @@ def negotiate_locale(preferred: Iterable[str], available: Iterable[str], sep: st
def parse_locale(
identifier: str,
sep: str = '_'
) -> tuple[str, str | None, str | None, str | None, str | None]:
) -> tuple[str, str | None, str | None, str | None] | tuple[str, str | None, str | None, str | None, str | None]:
"""Parse a locale identifier into a tuple of the form ``(language,
territory, script, variant, modifier)``.
Expand Down Expand Up @@ -1261,7 +1252,7 @@ def get_locale_identifier(
:param tup: the tuple as returned by :func:`parse_locale`.
:param sep: the separator for the identifier.
"""
tup = tuple(tup[:5])
tup = tuple(tup[:5]) # type: ignore # length should be no more than 5
lang, territory, script, variant, modifier = tup + (None,) * (5 - len(tup))
ret = sep.join(filter(None, (lang, script, territory, variant)))
return f'{ret}@{modifier}' if modifier else ret
45 changes: 23 additions & 22 deletions babel/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import decimal
import re
import warnings
from typing import TYPE_CHECKING, Any, overload
from typing import TYPE_CHECKING, Any, cast, overload

from babel.core import Locale, default_locale, get_global
from babel.localedata import LocaleDataDict
Expand Down Expand Up @@ -428,7 +428,7 @@ def get_decimal_quantum(precision: int | decimal.Decimal) -> decimal.Decimal:

def format_decimal(
number: float | decimal.Decimal | str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
decimal_quantization: bool = True,
group_separator: bool = True,
Expand Down Expand Up @@ -474,8 +474,8 @@ def format_decimal(
number format.
"""
locale = Locale.parse(locale)
if not format:
format = locale.decimal_formats.get(format)
if format is None:
format = locale.decimal_formats[format]
pattern = parse_pattern(format)
return pattern.apply(
number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator)
Expand Down Expand Up @@ -513,15 +513,15 @@ def format_compact_decimal(
number, format = _get_compact_format(number, compact_format, locale, fraction_digits)
# Did not find a format, fall back.
if format is None:
format = locale.decimal_formats.get(None)
format = locale.decimal_formats[None]
pattern = parse_pattern(format)
return pattern.apply(number, locale, decimal_quantization=False)


def _get_compact_format(
number: float | decimal.Decimal | str,
compact_format: LocaleDataDict,
locale: Locale | str | None,
locale: Locale,
fraction_digits: int,
) -> tuple[decimal.Decimal, NumberPattern | None]:
"""Returns the number after dividing by the unit and the format pattern to use.
Expand All @@ -543,7 +543,7 @@ def _get_compact_format(
break
# otherwise, we need to divide the number by the magnitude but remove zeros
# equal to the number of 0's in the pattern minus 1
number = number / (magnitude // (10 ** (pattern.count("0") - 1)))
number = cast(decimal.Decimal, number / (magnitude // (10 ** (pattern.count("0") - 1))))
# round to the number of fraction digits requested
rounded = round(number, fraction_digits)
# if the remaining number is singular, use the singular format
Expand All @@ -565,7 +565,7 @@ class UnknownCurrencyFormatError(KeyError):
def format_currency(
number: float | decimal.Decimal | str,
currency: str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
currency_digits: bool = True,
format_type: Literal["name", "standard", "accounting"] = "standard",
Expand Down Expand Up @@ -680,7 +680,7 @@ def format_currency(
def _format_currency_long_name(
number: float | decimal.Decimal | str,
currency: str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
currency_digits: bool = True,
format_type: Literal["name", "standard", "accounting"] = "standard",
Expand All @@ -706,7 +706,7 @@ def _format_currency_long_name(

# Step 5.
if not format:
format = locale.decimal_formats.get(format)
format = locale.decimal_formats[format]

pattern = parse_pattern(format)

Expand Down Expand Up @@ -758,13 +758,15 @@ def format_compact_currency(
# compress adjacent spaces into one
format = re.sub(r'(\s)\s+', r'\1', format).strip()
break
if format is None:
raise ValueError('No compact currency format found for the given number and locale.')
pattern = parse_pattern(format)
return pattern.apply(number, locale, currency=currency, currency_digits=False, decimal_quantization=False)


def format_percent(
number: float | decimal.Decimal | str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
decimal_quantization: bool = True,
group_separator: bool = True,
Expand Down Expand Up @@ -808,15 +810,15 @@ def format_percent(
"""
locale = Locale.parse(locale)
if not format:
format = locale.percent_formats.get(format)
format = locale.percent_formats[format]
pattern = parse_pattern(format)
return pattern.apply(
number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator)


def format_scientific(
number: float | decimal.Decimal | str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
decimal_quantization: bool = True,
) -> str:
Expand Down Expand Up @@ -847,7 +849,7 @@ def format_scientific(
"""
locale = Locale.parse(locale)
if not format:
format = locale.scientific_formats.get(format)
format = locale.scientific_formats[format]
pattern = parse_pattern(format)
return pattern.apply(
number, locale, decimal_quantization=decimal_quantization)
Expand All @@ -856,7 +858,7 @@ def format_scientific(
class NumberFormatError(ValueError):
"""Exception raised when a string cannot be parsed into a number."""

def __init__(self, message: str, suggestions: str | None = None) -> None:
def __init__(self, message: str, suggestions: list[str] | None = None) -> None:
super().__init__(message)
#: a list of properly formatted numbers derived from the invalid input
self.suggestions = suggestions
Expand Down Expand Up @@ -1140,7 +1142,7 @@ def scientific_notation_elements(self, value: decimal.Decimal, locale: Locale |

def apply(
self,
value: float | decimal.Decimal,
value: float | decimal.Decimal | str,
locale: Locale | str | None,
currency: str | None = None,
currency_digits: bool = True,
Expand Down Expand Up @@ -1211,9 +1213,9 @@ def apply(
number = ''.join([
self._quantize_value(value, locale, frac_prec, group_separator),
get_exponential_symbol(locale),
exp_sign,
self._format_int(
str(exp), self.exp_prec[0], self.exp_prec[1], locale)])
exp_sign, # type: ignore # exp_sign is always defined here
self._format_int(str(exp), self.exp_prec[0], self.exp_prec[1], locale) # type: ignore # exp is always defined here
])

# Is it a significant digits pattern?
elif '@' in self.pattern:
Expand All @@ -1234,9 +1236,8 @@ def apply(
number if self.number_pattern != '' else '',
self.suffix[is_negative]])

if '¤' in retval:
retval = retval.replace('¤¤¤',
get_currency_name(currency, value, locale))
if '¤' in retval and currency is not None:
retval = retval.replace('¤¤¤', get_currency_name(currency, value, locale))
retval = retval.replace('¤¤', currency.upper())
retval = retval.replace('¤', get_currency_symbol(currency, locale))

Expand Down

0 comments on commit 9b67367

Please sign in to comment.