Skip to content

Commit

Permalink
quote column name if db requires (apache#15465)
Browse files Browse the repository at this point in the history
Co-authored-by: hughhhh <hughmil3s@gmail.com>
  • Loading branch information
eschutho and hughhhh committed Jul 2, 2021
1 parent 68704a5 commit 80b8df0
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 2 deletions.
18 changes: 17 additions & 1 deletion superset/connectors/sqla/models.py
Expand Up @@ -60,7 +60,14 @@
)
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql import (
column,
ColumnElement,
literal_column,
quoted_name,
table,
text,
)
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom, TextClause
from sqlalchemy.sql.selectable import Alias, TableClause
Expand Down Expand Up @@ -912,16 +919,25 @@ def make_sqla_column_compatible(
self, sqla_col: Column, label: Optional[str] = None
) -> Column:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
also adds quotes to the column if engine is configured for quotes.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.db_engine_spec

# add quotes to column
if db_engine_spec.force_column_alias_quotes:
sqla_col = column(
quoted_name(sqla_col.name, True), sqla_col.type, sqla_col.is_literal
)

# add quotes to tables
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)

sqla_col.key = label_expected
return sqla_col

Expand Down
57 changes: 56 additions & 1 deletion tests/integration_tests/core_tests.py
Expand Up @@ -25,6 +25,9 @@
import logging
from typing import Dict, List
from urllib.parse import quote

from sqlalchemy.sql import column, quoted_name, literal_column
from sqlalchemy import select
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
)
Expand All @@ -40,7 +43,7 @@
import sqlalchemy as sqla
from sqlalchemy.exc import SQLAlchemyError
from superset.models.cache import CacheKey
from superset.utils.core import get_example_database
from superset.utils.core import get_example_database, get_or_create_db
from tests.integration_tests.conftest import with_feature_flags
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_with_slice,
Expand Down Expand Up @@ -898,6 +901,58 @@ def test_comments_in_sqlatable_query(self):
rendered_query = str(table.get_from_clause())
self.assertEqual(clean_query, rendered_query)

def test_make_column_compatible(self):
"""
DB Eng Specs: Make column compatible
"""

# with force_column_alias_quotes enabled
snowflake_database = get_or_create_db("snowflake", "snowflake://")

table = SqlaTable(
table_name="test_columns_with_alias_quotes", database=snowflake_database,
)

col = table.make_sqla_column_compatible(column("foo"))
s = select([col])
self.assertEqual(str(s), 'SELECT "foo" AS "foo"')

# with literal_column
table = SqlaTable(
table_name="test_columns_with_alias_quotes_on_literal_column",
database=snowflake_database,
)

col = table.make_sqla_column_compatible(literal_column("foo"))
s = select([col])
self.assertEqual(str(s), 'SELECT foo AS "foo"')

# with force_column_alias_quotes NOT enabled
postgres_database = get_or_create_db("postgresql", "postgresql://")

table = SqlaTable(
table_name="test_columns_with_no_quotes", database=postgres_database,
)

col = table.make_sqla_column_compatible(column("foo"))
s = select([col])
self.assertEqual(str(s), "SELECT foo AS foo")

# with literal_column
table = SqlaTable(
table_name="test_columns_with_no_quotes_on_literal_column",
database=postgres_database,
)

col = table.make_sqla_column_compatible(literal_column("foo"))
s = select([col])
self.assertEqual(str(s), "SELECT foo AS foo")

# cleanup
db.session.delete(snowflake_database)
db.session.delete(postgres_database)
db.session.commit()

def test_slice_payload_no_datasource(self):
self.login(username="admin")
data = self.get_json_resp("/superset/explore_json/", raise_on_error=False)
Expand Down
16 changes: 16 additions & 0 deletions tests/integration_tests/db_engine_specs/snowflake_tests.py
Expand Up @@ -16,13 +16,29 @@
# under the License.
import json

from sqlalchemy import column

from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.core import Database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec


class TestSnowflakeDbEngineSpec(TestDbEngineSpec):
def test_snowflake_sqla_column_label(self):
"""
DB Eng Specs (snowflake): Test column label
"""
test_cases = {
"Col": "Col",
"SUM(x)": "SUM(x)",
"SUM[x]": "SUM[x]",
"12345_col": "12345_col",
}
for original, expected in test_cases.items():
actual = SnowflakeEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)

def test_convert_dttm(self):
dttm = self.get_dttm()

Expand Down

0 comments on commit 80b8df0

Please sign in to comment.