Skip to content

Commit

Permalink
Fix incorrect time-zone in results after localization
Browse files Browse the repository at this point in the history
Consider the following example using the canonical way to add zones to
datetime objects:

    >>> import pytz
    >>> import datetime
    >>> import zoneinfo
    >>> datetime.datetime(2023, 1, 1, tzinfo=pytz.timezone("America/Los_Angeles")).isoformat()
    '2023-01-01T00:00:00-07:53'
    >>> datetime.datetime(2023, 1, 1, tzinfo=zoneinfo.ZoneInfo("America/Los_Angeles")).isoformat()
    '2023-01-01T00:00:00-08:00'

pytz does eager timezone evaluation and uses the local-mean-time since
the instant in time is not known. It requires an additional `localize`
call to get the correct zone like so:

    >>> pytz.timezone("America/Los_Angeles").localize(datetime.datetime(2023, 1, 1)).isoformat()
    '2023-01-01T00:00:00-08:00'

This increases chances of introducing bugs when writing idiomatic
Python.

The only reason to use pytz was because it allowed to control what
happens with ambiguous datetimes but the standard library also allows
provides control over that since 3.9 (and is available as
backports.zoneinfo for older versions).
  • Loading branch information
john-bodley authored and hashhar committed May 8, 2023
1 parent a031cc2 commit 2b9ca0c
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 40 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
python_requires='>=3.7',
install_requires=[
"backports.zoneinfo;python_version<'3.9'",
"python-dateutil",
"pytz",
"requests",
"tzlocal",
Expand Down
39 changes: 19 additions & 20 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
from decimal import Decimal
from typing import Tuple

try:
from zoneinfo import ZoneInfo
except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo

import pytest
import pytz
import requests
from tzlocal import get_localzone_name # type: ignore

Expand Down Expand Up @@ -234,7 +238,7 @@ def test_legacy_primitive_types_with_connection_and_cursor(
assert rows[0][0] == Decimal('0.142857')
assert rows[0][1] == date(2018, 1, 1)
assert rows[0][2] == datetime(2019, 1, 1, tzinfo=timezone(timedelta(hours=1)))
assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone('UTC'))
assert rows[0][3] == datetime(2019, 1, 1, tzinfo=ZoneInfo('UTC'))
assert rows[0][4] == datetime(2019, 1, 1)
assert rows[0][5] == time(0, 0, 0, 0)
else:
Expand Down Expand Up @@ -338,7 +342,7 @@ def test_datetime_query_param(trino_connection):
def test_datetime_with_utc_time_zone_query_param(trino_connection):
cur = trino_connection.cursor()

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('UTC'))
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('UTC'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()
Expand All @@ -364,7 +368,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
def test_datetime_with_named_time_zone_query_param(trino_connection):
cur = trino_connection.cursor()

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles'))
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('America/Los_Angeles'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()
Expand Down Expand Up @@ -407,32 +411,24 @@ def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection):
cur = trino_connection.cursor()

# This is a datetime that lies within a DST transition and not actually exists.
params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone('Europe/Brussels'))
params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=ZoneInfo('Europe/Brussels'))
with pytest.raises(trino.exceptions.TrinoUserError):
cur.execute("SELECT ?", params=(params,))
cur.fetchall()


def test_doubled_datetimes(trino_connection):
# Trino doesn't distinguish between doubled datetimes that lie within a DST transition. See also
@pytest.mark.parametrize('fold', [0, 1])
def test_doubled_datetimes(trino_connection, fold):
# Trino doesn't distinguish between doubled datetimes that lie within a DST transition.
# See also https://github.com/trinodb/trino/issues/5781
cur = trino_connection.cursor()

params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True)
params = datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern'), fold=fold)

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))

cur = trino_connection.cursor()

params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False)

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))
assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern'))


def test_date_query_param(trino_connection):
Expand Down Expand Up @@ -529,7 +525,7 @@ def test_time_query_param(trino_connection):
def test_time_with_named_time_zone_query_param(trino_connection):
cur = trino_connection.cursor()

params = time(16, 43, 22, 320000, tzinfo=pytz.timezone('Asia/Shanghai'))
params = time(16, 43, 22, 320000, tzinfo=ZoneInfo('Asia/Shanghai'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()
Expand Down Expand Up @@ -693,7 +689,10 @@ def test_array_timestamp_query_param(trino_connection):
def test_array_timestamp_with_timezone_query_param(trino_connection):
cur = trino_connection.cursor()

params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)]
params = [
datetime(2020, 1, 1, 0, 0, 0, tzinfo=ZoneInfo('UTC')),
datetime(2020, 1, 2, 0, 0, 0, tzinfo=ZoneInfo('UTC')),
]

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from datetime import date, datetime, time, timedelta, timezone, tzinfo
from decimal import Decimal

try:
from zoneinfo import ZoneInfo
except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo

import pytest
import pytz

import trino
from tests.integration.conftest import trino_version
Expand Down Expand Up @@ -729,7 +733,7 @@ def create_timezone(timezone_str: str) -> tzinfo:
else:
return timezone(-timedelta(hours=hours, minutes=minutes))
else:
return pytz.timezone(timezone_str)
return ZoneInfo(timezone_str)


def test_interval(trino_connection):
Expand Down
22 changes: 9 additions & 13 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,18 @@
from time import sleep
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import pytz
try:
from zoneinfo import ZoneInfo
except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo

import requests
from pytz.tzinfo import BaseTzInfo
from dateutil import tz
from tzlocal import get_localzone_name # type: ignore

import trino.logging
from trino import constants, exceptions

try:
from zoneinfo import ZoneInfo # type: ignore

except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo # type: ignore


__all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"]

logger = trino.logging.get_logger(__name__)
Expand Down Expand Up @@ -946,7 +943,7 @@ def _create_tzinfo(timezone_str: str) -> tzinfo:
return timezone(-timedelta(hours=int(hours), minutes=int(minutes)))
return timezone(timedelta(hours=int(hours), minutes=int(minutes)))
else:
return pytz.timezone(timezone_str)
return ZoneInfo(timezone_str)


def _fraction_to_decimal(fractional_str: str) -> Decimal:
Expand Down Expand Up @@ -996,8 +993,7 @@ def add_time_delta(self, time_delta: timedelta) -> PythonTemporalType:
def normalize(self, value: PythonTemporalType) -> PythonTemporalType:
"""
If `add_time_delta` results in value crossing DST boundaries, this method should
return a normalized version of the value to account for it, for example,
using `pytz.timezone.normalize`.
return a normalized version of the value to account for it.
"""
return value

Expand Down Expand Up @@ -1041,7 +1037,7 @@ def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZ
return TimestampWithTimeZone(value, fraction)

def normalize(self, value: datetime) -> datetime:
if isinstance(self._whole_python_temporal_value.tzinfo, BaseTzInfo):
if tz.datetime_ambiguous(value):
return self._whole_python_temporal_value.tzinfo.normalize(value)
return value

Expand Down
13 changes: 8 additions & 5 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types
from urllib.parse import urlparse

import pytz
try:
from zoneinfo import ZoneInfo
except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo

import trino.client
import trino.exceptions
Expand Down Expand Up @@ -425,8 +428,8 @@ def _format_prepared_param(self, param):
if isinstance(param, datetime.datetime) and param.tzinfo is not None:
datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f")
# named timezones
if hasattr(param.tzinfo, 'zone'):
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.zone)
if isinstance(param.tzinfo, ZoneInfo):
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.key)
# offset-based timezones
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param))

Expand All @@ -438,8 +441,8 @@ def _format_prepared_param(self, param):
if isinstance(param, datetime.time) and param.tzinfo is not None:
time_str = param.strftime("%H:%M:%S.%f")
# named timezones
if hasattr(param.tzinfo, 'zone'):
utc_offset = datetime.datetime.now(pytz.timezone(param.tzinfo.zone)).strftime('%z')
if isinstance(param.tzinfo, ZoneInfo):
utc_offset = datetime.datetime.now(tz=param.tzinfo).strftime('%z')
return "TIME '%s %s:%s'" % (time_str, utc_offset[:3], utc_offset[3:])
# offset-based timezones
return "TIME '%s %s'" % (time_str, param.strftime('%Z')[3:])
Expand Down

0 comments on commit 2b9ca0c

Please sign in to comment.