diff --git a/python/multicorn/__init__.py b/python/multicorn/__init__.py index 81ff3648..a54d6f98 100755 --- a/python/multicorn/__init__.py +++ b/python/multicorn/__init__.py @@ -224,17 +224,13 @@ def can_pushdown_upperrel(self): { "groupby_supported": , # can be ommited if false - "agg_functions": { - : , - ... - }, + "agg_functions": ["min", "max", "sum", ...], "supported_operators": [">", "<", "=", ...] } - Each entry in `agg_functions` dict corresponds to a maping between - the name of a aggregation function in PostgreSQL, and the equivalent - foreign function. If no mapping exists for an aggregate function any - queries containing it won't be pushed down. + Each entry in `agg_functions` list corresponds to the name of a + aggregation function in PostgreSQL, which the FDW can pushdown. + If a query has a function not in this list it won't be pushed down. The `supported_operators` entry lists all operators that can be used in qual (WHERE) clauses so that the aggregation pushdown will still diff --git a/python/multicorn/sqlalchemyfdw.py b/python/multicorn/sqlalchemyfdw.py index 643d042a..8a11b140 100644 --- a/python/multicorn/sqlalchemyfdw.py +++ b/python/multicorn/sqlalchemyfdw.py @@ -159,7 +159,7 @@ from sqlalchemy import create_engine from sqlalchemy.engine.url import make_url, URL from sqlalchemy.exc import UnsupportedCompilationError -from sqlalchemy.sql import select, operators as sqlops, and_ +from sqlalchemy.sql import select, operators as sqlops, func, and_ from sqlalchemy.sql.expression import nullsfirst, nullslast, text from . import ForeignDataWrapper, TableDefinition, ColumnDefinition @@ -225,6 +225,15 @@ def _parse_url_from_options(fdw_options): setattr(url, param, fdw_options[param]) return url +_PG_AGG_FUNC_MAPPING = { + "avg": func.avg, + "min": func.min, + "max": func.max, + "sum": func.sum, + "count": func.count, + "count.*": func.count +} + OPERATORS = { "=": operator.eq, @@ -401,32 +410,65 @@ def can_sort(self, sortkeys): return [] return sortkeys - def explain(self, quals, columns, sortkeys=None, verbose=False): + def can_pushdown_upperrel(self): + return { + "groupby_supported": True, + "agg_functions": list(_PG_AGG_FUNC_MAPPING), + "operators_supported": [op for op in OPERATORS if isinstance(op, str)], + } + + def explain(self, quals, columns, sortkeys=None, aggs=None, group_clauses=None, verbose=False): sortkeys = sortkeys or [] - statement = self._build_statement(quals, columns, sortkeys) - return [str(statement)] + statement = self._build_statement(quals, columns, sortkeys, aggs=aggs, group_clauses=group_clauses) + + # The literal_binds option below ensures that qualifiers are displayed as raw strings + # instead of being masked by placeholder bound parameters, thus providing more transparency + # during use (and testing). + return ["\n" + str(statement.compile(dialect=self.engine.dialect, compile_kwargs={"literal_binds": True})) + "\n"] + + def _build_statement(self, quals, columns, sortkeys, aggs=None, group_clauses=None): + is_aggregation = aggs or group_clauses + + if not is_aggregation: + statement = select([self.table]) + else: + target_list = [] + + if group_clauses is not None: + target_list = [self.table.c[col] for col in group_clauses] + + if aggs is not None: + for agg_name, agg_props in aggs.items(): + agg_func = _PG_AGG_FUNC_MAPPING[agg_props["function"]] + agg_target = agg_func() if agg_props["column"] == "*" else agg_func(self.table.c[agg_props["column"]]) + + target_list.append(agg_target.label(agg_name)) + + statement = select(*target_list).select_from(self.table) - def _build_statement(self, quals, columns, sortkeys): - statement = select([self.table]) clauses = [] for qual in quals: operator = OPERATORS.get(qual.operator, None) if operator: clauses.append(operator(self.table.c[qual.field_name], qual.value)) else: - log_to_postgres("Qual not pushed to foreign db: %s" % qual, WARNING) + log_to_postgres(f"Qual {qual} with operator {qual.operator} not pushed to foreign db", ERROR if is_aggregation else WARNING) if clauses: statement = statement.where(and_(*clauses)) - if columns: - columns = [self.table.c[col] for col in columns] - elif columns is None: - columns = [self.table] - else: - # This is the case where we're asked to output no columns (just a list of empty dicts) - # in a SELECT 1, but I can't get SQLAlchemy to emit `SELECT 1 FROM some_table`, so - # we just select a single column. - columns = [self.table.c[list(self.table.c)[0].name]] - statement = statement.with_only_columns(columns) + + if not is_aggregation: + if columns: + columns = [self.table.c[col] for col in columns] + elif columns is None: + columns = [self.table] + else: + # This is the case where we're asked to output no columns (just a list of empty dicts) + # in a SELECT 1, but I can't get SQLAlchemy to emit `SELECT 1 FROM some_table`, so + # we just select a single column. + columns = [self.table.c[list(self.table.c)[0].name]] + statement = statement.with_only_columns(columns) + elif group_clauses is not None: + statement = statement.group_by(*[self.table.c[col] for col in group_clauses]) for sortkey in sortkeys: column = self.table.c[sortkey.attname] @@ -440,12 +482,13 @@ def _build_statement(self, quals, columns, sortkeys): statement = statement.order_by(column) return statement - def execute(self, quals, columns, sortkeys=None): + def execute(self, quals, columns, sortkeys=None, aggs=None, group_clauses=None): """ The quals are turned into an and'ed where clause. """ sortkeys = sortkeys or [] - statement = self._build_statement(quals, columns, sortkeys) + is_aggregation = aggs or group_clauses + statement = self._build_statement(quals, columns, sortkeys, aggs=aggs, group_clauses=group_clauses) log_to_postgres(str(statement), DEBUG) # If a dialect doesn't support streaming using server-side cursors, @@ -469,6 +512,7 @@ def execute(self, quals, columns, sortkeys=None): for item in rs: yield dict(item) returned += 1 + if self.batch_size is None or returned < self.batch_size: return diff --git a/src/python.c b/src/python.c index 36963196..4474e05c 100644 --- a/src/python.c +++ b/src/python.c @@ -1771,14 +1771,9 @@ canPushdownUpperrel(MulticornPlanState * state) Py_XDECREF(p_object); /* Determine which aggregation functions are supported */ - p_object = PyMapping_GetItemString(p_upperrel_pushdown, "agg_functions"); - if (p_object != NULL && p_object != Py_None) - { - p_agg_funcs = PyMapping_Keys(p_object); - pythonUnicodeSequenceToList(p_agg_funcs, &state->agg_functions); - Py_DECREF(p_agg_funcs); - } - Py_XDECREF(p_object); + p_agg_funcs = PyMapping_GetItemString(p_upperrel_pushdown, "agg_functions"); + pythonUnicodeSequenceToList(p_agg_funcs, &state->agg_functions); + Py_XDECREF(p_agg_funcs); /* Construct supported qual operators list */ p_ops = PyMapping_GetItemString(p_upperrel_pushdown, "operators_supported");