Skip to content

Commit

Permalink
Add fix for DDL statements
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkaplan committed Jul 6, 2022
1 parent 2ae7b19 commit 318bee9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
27 changes: 20 additions & 7 deletions tests/unit/sqlalchemy/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
String,
Table,
)
from sqlalchemy.schema import CreateTable

from trino.sqlalchemy.dialect import TrinoDialect

Expand All @@ -29,6 +30,13 @@
Column('id', Integer, primary_key=True),
Column('name', String),
)
system_table = Table(
'table',
metadata,
Column('id', Integer, primary_key=True),
schema='information_schema',
trino_catalog='system'
)


@pytest.fixture
Expand All @@ -55,18 +63,23 @@ def test_offset(dialect):


def test_catalogs_argument(dialect):
system_table = Table(
'table',
MetaData(),
Column('id', Integer, primary_key=True),
schema='information_schema',
trino_catalog='system'
)
statement = select(system_table)
query = statement.compile(dialect=dialect)
assert str(query) == 'SELECT "system".information_schema."table".id \nFROM "system".information_schema."table"'


def test_catalogs_create_table(dialect):
statement = CreateTable(system_table)
query = statement.compile(dialect=dialect)
assert str(query) == \
'\n'\
'CREATE TABLE "system".information_schema."table" (\n'\
'\tid INTEGER NOT NULL, \n'\
'\tPRIMARY KEY (id)\n'\
')\n'\
'\n'


def test_cte_insert_order(dialect):
cte = select(table).cte('cte')
statement = insert(table).from_select(table.columns, cte)
Expand Down
10 changes: 7 additions & 3 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,17 @@ def visit_column(self, column, add_to_result_map=None, include_table=True, **kwa
column, add_to_result_map, include_table, **kwargs
)
table = column.table
return self.__add_catalog(sql, table)
return self.add_catalog(sql, table)

def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
sql = super(TrinoSQLCompiler, self).visit_table(
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
)
return self.__add_catalog(sql, table)
return self.add_catalog(sql, table)

@staticmethod
def __add_catalog(sql, table):
def add_catalog(sql, table):
if table is None:
return sql

Expand Down Expand Up @@ -215,3 +215,7 @@ def visit_TIME(self, type_, **kw):

class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS

def format_table(self, table, use_schema=True, name=None):
result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name)
return TrinoSQLCompiler.add_catalog(result, table)

0 comments on commit 318bee9

Please sign in to comment.