Skip to content

Commit

Permalink
fix(sqla): convert prequery results to native python types (apache#17195
Browse files Browse the repository at this point in the history
)
  • Loading branch information
villebro committed Oct 22, 2021
1 parent 35cbcc4 commit 2ba046f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
13 changes: 8 additions & 5 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
if c not in metrics and c in groupby_series_columns
]
top_groups = self._get_top_groups(
result.df, dimensions, groupby_series_columns
result.df, dimensions, groupby_series_columns, columns_by_name
)
qry = qry.where(top_groups)

Expand Down Expand Up @@ -1436,20 +1436,23 @@ def _get_series_orderby(
return ob

def _get_top_groups(
self, df: pd.DataFrame, dimensions: List[str], groupby_exprs: Dict[str, Any],
self,
df: pd.DataFrame,
dimensions: List[str],
groupby_exprs: Dict[str, Any],
columns_by_name: Dict[str, TableColumn],
) -> ColumnElement:
column_map = {column.column_name: column for column in self.columns}
groups = []
for _unused, row in df.iterrows():
group = []
for dimension in dimensions:
value = row[dimension]
value = utils.normalize_prequery_result_type(row[dimension])

# Some databases like Druid will return timestamps as strings, but
# do not perform automatic casting when comparing these strings to
# a timestamp. For cases like this we convert the value from a
# string into a timestamp.
if column_map[dimension].is_temporal and isinstance(value, str):
if columns_by_name[dimension].is_temporal and isinstance(value, str):
dttm = dateutil.parser.parse(value)
value = text(self.db_engine_spec.convert_dttm("TIMESTAMP", dttm))

Expand Down
32 changes: 32 additions & 0 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,3 +1813,35 @@ def escape_sqla_query_binds(sql: str) -> str:
sql = sql.replace(bind, bind.replace(":", "\\:"))
processed_binds.add(bind)
return sql


def normalize_prequery_result_type(
value: Union[str, int, float, bool, np.generic]
) -> Union[str, int, float, bool]:
"""
Convert a value that is potentially a numpy type into its equivalent Python type.
:param value: primitive datatype in either numpy or python format
:return: equivalent primitive python type
>>> normalize_prequery_result_type('abc')
'abc'
>>> normalize_prequery_result_type(True)
True
>>> normalize_prequery_result_type(123)
123
>>> normalize_prequery_result_type(np.int16(123))
123
>>> normalize_prequery_result_type(np.uint32(123))
123
>>> normalize_prequery_result_type(np.int64(123))
123
>>> normalize_prequery_result_type(123.456)
123.456
>>> normalize_prequery_result_type(np.float32(123.456))
123.45600128173828
>>> normalize_prequery_result_type(np.float64(123.456))
123.456
"""
if isinstance(value, np.generic):
return value.item()
return value

0 comments on commit 2ba046f

Please sign in to comment.