Skip to content

Commit

Permalink
fix(dialect, custom-types): support compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
Brooke-white committed Nov 19, 2021
1 parent b22cc98 commit 6fae561
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
34 changes: 27 additions & 7 deletions sqlalchemy_redshift/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
)
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.dialects.postgresql.psycopg2cffi import PGDialect_psycopg2cffi
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.engine import reflection
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import (
Expand Down Expand Up @@ -178,7 +180,18 @@ class RedshiftImpl(postgresql.PostgresqlImpl):
])


class TIMESTAMPTZ(sa.dialects.postgresql.TIMESTAMP):
class RedshiftTypeEngine(TypeEngine):

def _default_dialect(self, default=None):
"""
Returns the default dialect used for TypeEngine compilation yielding String result.
:meth:`~sqlalchemy.sql.type_api.TypeEngine.compile`
"""
return RedshiftDialectMixin()


class TIMESTAMPTZ(RedshiftTypeEngine, sa.dialects.postgresql.TIMESTAMP):
"""
Redshift defines a TIMTESTAMPTZ column type as an alias
of TIMESTAMP WITH TIME ZONE.
Expand All @@ -193,10 +206,13 @@ class TIMESTAMPTZ(sa.dialects.postgresql.TIMESTAMP):
__visit_name__ = 'TIMESTAMPTZ'

def __init__(self, timezone=True, precision=None):
# timezone param must be present as it's provided in base class so the object
# can be instantiated with kwargs
# see :meth:`~sqlalchemy.dialects.postgresql.base.PGDialect._get_column_info`
super(TIMESTAMPTZ, self).__init__(timezone=True, precision=precision)


class TIMETZ(sa.dialects.postgresql.TIME):
class TIMETZ(RedshiftTypeEngine, sa.dialects.postgresql.TIME):
"""
Redshift defines a TIMTETZ column type as an alias
of TIME WITH TIME ZONE.
Expand All @@ -211,10 +227,13 @@ class TIMETZ(sa.dialects.postgresql.TIME):
__visit_name__ = 'TIMETZ'

def __init__(self, timezone=True, precision=None):
# timezone param must be present as it's provided in base class so the object
# can be instantiated with kwargs
# see :meth:`~sqlalchemy.dialects.postgresql.base.PGDialect._get_column_info`
super(TIMETZ, self).__init__(timezone=True, precision=precision)


class GEOMETRY(sa.dialects.postgresql.TEXT):
class GEOMETRY(RedshiftTypeEngine, sa.dialects.postgresql.TEXT):
"""
Redshift defines a GEOMETRY column type
https://docs.aws.amazon.com/redshift/latest/dg/c_Supported_data_types.html
Expand All @@ -233,7 +252,7 @@ def get_dbapi_type(self, dbapi):
return dbapi.GEOMETRY


class SUPER(sa.dialects.postgresql.TEXT):
class SUPER(RedshiftTypeEngine, sa.dialects.postgresql.TEXT):
"""
Redshift defines a SUPER column type
https://docs.aws.amazon.com/redshift/latest/dg/c_Supported_data_types.html
Expand Down Expand Up @@ -261,13 +280,14 @@ def process_bind_param(self, value, dialect):
return value

# Mapping for database schema inspection of Amazon Redshift datatypes
redshift_ischema_names = {
REDSHIFT_ISCHEMA_NAMES = {
"geometry": GEOMETRY,
"super": SUPER,
"time with time zone": TIMETZ,
"timestamp with time zone": TIMESTAMPTZ,
}


class RelationKey(namedtuple('RelationKey', ('name', 'schema'))):
"""
Structured tuple of table/view name and schema name.
Expand Down Expand Up @@ -500,7 +520,7 @@ class RedshiftIdentifierPreparer(PGIdentifierPreparer):
reserved_words = RESERVED_WORDS


class RedshiftDialectMixin(object):
class RedshiftDialectMixin(DefaultDialect):
"""
Define Redshift-specific behavior.
Expand Down Expand Up @@ -551,7 +571,7 @@ def ischema_names(self):
Used in
:meth:`~sqlalchemy.engine.dialects.postgresql.base.PGDialect._get_column_info`.
"""
return {**super(RedshiftDialectMixin, self).ischema_names, **redshift_ischema_names}
return {**super(RedshiftDialectMixin, self).ischema_names, **REDSHIFT_ISCHEMA_NAMES}

@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_dialect_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,10 @@ def test_custom_types_reflection_inspection(
actual = inspect.get_columns(table_name='t1', schema='public')
assert len(actual) == 3
assert isinstance(actual[2]['type'], custom_datatype)


@pytest.mark.parametrize("custom_datatype", redshift_specific_datatypes)
def test_custom_type_compilation(custom_datatype):
dt = custom_datatype()
compiled_dt = dt.compile()
assert compiled_dt == dt.__visit_name__

0 comments on commit 6fae561

Please sign in to comment.