Skip to content

Commit

Permalink
Add support for aggregation pushdow
Browse files Browse the repository at this point in the history
  • Loading branch information
Parth-Brahmbhatt authored and martint committed Jun 11, 2020
1 parent a1858e5 commit 4052f33
Show file tree
Hide file tree
Showing 22 changed files with 744 additions and 58 deletions.
Expand Up @@ -18,6 +18,7 @@
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorInsertTableHandle;
Expand All @@ -33,7 +34,6 @@
import io.prestosql.spi.connector.ConstraintApplicationResult;
import io.prestosql.spi.connector.LimitApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.connector.SystemTable;
Expand Down
Expand Up @@ -25,6 +25,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.airlift.log.Logger;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorMetadata;
Expand All @@ -37,7 +38,6 @@
import io.prestosql.spi.connector.LimitApplicationResult;
import io.prestosql.spi.connector.NotFoundException;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.connector.TableNotFoundException;
Expand Down
Expand Up @@ -49,6 +49,7 @@
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.StandardErrorCode;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.CatalogSchemaName;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
Expand All @@ -68,7 +69,6 @@
import io.prestosql.spi.connector.DiscretePredicates;
import io.prestosql.spi.connector.InMemoryRecordSet;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.SchemaNotFoundException;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
Expand Down
Expand Up @@ -60,6 +60,7 @@
import io.prestosql.spi.Page;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorInsertTableHandle;
Expand All @@ -84,7 +85,6 @@
import io.prestosql.spi.connector.ConstraintApplicationResult;
import io.prestosql.spi.connector.DiscretePredicates;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.RecordCursor;
import io.prestosql.spi.connector.RecordPageSource;
import io.prestosql.spi.connector.SchemaTableName;
Expand Down Expand Up @@ -3058,7 +3058,7 @@ public void testApplyProjection()
assertProjectionResult(projectionResult, false, expectedProjections, expectedAssignments);

// Round-2: input projections [symbol_2.int0 and onelevelrow0#f_int0]. Virtual handle is reused.
ProjectionApplicationResult.Assignment newlyCreatedColumn = getOnlyElement(projectionResult.get().getAssignments().stream()
Assignment newlyCreatedColumn = getOnlyElement(projectionResult.get().getAssignments().stream()
.filter(handle -> handle.getVariable().equals("onelevelrow0#f_int0"))
.collect(toList()));
inputAssignments = ImmutableMap.<String, ColumnHandle>builder()
Expand Down
Expand Up @@ -18,6 +18,7 @@
import io.airlift.slice.Slice;
import io.prestosql.plugin.kudu.properties.KuduTableProperties;
import io.prestosql.plugin.kudu.properties.PartitionDesign;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorInsertTableHandle;
Expand All @@ -33,7 +34,6 @@
import io.prestosql.spi.connector.ConstraintApplicationResult;
import io.prestosql.spi.connector.NotFoundException;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.expression.ConnectorExpression;
Expand Down
9 changes: 9 additions & 0 deletions presto-main/src/main/java/io/prestosql/metadata/Metadata.java
Expand Up @@ -21,6 +21,8 @@
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.block.BlockEncoding;
import io.prestosql.spi.block.BlockEncodingSerde;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.connector.AggregationApplicationResult;
import io.prestosql.spi.connector.CatalogSchemaName;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
Expand Down Expand Up @@ -353,6 +355,13 @@ public interface Metadata

Optional<TableHandle> applySample(Session session, TableHandle table, SampleType sampleType, double sampleRatio);

Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
Session session,
TableHandle table,
List<AggregateFunction> aggregations,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets);

default void validateScan(Session session, TableHandle table) {}

//
Expand Down
Expand Up @@ -44,6 +44,9 @@
import io.prestosql.spi.block.SingleMapBlockEncoding;
import io.prestosql.spi.block.SingleRowBlockEncoding;
import io.prestosql.spi.block.VariableWidthBlockEncoding;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.connector.AggregationApplicationResult;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.CatalogSchemaName;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
Expand Down Expand Up @@ -1108,6 +1111,57 @@ public Optional<TableHandle> applySample(Session session, TableHandle table, Sam
Optional.empty()));
}

@Override
public Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
Session session,
TableHandle table,
List<AggregateFunction> aggregations,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
{
CatalogName catalogName = table.getCatalogName();
ConnectorMetadata metadata = getMetadata(session, catalogName);

if (metadata.usesLegacyTableLayouts()) {
return Optional.empty();
}

ConnectorSession connectorSession = session.toConnectorSession(catalogName);
return metadata.applyAggregation(connectorSession, table.getConnectorHandle(), aggregations, assignments, groupingSets)
.map(result -> {
verifyProjection(table, result.getProjections(), result.getAssignments(), aggregations.size());

return new AggregationApplicationResult<>(
new TableHandle(catalogName, result.getHandle(), table.getTransaction(), Optional.empty()),
result.getProjections(),
result.getAssignments(),
result.getGroupingColumnMapping());
});
}

private void verifyProjection(TableHandle table, List<ConnectorExpression> projections, List<Assignment> assignments, int expectedProjectionSize)
{
projections.forEach(projection -> requireNonNull(projection, "one of the projections is null"));
assignments.forEach(assignment -> requireNonNull(assignment, "one of the assignments is null"));

verify(
expectedProjectionSize == projections.size(),
"ConnectorMetadata returned invalid number of projections: %s instead of %s for %s",
projections.size(),
expectedProjectionSize,
table);

Set<String> assignedVariables = assignments.stream()
.map(Assignment::getVariable)
.collect(toImmutableSet());
projections.stream()
.flatMap(connectorExpression -> ConnectorExpressions.extractVariables(connectorExpression).stream())
.map(Variable::getName)
.filter(variableName -> !assignedVariables.contains(variableName))
.findAny()
.ifPresent(variableName -> { throw new IllegalStateException("Unbound variable: " + variableName); });
}

@Override
public void validateScan(Session session, TableHandle table)
{
Expand Down Expand Up @@ -1146,24 +1200,7 @@ public Optional<ProjectionApplicationResult<TableHandle>> applyProjection(Sessio
ConnectorSession connectorSession = session.toConnectorSession(catalogName);
return metadata.applyProjection(connectorSession, table.getConnectorHandle(), projections, assignments)
.map(result -> {
result.getProjections().forEach(projection -> requireNonNull(projection, "one of the projections is null"));
result.getAssignments().forEach(assignment -> requireNonNull(assignment, "one of the assignments is null"));

verify(
projections.size() == result.getProjections().size(),
"ConnectorMetadata returned invalid number of projections: %s instead of %s for %s",
result.getProjections().size(),
projections.size(),
table);

Set<String> assignedVariables = result.getAssignments().stream()
.map(ProjectionApplicationResult.Assignment::getVariable)
.collect(toImmutableSet());
result.getProjections().stream()
.flatMap(connectorExpression -> ConnectorExpressions.extractVariables(connectorExpression).stream())
.map(Variable::getName)
.filter(variableName -> !assignedVariables.contains(variableName))
.findAny().ifPresent(variableName -> { throw new IllegalStateException("Unbound variable: " + variableName); });
verifyProjection(table, result.getProjections(), result.getAssignments(), projections.size());

return new ProjectionApplicationResult<>(
new TableHandle(catalogName, result.getHandle(), table.getTransaction(), Optional.empty()),
Expand Down
Expand Up @@ -107,6 +107,7 @@
import io.prestosql.sql.planner.iterative.rule.PruneUnnestSourceColumns;
import io.prestosql.sql.planner.iterative.rule.PruneValuesColumns;
import io.prestosql.sql.planner.iterative.rule.PruneWindowColumns;
import io.prestosql.sql.planner.iterative.rule.PushAggregationIntoTableScan;
import io.prestosql.sql.planner.iterative.rule.PushAggregationThroughOuterJoin;
import io.prestosql.sql.planner.iterative.rule.PushDeleteIntoConnector;
import io.prestosql.sql.planner.iterative.rule.PushDownDereferenceThroughFilter;
Expand Down Expand Up @@ -510,6 +511,7 @@ public PlanOptimizers(
.add(new PushLimitIntoTableScan(metadata))
.add(new PushPredicateIntoTableScan(metadata, typeAnalyzer))
.add(new PushSampleIntoTableScan(metadata))
.add(new PushAggregationIntoTableScan(metadata))
.build()),
new IterativeOptimizer(
ruleStats,
Expand Down
Expand Up @@ -17,6 +17,7 @@
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.TableHandle;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.expression.ConnectorExpression;
Expand Down Expand Up @@ -88,7 +89,7 @@ public static Optional<PlanNode> pruneColumns(Metadata metadata, TypeProvider ty
handle = result.get().getHandle();

Map<String, ColumnHandle> assignments = result.get().getAssignments().stream()
.collect(toImmutableMap(ProjectionApplicationResult.Assignment::getVariable, ProjectionApplicationResult.Assignment::getColumn));
.collect(toImmutableMap(Assignment::getVariable, Assignment::getColumn));

ImmutableMap.Builder<Symbol, ColumnHandle> builder = ImmutableMap.builder();
for (int i = 0; i < newOutputs.size(); i++) {
Expand Down

0 comments on commit 4052f33

Please sign in to comment.