From fc0bf23020de5c7727c2ab096e1736d1a1c6a326 Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Fri, 13 Oct 2017 17:33:51 +0000 Subject: [PATCH] BUG: Fix discover for sa.TIMESTAMP. `discover` was returning a bytes dtype for the `sa.TIMESTAMP` instead of `datetime_`. That bug was because the SQLAlchemy dialect for MSSQL currently directly imports `sa.TIMESTAMP`, causing a collision in `odo.backends.sql.revtypes`. Fix by 1) creating a subclass of `mssql.TIMESTAMP` to use as the key in `revtypes` so that it does not overwrite `sa.TIMESTAMP` 2) assign that subclass to the mssql dialect's `'TIMESTAMP'` (using `ischema_names`), so that the subclass will be returned by the type engine instead of `mssql.TIMESTAMP`. The added test for the (`sa.TIMESTAMP`, `datetime_`) without the fix applied would result in this error: ``` a = Bytes(), b = DateTime(tz=None), path = ('.measure', "['ts']", '.ty') kwargs = {'check_dim': True, 'check_record_order': True} @assert_dshape_equal.register(object, object) def _base_case(a, b, path=None, **kwargs): > assert a == b, '%s != %s\n%s' % (a, b, _fmt_path(path)) E AssertionError: bytes != datetime E path: _.measure['value'].ty ``` This patch includes checks for the other types besides `sa.TIMESTAMP` for better protection against the general case of failures when using `revtypes` to map SQLAlchemy types to dtypes. Those extra cases exposed an issue with `sa.Float(precision=24)`. That case is commented out to keep the fix of this patch on the `sa.TIMESTAMP` mapping. (I would have used `pytest.param` to mark it xfail, but pytest for this project needs to be upgraded first.) See: https://github.com/blaze/odo/issues/567 https://github.com/blaze/odo/pull/568 https://bitbucket.org/zzzeek/sqlalchemy/issues/4092/type-problem-with-mssqltimestamp https://github.com/blaze/blaze/pull/1656 --- odo/backends/sql.py | 15 +++++++++++++- odo/backends/tests/test_sql.py | 38 ++++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/odo/backends/sql.py b/odo/backends/sql.py index bd456068..61b83846 100644 --- a/odo/backends/sql.py +++ b/odo/backends/sql.py @@ -83,6 +83,19 @@ revtypes = dict(map(reversed, types.items())) +# Subclass mssql.TIMESTAMP subclass for use when differentiating between +# mssql.TIMESTAMP and sa.TIMESTAMP. +# At the time of this writing, (mssql.TIMESTAMP == sa.TIMESTAMP) is True, +# which causes a collision when defining the revtypes mappings. +# +# See: +# https://bitbucket.org/zzzeek/sqlalchemy/issues/4092/type-problem-with-mssqltimestamp +class MSSQLTimestamp(mssql.TIMESTAMP): + pass + +# Assign the custom subclass as the type to use instead of `mssql.TIMESTAMP`. +mssql.base.ischema_names['TIMESTAMP'] = MSSQLTimestamp + revtypes.update({ sa.DATETIME: datetime_, sa.TIMESTAMP: datetime_, @@ -103,7 +116,7 @@ mssql.UNIQUEIDENTIFIER: string, # The SQL Server TIMESTAMP value doesn't correspond to the ISO Standard # It is instead just a binary(8) value with no relation to dates or times - mssql.TIMESTAMP: bytes_, + MSSQLTimestamp: bytes_, }) # interval types are special cased in discover_typeengine so remove them from diff --git a/odo/backends/tests/test_sql.py b/odo/backends/tests/test_sql.py index 8b361afa..b0f6e32e 100644 --- a/odo/backends/tests/test_sql.py +++ b/odo/backends/tests/test_sql.py @@ -8,7 +8,20 @@ from functools import partial import datashape -from datashape import discover, dshape, float32, float64, Option +from datashape import ( + date_, + datetime_, + discover, + dshape, + int_, + int64, + float32, + float64, + string, + var, + Option, + R, +) from datashape.util.testing import assert_dshape_equal import numpy as np import pandas as pd @@ -279,6 +292,28 @@ def test_discover_oracle_intervals(freq): assert discover(t) == dshape('var * {dur: ?timedelta[unit="%s"]}' % freq) +@pytest.mark.parametrize( + 'typ,dtype', ( + (sa.DATETIME, datetime_), + (sa.TIMESTAMP, datetime_), + (sa.FLOAT, float64), + (sa.DATE, date_), + (sa.BIGINT, int64), + (sa.INTEGER, int_), + (sa.BIGINT, int64), + (sa.types.NullType, string), + (sa.REAL, float32), + (sa.Float, float64), + # sa.Float(precision=24), float32 (reason="Currently returns float64") + (sa.Float(precision=53), float64), + ), +) +def test_types(typ, dtype): + expected = var * R['value': Option(dtype)] + t = sa.Table('t', sa.MetaData(), sa.Column('value', typ)) + assert_dshape_equal(discover(t), expected) + + def test_mssql_types(): typ = sa.dialects.mssql.BIT() t = sa.Table('t', sa.MetaData(), sa.Column('bit', typ)) @@ -296,7 +331,6 @@ def test_mssql_types(): t = sa.Table('t', sa.MetaData(), sa.Column('uuid', typ)) assert_dshape_equal(discover(t), dshape('var * {uuid: ?string}')) - def test_create_from_datashape(): engine = sa.create_engine('sqlite:///:memory:') ds = dshape('''{bank: var * {name: string, amount: int},