Skip to content

Commit

Permalink
Use Query-based table handles in JDBC connectors
Browse files Browse the repository at this point in the history
This changes how aggregation pushdown is modelled in `JdbcTableHandle`.
Instead of  `JdbcTableHandle.groupingSets` and
`JdbcColumnHandle.expression`, it is converted directly into a query.

This allows further pushdowns, e.g. predicate after aggregation.
  • Loading branch information
findepi committed Jan 25, 2021
1 parent 6442c96 commit 02b59cd
Show file tree
Hide file tree
Showing 32 changed files with 222 additions and 142 deletions.
Expand Up @@ -420,6 +420,32 @@ public Connection getConnection(ConnectorSession session, JdbcSplit split)
return connection;
}

@Override
public PreparedQuery prepareQuery(
ConnectorSession session,
JdbcTableHandle table,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions)
{
try (Connection connection = connectionFactory.openConnection(session)) {
PreparedQuery preparedQuery = new QueryBuilder(this).prepareQuery(
session,
connection,
table.getRelationHandle(),
groupingSets,
columns,
columnExpressions,
table.getConstraint(),
Optional.empty());
preparedQuery = preparedQuery.transformQuery(tryApplyLimit(table.getLimit()));
return preparedQuery;
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
}
}

@Override
public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List<JdbcColumnHandle> columns)
throws SQLException
Expand All @@ -429,8 +455,9 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio
session,
connection,
table.getRelationHandle(),
table.getGroupingSets(),
Optional.empty(),
columns,
ImmutableMap.of(),
table.getConstraint(),
split.getAdditionalPredicate());
preparedQuery = preparedQuery.transformQuery(tryApplyLimit(table.getLimit()));
Expand Down
Expand Up @@ -190,6 +190,17 @@ public void abortReadConnection(Connection connection)
delegate.abortReadConnection(connection);
}

@Override
public PreparedQuery prepareQuery(
ConnectorSession session,
JdbcTableHandle table,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions)
{
return delegate.prepareQuery(session, table, groupingSets, columns, columnExpressions);
}

@Override
public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List<JdbcColumnHandle> columns)
throws SQLException
Expand Down
Expand Up @@ -141,6 +141,17 @@ public void abortReadConnection(Connection connection)
delegate().abortReadConnection(connection);
}

@Override
public PreparedQuery prepareQuery(
ConnectorSession session,
JdbcTableHandle table,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions)
{
return delegate().prepareQuery(session, table, groupingSets, columns, columnExpressions);
}

@Override
public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle tableHandle, List<JdbcColumnHandle> columnHandles)
throws SQLException
Expand Down
Expand Up @@ -91,6 +91,13 @@ default void abortReadConnection(Connection connection)
// most drivers do not need this
}

PreparedQuery prepareQuery(
ConnectorSession session,
JdbcTableHandle table,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions);

PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List<JdbcColumnHandle> columns)
throws SQLException;

Expand Down
Expand Up @@ -22,14 +22,12 @@

import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

import static java.util.Objects.requireNonNull;

public final class JdbcColumnHandle
implements ColumnHandle
{
private final Optional<String> expression;
private final String columnName;
private final JdbcTypeHandle jdbcTypeHandle;
private final Type columnType;
Expand All @@ -39,7 +37,7 @@ public final class JdbcColumnHandle
// All and only required fields
public JdbcColumnHandle(String columnName, JdbcTypeHandle jdbcTypeHandle, Type columnType)
{
this(Optional.empty(), columnName, jdbcTypeHandle, columnType, true, Optional.empty());
this(columnName, jdbcTypeHandle, columnType, true, Optional.empty());
}

/**
Expand All @@ -48,7 +46,7 @@ public JdbcColumnHandle(String columnName, JdbcTypeHandle jdbcTypeHandle, Type c
@Deprecated
public JdbcColumnHandle(String columnName, JdbcTypeHandle jdbcTypeHandle, Type columnType, boolean nullable)
{
this(Optional.empty(), columnName, jdbcTypeHandle, columnType, nullable, Optional.empty());
this(columnName, jdbcTypeHandle, columnType, nullable, Optional.empty());
}

/**
Expand All @@ -57,27 +55,19 @@ public JdbcColumnHandle(String columnName, JdbcTypeHandle jdbcTypeHandle, Type c
@Deprecated
@JsonCreator
public JdbcColumnHandle(
@JsonProperty("expression") Optional<String> expression,
@JsonProperty("columnName") String columnName,
@JsonProperty("jdbcTypeHandle") JdbcTypeHandle jdbcTypeHandle,
@JsonProperty("columnType") Type columnType,
@JsonProperty("nullable") boolean nullable,
@JsonProperty("comment") Optional<String> comment)
{
this.expression = requireNonNull(expression, "expression is null");
this.columnName = requireNonNull(columnName, "columnName is null");
this.jdbcTypeHandle = requireNonNull(jdbcTypeHandle, "jdbcTypeHandle is null");
this.columnType = requireNonNull(columnType, "columnType is null");
this.nullable = nullable;
this.comment = requireNonNull(comment, "comment is null");
}

@JsonProperty
public Optional<String> getExpression()
{
return expression;
}

@JsonProperty
public String getColumnName()
{
Expand Down Expand Up @@ -141,19 +131,11 @@ public int hashCode()
public String toString()
{
return Joiner.on(":").skipNulls().join(
expression.orElse(null),
columnName,
columnType.getDisplayName(),
jdbcTypeHandle.getJdbcTypeName().orElse(null));
}

public String toSqlExpression(Function<String, String> identifierQuote)
{
requireNonNull(identifierQuote, "identifierQuote is null");
return expression
.orElseGet(() -> identifierQuote.apply(columnName));
}

public static Builder builder()
{
return new Builder();
Expand All @@ -166,7 +148,6 @@ public static Builder builderFrom(JdbcColumnHandle handle)

public static final class Builder
{
private Optional<String> expression = Optional.empty();
private String columnName;
private JdbcTypeHandle jdbcTypeHandle;
private Type columnType;
Expand All @@ -177,20 +158,13 @@ public Builder() {}

private Builder(JdbcColumnHandle handle)
{
this.expression = handle.getExpression();
this.columnName = handle.getColumnName();
this.jdbcTypeHandle = handle.getJdbcTypeHandle();
this.columnType = handle.getColumnType();
this.nullable = handle.isNullable();
this.comment = handle.getComment();
}

public Builder setExpression(Optional<String> expression)
{
this.expression = expression;
return this;
}

public Builder setColumnName(String columnName)
{
this.columnName = columnName;
Expand Down Expand Up @@ -224,7 +198,6 @@ public Builder setComment(Optional<String> comment)
public JdbcColumnHandle build()
{
return new JdbcColumnHandle(
expression,
columnName,
jdbcTypeHandle,
columnType,
Expand Down
Expand Up @@ -56,7 +56,6 @@
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.base.Functions.identity;
Expand Down Expand Up @@ -114,21 +113,6 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
{
JdbcTableHandle handle = (JdbcTableHandle) table;

if (handle.getGroupingSets().isPresent()) {
if (constraint.getSummary().isNone()) {
return Optional.empty();
}

Set<ColumnHandle> constraintColumns = constraint.getSummary().getDomains().orElseThrow().keySet();
List<List<JdbcColumnHandle>> groupingSets = handle.getGroupingSets().get();
boolean canPushDown = groupingSets.stream()
.allMatch(groupingSet -> ImmutableSet.copyOf(groupingSet).containsAll(constraintColumns));

if (!canPushDown) {
return Optional.empty();
}
}

TupleDomain<ColumnHandle> oldDomain = handle.getConstraint();
TupleDomain<ColumnHandle> newDomain = oldDomain.intersect(constraint.getSummary());

Expand Down Expand Up @@ -168,13 +152,23 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
handle = new JdbcTableHandle(
handle.getRelationHandle(),
newDomain,
handle.getGroupingSets(),
handle.getLimit(),
handle.getColumns());

return Optional.of(new ConstraintApplicationResult<>(handle, remainingFilter));
}

private JdbcTableHandle flushAttributesAsQuery(ConnectorSession session, JdbcTableHandle handle)
{
List<JdbcColumnHandle> columns = jdbcClient.getColumns(session, handle);
PreparedQuery preparedQuery = jdbcClient.prepareQuery(session, handle, Optional.empty(), columns, ImmutableMap.of());
return new JdbcTableHandle(
new JdbcQueryRelationHandle(preparedQuery),
TupleDomain.all(),
OptionalLong.empty(),
Optional.of(columns));
}

@Override
public Optional<ProjectionApplicationResult<ConnectorTableHandle>> applyProjection(
ConnectorSession session,
Expand All @@ -196,7 +190,6 @@ public Optional<ProjectionApplicationResult<ConnectorTableHandle>> applyProjecti
new JdbcTableHandle(
handle.getRelationHandle(),
handle.getConstraint(),
handle.getGroupingSets(),
handle.getLimit(),
Optional.of(newColumns)),
projections,
Expand All @@ -222,13 +215,10 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega

JdbcTableHandle handle = (JdbcTableHandle) table;

if (handle.getLimit().isPresent()) {
// handle's limit is applied after aggregations, so we cannot apply aggregations if limit is already set
return Optional.empty();
}
// Global aggregation is represented by [[]]
verify(!groupingSets.isEmpty(), "No grouping sets provided");

if (handle.getGroupingSets().isPresent()) {
// table handle cannot express aggregation on top of aggregation
if (groupingSets.size() > 1 && !jdbcClient.supportsGroupingSets()) {
return Optional.empty();
}

Expand All @@ -237,11 +227,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
return Optional.empty();
}

// Global aggregation is represented by [[]]
verify(!groupingSets.isEmpty(), "No grouping sets provided");

if (groupingSets.size() > 1 && !jdbcClient.supportsGroupingSets()) {
return Optional.empty();
if (handle.getLimit().isPresent()) {
handle = flushAttributesAsQuery(session, handle);
}

List<JdbcColumnHandle> columns = jdbcClient.getColumns(session, handle);
Expand All @@ -253,6 +240,7 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
ImmutableList.Builder<JdbcColumnHandle> newColumns = ImmutableList.builder();
ImmutableList.Builder<ConnectorExpression> projections = ImmutableList.builder();
ImmutableList.Builder<Assignment> resultAssignments = ImmutableList.builder();
ImmutableMap.Builder<String, String> expressions = ImmutableMap.builder();
for (AggregateFunction aggregate : aggregates) {
Optional<JdbcExpression> expression = jdbcClient.implementAggregation(session, aggregate, assignments);
if (expression.isEmpty()) {
Expand All @@ -263,30 +251,47 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
syntheticNextIdentifier++;
}

String columnName = SYNTHETIC_COLUMN_NAME_PREFIX + syntheticNextIdentifier;
JdbcColumnHandle newColumn = JdbcColumnHandle.builder()
.setExpression(Optional.of(expression.get().getExpression()))
.setColumnName(SYNTHETIC_COLUMN_NAME_PREFIX + syntheticNextIdentifier)
.setColumnName(columnName)
.setJdbcTypeHandle(expression.get().getJdbcTypeHandle())
.setColumnType(aggregate.getOutputType())
.setComment(Optional.of("synthetic"))
.build();
syntheticNextIdentifier++;

newColumns.add(newColumn);
projections.add(new Variable(newColumn.getColumnName(), aggregate.getOutputType()));
resultAssignments.add(new Assignment(newColumn.getColumnName(), newColumn, aggregate.getOutputType()));
expressions.put(columnName, expression.get().getExpression());

syntheticNextIdentifier++;
}

List<List<JdbcColumnHandle>> groupingSetsAsJdbcColumnHandles = groupingSets.stream()
.map(groupingSet -> groupingSet.stream()
.map(JdbcColumnHandle.class::cast)
.collect(toImmutableList()))
.collect(toImmutableList());

List<JdbcColumnHandle> newColumnsList = newColumns.build();

PreparedQuery preparedQuery = jdbcClient.prepareQuery(
session,
handle,
Optional.of(groupingSetsAsJdbcColumnHandles),
ImmutableList.<JdbcColumnHandle>builder()
.addAll(groupingSetsAsJdbcColumnHandles.stream()
.flatMap(List::stream)
.distinct()
.iterator())
.addAll(newColumnsList)
.build(),
expressions.build());
handle = new JdbcTableHandle(
handle.getRelationHandle(),
handle.getConstraint(),
Optional.of(groupingSets.stream()
.map(groupingSet -> groupingSet.stream()
.map(JdbcColumnHandle.class::cast)
.collect(toImmutableList()))
.collect(toImmutableList())),
OptionalLong.empty(), // limit
Optional.of(newColumns.build()));
new JdbcQueryRelationHandle(preparedQuery),
TupleDomain.all(),
OptionalLong.empty(),
Optional.of(newColumnsList));

return Optional.of(new AggregationApplicationResult<>(handle, projections.build(), resultAssignments.build(), ImmutableMap.of()));
}
Expand All @@ -307,7 +312,6 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
handle = new JdbcTableHandle(
handle.getRelationHandle(),
handle.getConstraint(),
handle.getGroupingSets(),
OptionalLong.of(limit),
handle.getColumns());

Expand Down

0 comments on commit 02b59cd

Please sign in to comment.