diff --git a/CHANGELOG.md b/CHANGELOG.md index a49387d..d93e633 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* sqlalchemy's DATETIME type now rendered as YDB's Datetime instead of Timestamp + ## 0.0.1b19 ## * Do not use set for columns in index, use dict (preserve order) diff --git a/test/test_core.py b/test/test_core.py index 2464c5e..88f0aab 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -213,8 +213,8 @@ def define_tables(cls, metadata: sa.MetaData): Table( "test_datetime_types", metadata, - Column("datetime", sa.DateTime, primary_key=True), - Column("datetime_tz", sa.DateTime(timezone=True)), + Column("datetime", sa.DATETIME, primary_key=True), + Column("datetime_tz", sa.DATETIME(timezone=True)), Column("timestamp", sa.TIMESTAMP), Column("timestamp_tz", sa.TIMESTAMP(timezone=True)), Column("date", sa.Date), @@ -253,17 +253,30 @@ def test_integer_types(self, connection): assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64") def test_datetime_types(self, connection: sa.Connection): + stmt = sa.Select( + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_datetime", datetime.datetime.now(), sa.DateTime))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_DATETIME", datetime.datetime.now(), sa.DATETIME))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_TIMESTAMP", datetime.datetime.now(), sa.TIMESTAMP))), + ) + + result = connection.execute(stmt).fetchone() + assert result == (b"Timestamp", b"Datetime", b"Timestamp") + + def test_datetime_types_timezone(self, connection: sa.Connection): table = self.tables.test_datetime_types + tzinfo = datetime.timezone(datetime.timedelta(hours=3, minutes=42)) - now_dt = datetime.datetime.now() - now_dt_tz = now_dt.replace(tzinfo=datetime.timezone(datetime.timedelta(hours=3, minutes=42))) - today = now_dt.date() + timestamp_value = datetime.datetime.now() + timestamp_value_tz = timestamp_value.replace(tzinfo=tzinfo) + datetime_value = timestamp_value.replace(microsecond=0) + datetime_value_tz = timestamp_value_tz.replace(microsecond=0) + today = timestamp_value.date() statement = sa.insert(table).values( - datetime=now_dt, - datetime_tz=now_dt_tz, - timestamp=now_dt, - timestamp_tz=now_dt_tz, + datetime=datetime_value, + datetime_tz=datetime_value_tz, + timestamp=timestamp_value, + timestamp_tz=timestamp_value_tz, date=today, # interval=datetime.timedelta(minutes=45), ) @@ -271,12 +284,11 @@ def test_datetime_types(self, connection: sa.Connection): row = connection.execute(sa.select(table)).fetchone() - now_dt_tz_utc = now_dt.replace(tzinfo=datetime.timezone.utc) - datetime.timedelta(hours=3, minutes=42) assert row == ( - now_dt, - now_dt_tz_utc, # YDB doesn't store timezone, so it is always utc - now_dt, - now_dt_tz_utc, # YDB doesn't store timezone, so it is always utc + datetime_value, + datetime_value_tz.astimezone(datetime.timezone.utc), # YDB doesn't store timezone, so it is always utc + timestamp_value, + timestamp_value_tz.astimezone(datetime.timezone.utc), # YDB doesn't store timezone, so it is always utc today, ) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index f3735ba..326b47f 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -132,7 +132,13 @@ def visit_BINARY(self, type_: sa.BINARY, **kw): def visit_BLOB(self, type_: sa.BLOB, **kw): return "String" - def visit_DATETIME(self, type_: sa.TIMESTAMP, **kw): + def visit_datetime(self, type_: sa.TIMESTAMP, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + def visit_DATETIME(self, type_: sa.DATETIME, **kw): + return "DateTime" + + def visit_TIMESTAMP(self, type_: sa.TIMESTAMP, **kw): return "Timestamp" def visit_list_type(self, type_: types.ListType, **kw): @@ -193,7 +199,10 @@ def get_ydb_type( elif isinstance(type_, types.YqlJSON.YqlJSONPathType): ydb_type = ydb.PrimitiveType.Utf8 # Json - + elif isinstance(type_, sa.DATETIME): + ydb_type = ydb.PrimitiveType.Datetime + elif isinstance(type_, sa.TIMESTAMP): + ydb_type = ydb.PrimitiveType.Timestamp elif isinstance(type_, sa.DateTime): ydb_type = ydb.PrimitiveType.Timestamp elif isinstance(type_, sa.Date): @@ -540,7 +549,7 @@ def upsert(table): ydb.PrimitiveType.Yson: sa.TEXT, ydb.PrimitiveType.Date: sa.DATE, ydb.PrimitiveType.Datetime: sa.DATETIME, - ydb.PrimitiveType.Timestamp: sa.DATETIME, + ydb.PrimitiveType.Timestamp: sa.TIMESTAMP, ydb.PrimitiveType.Interval: sa.INTEGER, ydb.PrimitiveType.Bool: sa.BOOLEAN, ydb.PrimitiveType.DyNumber: sa.TEXT, @@ -610,7 +619,9 @@ class YqlDialect(StrCompileDialect): colspecs = { sa.types.JSON: types.YqlJSON, sa.types.JSON.JSONPathType: types.YqlJSON.YqlJSONPathType, - sa.types.DateTime: types.YqlDateTime, + sa.types.DateTime: types.YqlTimestamp, # Because YDB's DateTime doesn't store microseconds + sa.types.DATETIME: types.YqlDateTime, + sa.types.TIMESTAMP: types.YqlTimestamp, } connection_characteristics = util.immutabledict( diff --git a/ydb_sqlalchemy/sqlalchemy/datetime_types.py b/ydb_sqlalchemy/sqlalchemy/datetime_types.py index 651310f..d2f8283 100644 --- a/ydb_sqlalchemy/sqlalchemy/datetime_types.py +++ b/ydb_sqlalchemy/sqlalchemy/datetime_types.py @@ -3,16 +3,28 @@ from sqlalchemy import Dialect from sqlalchemy import types as sqltypes -from sqlalchemy.sql.type_api import _ResultProcessorType +from sqlalchemy.sql.type_api import _BindProcessorType, _ResultProcessorType -class YqlDateTime(sqltypes.DateTime): +class YqlTimestamp(sqltypes.TIMESTAMP): def result_processor(self, dialect: Dialect, coltype: str) -> Optional[_ResultProcessorType[datetime.datetime]]: def process(value: Optional[datetime.datetime]) -> Optional[datetime.datetime]: if value is None: return None + if not self.timezone: + return value return value.replace(tzinfo=datetime.timezone.utc) - if self.timezone: - return process - return None + return process + + +class YqlDateTime(YqlTimestamp, sqltypes.DATETIME): + def bind_processor(self, dialect: Dialect) -> Optional[_BindProcessorType[datetime.datetime]]: + def process(value: Optional[datetime.datetime]) -> Optional[int]: + if value is None: + return None + if not self.timezone: # if timezone is disabled, consider it as utc + value = value.replace(tzinfo=datetime.timezone.utc) + return int(value.timestamp()) + + return process diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 30f9002..c97a3e0 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -3,7 +3,7 @@ from sqlalchemy import ARRAY, ColumnElement, exc, types from sqlalchemy.sql import type_api -from .datetime_types import YqlDateTime # noqa: F401 +from .datetime_types import YqlTimestamp, YqlDateTime # noqa: F401 from .json import YqlJSON # noqa: F401