Skip to content

Commit

Permalink
fix(dialect, metadata): support inspection of Redshift datatypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Brooke-white committed Nov 16, 2021
1 parent a3eff80 commit b22cc98
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
25 changes: 21 additions & 4 deletions sqlalchemy_redshift/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ class TIMESTAMPTZ(sa.dialects.postgresql.TIMESTAMP):

__visit_name__ = 'TIMESTAMPTZ'

def __init__(self):
super(TIMESTAMPTZ, self).__init__(timezone=True)
def __init__(self, timezone=True, precision=None):
super(TIMESTAMPTZ, self).__init__(timezone=True, precision=precision)


class TIMETZ(sa.dialects.postgresql.TIME):
Expand All @@ -210,8 +210,8 @@ class TIMETZ(sa.dialects.postgresql.TIME):

__visit_name__ = 'TIMETZ'

def __init__(self):
super(TIMETZ, self).__init__(timezone=True)
def __init__(self, timezone=True, precision=None):
super(TIMETZ, self).__init__(timezone=True, precision=precision)


class GEOMETRY(sa.dialects.postgresql.TEXT):
Expand Down Expand Up @@ -260,6 +260,13 @@ def process_bind_param(self, value, dialect):
return json.dumps(value)
return value

# Mapping for database schema inspection of Amazon Redshift datatypes
redshift_ischema_names = {
"geometry": GEOMETRY,
"super": SUPER,
"time with time zone": TIMETZ,
"timestamp with time zone": TIMESTAMPTZ,
}

class RelationKey(namedtuple('RelationKey', ('name', 'schema'))):
"""
Expand Down Expand Up @@ -536,6 +543,16 @@ def __init__(self, *args, **kw):
# Redshift does not support user-created domains.
self._domains = None

@property
def ischema_names(self):
"""
Returns information about datatypes supported by Amazon Redshift.
Used in
:meth:`~sqlalchemy.engine.dialects.postgresql.base.PGDialect._get_column_info`.
"""
return {**super(RedshiftDialectMixin, self).ischema_names, **redshift_ischema_names}

@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
"""
Expand Down
31 changes: 31 additions & 0 deletions tests/test_dialect_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import sqlalchemy_redshift.dialect
import sqlalchemy
from sqlalchemy.engine import reflection
from sqlalchemy import MetaData


def test_defined_types():
Expand Down Expand Up @@ -130,3 +132,32 @@ def test_custom_types_ddl_generation(
create_table = sqlalchemy.schema.CreateTable(table)
actual = compiler.process(create_table)
assert expected == actual


redshift_specific_datatypes = [
sqlalchemy_redshift.dialect.GEOMETRY,
sqlalchemy_redshift.dialect.SUPER,
sqlalchemy_redshift.dialect.TIMETZ,
sqlalchemy_redshift.dialect.TIMESTAMPTZ
]


@pytest.mark.parametrize("custom_datatype", redshift_specific_datatypes)
def test_custom_types_reflection_inspection(
custom_datatype, redshift_engine
):
metadata = MetaData(bind=redshift_engine)
sqlalchemy.Table(
't1',
metadata,
sqlalchemy.Column('id', sqlalchemy.INTEGER, primary_key=True),
sqlalchemy.Column('name', sqlalchemy.String),
sqlalchemy.Column('test_col', custom_datatype),
schema='public'
)
metadata.create_all()
inspect = reflection.Inspector.from_engine(redshift_engine)

actual = inspect.get_columns(table_name='t1', schema='public')
assert len(actual) == 3
assert isinstance(actual[2]['type'], custom_datatype)

0 comments on commit b22cc98

Please sign in to comment.