Skip to content

Commit

Permalink
feat: add enforce URI query params with a specific for MySQL (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar committed Apr 18, 2023
1 parent e9b4022 commit 0ad6c87
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
11 changes: 7 additions & 4 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
top_keywords: Set[str] = {"TOP"}
# A set of disallowed connection query parameters
disallow_uri_query_params: Set[str] = set()
# A Dict of query parameters that will always be used on every connection
enforce_uri_query_params: Dict[str, Any] = {}

force_column_alias_quotes = False
arraysize = 0
Expand Down Expand Up @@ -1089,11 +1091,12 @@ def adjust_engine_params( # pylint: disable=unused-argument
``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
given query is running in order to enforce permissions (see #23385 and #23401).
Currently, changing the catalog is not supported. The method acceps a catalog so
that when catalog support is added to Superse the interface remains the same. This
is important because DB engine specs can be installed from 3rd party packages.
Currently, changing the catalog is not supported. The method accepts a catalog so
that when catalog support is added to Superset the interface remains the same.
This is important because DB engine specs can be installed from 3rd party
packages.
"""
return uri, connect_args
return uri, {**connect_args, **cls.enforce_uri_query_params}

@classmethod
def patch(cls) -> None:
Expand Down
6 changes: 5 additions & 1 deletion superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
),
}
disallow_uri_query_params = {"local_infile"}
enforce_uri_query_params = {"local_infile": 0}

@classmethod
def convert_dttm(
Expand All @@ -198,10 +199,13 @@ def adjust_engine_params(
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
uri, new_connect_args = super(
MySQLEngineSpec, MySQLEngineSpec
).adjust_engine_params(uri, connect_args, catalog, schema)
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))

return uri, connect_args
return uri, new_connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
15 changes: 15 additions & 0 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ def test_impersonate_user_presto(self, mocked_create_engine):
"password": "original_user_password",
}

@unittest.skipUnless(
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
)
@mock.patch("superset.models.core.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
database_name="test_database",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "mysql://user:password@localhost"
assert call_args[1]["connect_args"]["local_infile"] == 0

@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
Expand Down
34 changes: 32 additions & 2 deletions tests/unit_tests/db_engine_specs/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from datetime import datetime
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Tuple, Type
from unittest.mock import Mock, patch

import pytest
Expand All @@ -33,7 +33,7 @@
TINYINT,
TINYTEXT,
)
from sqlalchemy.engine.url import make_url
from sqlalchemy.engine.url import make_url, URL

from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
Expand Down Expand Up @@ -119,6 +119,36 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
MySQLEngineSpec.validate_database_uri(url)


@pytest.mark.parametrize(
"sqlalchemy_uri,connect_args,returns",
[
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
(
"mysql://user:password@host/db1",
{"param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
(
"mysql://user:password@host/db1",
{"local_infile": 1, "param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
],
)
def test_adjust_engine_params(
sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec

url = make_url(sqlalchemy_uri)
returned_url, returned_connect_args = MySQLEngineSpec.adjust_engine_params(
url, connect_args
)
assert returned_connect_args == returns


@patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
Expand Down

0 comments on commit 0ad6c87

Please sign in to comment.