Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type hints for types and mapper modules #446

Merged
merged 6 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ ignore_missing_imports = true
no_implicit_optional = true
warn_unused_ignores = true

[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*,trino.mapper,trino.types]
[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*]
ignore_errors = true
69 changes: 35 additions & 34 deletions trino/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import abc
import base64
import sys
import uuid
from datetime import date, datetime, time, timedelta, timezone, tzinfo
from decimal import Decimal
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar

try:
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
except ModuleNotFoundError:
else:
from backports.zoneinfo import ZoneInfo

import trino.exceptions
Expand Down Expand Up @@ -55,7 +56,7 @@ def map(self, value: Any) -> Optional[int]:


class DoubleValueMapper(ValueMapper[float]):
def map(self, value) -> Optional[float]:
def map(self, value: Any) -> Optional[float]:
if value is None:
return None
if value == 'Infinity':
Expand All @@ -68,7 +69,7 @@ def map(self, value) -> Optional[float]:


class DecimalValueMapper(ValueMapper[Decimal]):
def map(self, value) -> Optional[Decimal]:
def map(self, value: Any) -> Optional[Decimal]:
if value is None:
return None
return Decimal(value)
Expand All @@ -82,25 +83,25 @@ def map(self, value: Any) -> Optional[str]:


class BinaryValueMapper(ValueMapper[bytes]):
def map(self, value) -> Optional[bytes]:
def map(self, value: Any) -> Optional[bytes]:
if value is None:
return None
return base64.b64decode(value.encode("utf8"))


class DateValueMapper(ValueMapper[date]):
def map(self, value) -> Optional[date]:
def map(self, value: Any) -> Optional[date]:
if value is None:
return None
return date.fromisoformat(value)


class TimeValueMapper(ValueMapper[time]):
def __init__(self, precision):
def __init__(self, precision: int):
self.time_default_size = 8 # size of 'HH:MM:SS'
self.precision = precision

def map(self, value) -> Optional[time]:
def map(self, value: Any) -> Optional[time]:
if value is None:
return None
whole_python_temporal_value = value[:self.time_default_size]
Expand All @@ -115,7 +116,7 @@ def _add_second(self, time_value: time) -> time:


class TimeWithTimeZoneValueMapper(TimeValueMapper):
def map(self, value) -> Optional[time]:
def map(self, value: Any) -> Optional[time]:
if value is None:
return None
whole_python_temporal_value = value[:self.time_default_size]
Expand All @@ -128,11 +129,11 @@ def map(self, value) -> Optional[time]:


class TimestampValueMapper(ValueMapper[datetime]):
def __init__(self, precision):
def __init__(self, precision: int):
self.datetime_default_size = 19 # size of 'YYYY-MM-DD HH:MM:SS' (the datetime string up to the seconds)
self.precision = precision

def map(self, value) -> Optional[datetime]:
def map(self, value: Any) -> Optional[datetime]:
if value is None:
return None
whole_python_temporal_value = value[:self.datetime_default_size]
Expand All @@ -144,7 +145,7 @@ def map(self, value) -> Optional[datetime]:


class TimestampWithTimeZoneValueMapper(TimestampValueMapper):
def map(self, value) -> Optional[datetime]:
def map(self, value: Any) -> Optional[datetime]:
if value is None:
return None
datetime_with_fraction, timezone_part = value.rsplit(' ', 1)
Expand Down Expand Up @@ -175,36 +176,36 @@ class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
def __init__(self, mapper: ValueMapper[Any]):
self.mapper = mapper

def map(self, values: List[Any]) -> Optional[List[Any]]:
if values is None:
def map(self, value: Optional[List[Any]]) -> Optional[List[Any]]:
if value is None:
return None
return [self.mapper.map(value) for value in values]
return [self.mapper.map(v) for v in value]


class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
self.key_mapper = key_mapper
self.value_mapper = value_mapper

def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
if values is None:
def map(self, value: Any) -> Optional[Dict[Any, Optional[Any]]]:
if value is None:
return None
return {
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
self.key_mapper.map(k): self.value_mapper.map(v) for k, v in value.items()
}


class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
def __init__(self, mappers: List[ValueMapper[Any]], names: List[str], types: List[str]):
def __init__(self, mappers: List[ValueMapper[Any]], names: List[Optional[str]], types: List[str]):
self.mappers = mappers
self.names = names
self.types = types

def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]:
if values is None:
def map(self, value: Optional[List[Any]]) -> Optional[Tuple[Optional[Any], ...]]:
if value is None:
return None
return NamedRowTuple(
list(self.mappers[index].map(value) for index, value in enumerate(values)),
list(self.mappers[i].map(v) for i, v in enumerate(value)),
self.names,
self.types
)
Expand All @@ -218,7 +219,7 @@ def map(self, value: Any) -> Optional[uuid.UUID]:


class NoOpValueMapper(ValueMapper[Any]):
def map(self, value) -> Optional[Any]:
def map(self, value: Any) -> Optional[Any]:
return value


Expand All @@ -228,7 +229,7 @@ class NoOpRowMapper:
Used when legacy_primitive_types is False.
"""

def map(self, rows):
def map(self, rows: List[List[Any]]) -> List[List[Any]]:
return rows


Expand All @@ -240,14 +241,14 @@ class RowMapperFactory:
"""
NO_OP_ROW_MAPPER = NoOpRowMapper()

def create(self, columns, legacy_primitive_types):
def create(self, columns: List[Any], legacy_primitive_types: bool) -> RowMapper | NoOpRowMapper:
assert columns is not None

if not legacy_primitive_types:
return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns])
return RowMapperFactory.NO_OP_ROW_MAPPER

def _create_value_mapper(self, column) -> ValueMapper:
def _create_value_mapper(self, column: Dict[str, Any]) -> ValueMapper[Any]:
col_type = column['rawType']

# primitive types
Expand Down Expand Up @@ -285,9 +286,9 @@ def _create_value_mapper(self, column) -> ValueMapper:
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
return MapValueMapper(key_mapper, value_mapper)
if col_type == 'row':
mappers = []
names = []
types = []
mappers: List[ValueMapper[Any]] = []
names: List[Optional[str]] = []
types: List[str] = []
for arg in column['arguments']:
mappers.append(self._create_value_mapper(arg['value']['typeSignature']))
names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None)
Expand All @@ -299,7 +300,7 @@ def _create_value_mapper(self, column) -> ValueMapper:
return UuidValueMapper()
return NoOpValueMapper()

def _get_precision(self, column: Dict[str, Any]):
def _get_precision(self, column: Dict[str, Any]) -> int:
args = column['arguments']
if len(args) == 0:
return 3
Expand All @@ -310,18 +311,18 @@ class RowMapper:
"""
Maps a row of data given a list of mapping functions
"""
def __init__(self, columns):
def __init__(self, columns: List[ValueMapper[Any]]):
self.columns = columns

def map(self, rows):
def map(self, rows: List[List[Any]]) -> List[List[Any]]:
if len(self.columns) == 0:
return rows
return [self._map_row(row) for row in rows]

def _map_row(self, row):
def _map_row(self, row: List[Any]) -> List[Any]:
return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)]

def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]:
def _map_value(self, value: Any, value_mapper: ValueMapper[T]) -> Optional[T]:
try:
return value_mapper.map(value)
except ValueError as e:
Expand Down
25 changes: 14 additions & 11 deletions trino/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
from datetime import datetime, time, timedelta
from decimal import Decimal
from typing import Any, Dict, Generic, List, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast

from dateutil import tz

Expand All @@ -26,15 +26,16 @@ def new_instance(self, value: PythonTemporalType, fraction: Decimal) -> Temporal
def to_python_type(self) -> PythonTemporalType:
pass

def round_to(self, precision: int) -> TemporalType:
def round_to(self, precision: int) -> TemporalType[PythonTemporalType]:
"""
Python datetime and time only support up to microsecond precision
In case the supplied value exceeds the specified precision,
the value needs to be rounded.
"""
precision = min(precision, MAX_PYTHON_TEMPORAL_PRECISION_POWER)
remaining_fractional_seconds = self._remaining_fractional_seconds
digits = abs(remaining_fractional_seconds.as_tuple().exponent)
# exponent can return `n`, `N`, `F` too if the value is a NaN for example
digits = abs(remaining_fractional_seconds.as_tuple().exponent) # type: ignore
if digits > precision:
rounding_factor = POWERS_OF_TEN[precision]
rounded = remaining_fractional_seconds.quantize(Decimal(1 / rounding_factor))
Expand Down Expand Up @@ -101,33 +102,35 @@ def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZ

def normalize(self, value: datetime) -> datetime:
if tz.datetime_ambiguous(value):
return self._whole_python_temporal_value.tzinfo.normalize(value)
# This appears to be dead code since tzinfo doesn't actually have a `normalize` method.
# TODO: Fix this or remove. (https://github.com/trinodb/trino-python-client/issues/449)
return self._whole_python_temporal_value.tzinfo.normalize(value) # type: ignore
return value


class NamedRowTuple(tuple):
class NamedRowTuple(Tuple[Any, ...]):
"""Custom tuple class as namedtuple doesn't support missing or duplicate names"""
def __new__(cls, values, names: List[str], types: List[str]):
return super().__new__(cls, values)
def __new__(cls, values: List[Any], names: List[str], types: List[str]) -> NamedRowTuple:
return cast(NamedRowTuple, super().__new__(cls, values))

def __init__(self, values, names: List[str], types: List[str]):
def __init__(self, values: List[Any], names: List[Optional[str]], types: List[str]):
self._names = names
# With names and types users can retrieve the name and Trino data type of a row
self.__annotations__ = dict()
self.__annotations__["names"] = names
self.__annotations__["types"] = types
elements: List[Any] = []
for name, value in zip(names, values):
if names.count(name) == 1:
if name is not None and names.count(name) == 1:
setattr(self, name, value)
elements.append(f"{name}: {repr(value)}")
else:
elements.append(repr(value))
self._repr = "(" + ", ".join(elements) + ")"

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
if self._names.count(name):
raise ValueError("Ambiguous row field reference: " + name)

def __repr__(self):
def __repr__(self) -> str:
return self._repr
Loading