Skip to content

Commit

Permalink
Add ability to specify catalog in SQLAlchemy Table objects
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkaplan authored and hashhar committed Aug 16, 2022
1 parent a0f524b commit aee6064
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
27 changes: 26 additions & 1 deletion tests/unit/sqlalchemy/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,24 @@
String,
Table,
)
from sqlalchemy.schema import CreateTable

from trino.sqlalchemy.dialect import TrinoDialect

metadata = MetaData()
table = Table(
'table',
metadata,
Column('id', Integer, primary_key=True),
Column('id', Integer),
Column('name', String),
)
table_with_catalog = Table(
'table',
metadata,
Column('id', Integer),
schema='default',
trino_catalog='other'
)


@pytest.fixture
Expand Down Expand Up @@ -64,3 +72,20 @@ def test_cte_insert_order(dialect):
'FROM "table")\n'\
' SELECT cte.id, cte.name \n'\
'FROM cte'


def test_catalogs_argument(dialect):
statement = select(table_with_catalog)
query = statement.compile(dialect=dialect)
assert str(query) == 'SELECT default."table".id \nFROM "other".default."table"'


def test_catalogs_create_table(dialect):
statement = CreateTable(table_with_catalog)
query = statement.compile(dialect=dialect)
assert str(query) == \
'\n'\
'CREATE TABLE "other".default."table" (\n'\
'\tid INTEGER\n'\
')\n'\
'\n'
40 changes: 40 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlalchemy.sql import compiler
try:
from sqlalchemy.sql.expression import (
Alias,
CTE,
Subquery,
)
except ImportError:
# For SQLAlchemy versions < 1.4, the CTE and Subquery classes did not explicitly exist
from sqlalchemy.sql.expression import Alias
CTE = type(None)
Subquery = type(None)

# https://trino.io/docs/current/language/reserved.html
RESERVED_WORDS = {
Expand Down Expand Up @@ -102,6 +113,31 @@ def limit_clause(self, select, **kw):
text += "\nLIMIT " + self.process(select._limit_clause, **kw)
return text

def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
sql = super(TrinoSQLCompiler, self).visit_table(
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
)
return self.add_catalog(sql, table)

@staticmethod
def add_catalog(sql, table):
if table is None:
return sql

if isinstance(table, (Alias, CTE, Subquery)):
return sql

if (
'trino' not in table.dialect_options
or 'catalog' not in table.dialect_options['trino']
):
return sql

catalog = table.dialect_options['trino']['catalog']
sql = f'"{catalog}".{sql}'
return sql


class TrinoDDLCompiler(compiler.DDLCompiler):
pass
Expand Down Expand Up @@ -173,3 +209,7 @@ def visit_TIME(self, type_, **kw):

class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS

def format_table(self, table, use_schema=True, name=None):
result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name)
return TrinoSQLCompiler.add_catalog(result, table)

0 comments on commit aee6064

Please sign in to comment.