Skip to content

Commit

Permalink
fix: pass valid SQL to SM (apache#27464)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored and qleroy committed Apr 28, 2024
1 parent 22bcd3e commit f1492a2
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 2 deletions.
2 changes: 1 addition & 1 deletion superset/commands/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
return populate_owner_list(owner_ids, default_to_user=True)


class UpdateMixin: # pylint: disable=too-few-public-methods
class UpdateMixin:
@staticmethod
def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
"""
Expand Down
6 changes: 5 additions & 1 deletion superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
DatasetInvalidPermissionEvaluationException,
SupersetSecurityException,
)
from superset.jinja_context import get_template_processor
from superset.security.guest_token import (
GuestToken,
GuestTokenResources,
Expand Down Expand Up @@ -1956,11 +1957,14 @@ def raise_for_access(
return

if query:
# make sure the quuery is valid SQL by rendering any Jinja
processor = get_template_processor(database=query.database)
rendered_sql = processor.process_template(query.sql)
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
for table_ in sql_parse.ParsedQuery(
query.sql,
rendered_sql,
engine=database.db_engine_spec.engine,
).tables
}
Expand Down
16 changes: 16 additions & 0 deletions tests/unit_tests/commands/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
35 changes: 35 additions & 0 deletions tests/unit_tests/security/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from superset.extensions import appbuilder
from superset.models.slice import Slice
from superset.security.manager import SupersetSecurityManager
from superset.sql_parse import Table
from superset.superset_typing import AdhocMetric
from superset.utils.core import override_user

Expand Down Expand Up @@ -245,6 +246,40 @@ def test_raise_for_access_query_default_schema(
)


def test_raise_for_access_jinja_sql(mocker: MockFixture, app_context: None) -> None:
"""
Test that Jinja gets rendered to SQL.
"""
sm = SupersetSecurityManager(appbuilder)
mocker.patch.object(sm, "can_access_database", return_value=False)
mocker.patch.object(sm, "get_schema_perm", return_value="[PostgreSQL].[public]")
mocker.patch.object(sm, "can_access", return_value=False)
mocker.patch.object(sm, "is_guest_user", return_value=False)
get_table_access_error_object = mocker.patch.object(
sm, "get_table_access_error_object"
)
SqlaTable = mocker.patch("superset.connectors.sqla.models.SqlaTable")
SqlaTable.query_datasources_by_name.return_value = []

database = mocker.MagicMock()
database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock()
query.database = database
query.sql = "SELECT * FROM {% if True %}ab_user{% endif %} WHERE 1=1"

with pytest.raises(SupersetSecurityException):
sm.raise_for_access(
database=None,
datasource=None,
query=query,
query_context=None,
table=None,
viz=None,
)

get_table_access_error_object.assert_called_with({Table("ab_user", "public")})


def test_raise_for_access_chart_for_datasource_permission(
mocker: MockFixture,
app_context: None,
Expand Down

0 comments on commit f1492a2

Please sign in to comment.