Skip to content

Commit

Permalink
feat: add global max row limit (apache#16683)
Browse files Browse the repository at this point in the history
* feat: add global max limit

* fix lint and tests

* leave SAMPLES_ROW_LIMIT unchanged

* fix sample rowcount test

* replace max global limit with existing sql max row limit

* fix test

* make max_limit optional in util

* improve comments
  • Loading branch information
villebro committed Sep 16, 2021
1 parent 633f29f commit 4e3d4f6
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 27 deletions.
4 changes: 0 additions & 4 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import copy
import math
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING

from flask_babel import _
Expand Down Expand Up @@ -131,14 +130,11 @@ def _get_samples(
query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False
) -> Dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
row_limit = query_obj.row_limit or math.inf
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
query_obj.orderby = []
query_obj.metrics = []
query_obj.post_processing = []
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
query_obj.row_offset = 0
query_obj.columns = [o.column_name for o in datasource.columns]
return _get_full(query_context, query_obj, force_cached)

Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ def __init__(
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
self.queries = [QueryObject(self, **query_obj) for query_obj in queries]
self.cache_values = {
"datasource": datasource,
"queries": queries,
Expand Down
17 changes: 14 additions & 3 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING

from flask_babel import gettext as _
from pandas import DataFrame
Expand All @@ -28,6 +28,7 @@
from superset.typing import Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
apply_max_row_limit,
ChartDataResultType,
DatasourceDict,
DTTM_ALIAS,
Expand All @@ -41,6 +42,10 @@
from superset.utils.hashing import md5_sha_from_dict
from superset.views.utils import get_time_range_endpoints

if TYPE_CHECKING:
from superset.common.query_context import QueryContext # pragma: no cover


config = app.config
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -103,6 +108,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes

def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
query_context: "QueryContext",
annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False,
Expand Down Expand Up @@ -146,7 +152,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.result_type = result_type
self.result_type = result_type or query_context.result_type
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
self.annotation_layers = [
layer
Expand Down Expand Up @@ -186,7 +192,12 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
for x in metrics
]

self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if self.result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
self.row_offset = row_offset or 0
self.filter = filters or []
self.series_limit = series_limit
Expand Down
8 changes: 3 additions & 5 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
# default viz used in chart explorer
DEFAULT_VIZ_TYPE = "table"

# default row limit when requesting chart data
ROW_LIMIT = 50000
VIZ_ROW_LIMIT = 10000
# max rows retreieved when requesting samples from datasource in explore view
# default row limit when requesting samples from datasource in explore view
SAMPLES_ROW_LIMIT = 1000
# max rows retrieved by filter select auto complete
FILTER_SELECT_ROW_LIMIT = 10000
Expand Down Expand Up @@ -671,9 +671,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
# Set this API key to enable Mapbox visualizations
MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "")

# Maximum number of rows returned from a database
# in async mode, no more than SQL_MAX_ROW will be returned and stored
# in the results backend. This also becomes the limit when exporting CSVs
# Maximum number of rows returned for any analytical database query
SQL_MAX_ROW = 100000

# Maximum number of rows displayed in SQL Lab UI
Expand Down
22 changes: 22 additions & 0 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,3 +1762,25 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool:
return bool(strtobool(bool_str.lower()))
except ValueError:
return False


def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
"""
Override row limit if max global limit is defined
:param limit: requested row limit
:param max_limit: Maximum allowed row limit
:return: Capped row limit
>>> apply_max_row_limit(100000, 10)
10
>>> apply_max_row_limit(10, 100000)
10
>>> apply_max_row_limit(0, 10000)
10000
"""
if max_limit is None:
max_limit = current_app.config["SQL_MAX_ROW"]
if limit != 0:
return min(max_limit, limit)
return max_limit
5 changes: 3 additions & 2 deletions superset/utils/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@

from flask import g

from superset import app, is_feature_enabled
from superset import is_feature_enabled
from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
from superset.utils.core import apply_max_row_limit
from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name

Expand Down Expand Up @@ -102,7 +103,7 @@ def _get_template_params(query_params: Dict[str, Any]) -> Dict[str, Any]:

@staticmethod
def _get_limit_param(query_params: Dict[str, Any]) -> int:
limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"]
limit = apply_max_row_limit(query_params.get("queryLimit") or 0)
if limit < 0:
logger.warning(
"Invalid limit of %i specified. Defaulting to max limit.", limit
Expand Down
5 changes: 3 additions & 2 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
from superset.utils import core as utils, csv
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache
from superset.utils.core import ReservedUrlParameters
from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
from superset.utils.dates import now_as_float
from superset.utils.decorators import check_dashboard_access
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
Expand Down Expand Up @@ -897,8 +897,9 @@ def filter( # pylint: disable=no-self-use
return json_error_response(DATASOURCE_MISSING_ERR)

datasource.raise_for_access()
row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"])
payload = json.dumps(
datasource.values_for_column(column, config["FILTER_SELECT_ROW_LIMIT"]),
datasource.values_for_column(column, row_limit),
default=utils.json_int_dttm_ser,
ignore_nan=True,
)
Expand Down
7 changes: 4 additions & 3 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
apply_max_row_limit,
DTTM_ALIAS,
ExtraFiltersReasonType,
JS_MAX_INTEGER,
Expand Down Expand Up @@ -324,7 +325,10 @@ def query_obj(self) -> QueryObjectDict: # pylint: disable=too-many-locals
)
limit = int(self.form_data.get("limit") or 0)
timeseries_limit_metric = self.form_data.get("timeseries_limit_metric")

# apply row limit to query
row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"])
row_limit = apply_max_row_limit(row_limit)

# default order direction
order_desc = self.form_data.get("order_desc", True)
Expand Down Expand Up @@ -1687,9 +1691,6 @@ class HistogramViz(BaseViz):
def query_obj(self) -> QueryObjectDict:
"""Returns the query object for this visualization"""
query_obj = super().query_obj()
query_obj["row_limit"] = self.form_data.get(
"row_limit", int(config["VIZ_ROW_LIMIT"])
)
numeric_columns = self.form_data.get("all_columns_x")
if numeric_columns is None:
raise QueryObjectValidationError(
Expand Down
62 changes: 58 additions & 4 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Unit tests for Superset"""
import json
import unittest
from datetime import datetime, timedelta
from datetime import datetime
from io import BytesIO
from typing import Optional
from unittest import mock
Expand Down Expand Up @@ -1203,18 +1203,54 @@ def test_chart_data_default_row_limit(self):
self.login(username="admin")
request_payload = get_query_context("birth_names")
del request_payload["queries"][0]["row_limit"]

rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 7)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_actions.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10},
)
def test_chart_data_sql_max_row_limit(self):
"""
Chart data API: Ensure row count doesn't exceed max global row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
)
def test_chart_data_sample_default_limit(self):
"""
Chart data API: Ensure sample response row count defaults to config defaults
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_actions.config",
{**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15},
)
def test_chart_data_default_sample_limit(self):
def test_chart_data_sample_custom_limit(self):
"""
Chart data API: Ensure sample response row count doesn't exceed default limit
Chart data API: Ensure requested sample response row count is between
default and SQL max row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
Expand All @@ -1223,6 +1259,24 @@ def test_chart_data_default_sample_limit(self):
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5},
)
def test_chart_data_sql_max_row_sample_limit(self):
"""
Chart data API: Ensure requested sample response row count doesn't
exceed SQL max row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)

def test_chart_data_incorrect_result_type(self):
Expand Down
17 changes: 14 additions & 3 deletions tests/integration_tests/charts/schema_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,25 @@
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
from typing import Any, Dict, Tuple
from unittest import mock

import pytest

from marshmallow import ValidationError
from tests.integration_tests.test_app import app
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
)
from tests.integration_tests.fixtures.query_context import get_query_context


class TestSchema(SupersetTestCase):
@mock.patch(
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 5000},
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_limit_and_offset(self):
self.login(username="admin")
payload = get_query_context("birth_names")
Expand All @@ -36,7 +44,7 @@ def test_query_context_limit_and_offset(self):
payload["queries"][0].pop("row_offset", None)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
self.assertEqual(query_object.row_limit, 5000)
self.assertEqual(query_object.row_offset, 0)

# Valid limit and offset
Expand All @@ -55,12 +63,14 @@ def test_query_context_limit_and_offset(self):
self.assertIn("row_limit", context.exception.messages["queries"][0])
self.assertIn("row_offset", context.exception.messages["queries"][0])

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_null_timegrain(self):
self.login(username="admin")
payload = get_query_context("birth_names")
payload["queries"][0]["extras"]["time_grain_sqla"] = None
_ = ChartDataQueryContextSchema().load(payload)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_series_limit(self):
self.login(username="admin")
payload = get_query_context("birth_names")
Expand All @@ -82,6 +92,7 @@ def test_query_context_series_limit(self):
}
_ = ChartDataQueryContextSchema().load(payload)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_null_post_processing_op(self):
self.login(username="admin")
payload = get_query_context("birth_names")
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_schema_deserialization(self):
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"])

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_cache(self):
table_name = "birth_names"
table = self.get_table(name=table_name)
Expand Down

0 comments on commit 4e3d4f6

Please sign in to comment.