Skip to content

Commit

Permalink
Migrate to Java 8 streams
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Jun 22, 2016
1 parent e19d941 commit ddfeb54
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 62 deletions.
Expand Up @@ -25,29 +25,27 @@
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.QualifiedNameReference;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;

import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.IS_DISTINCT_FROM;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.Predicates.not;
import static com.google.common.collect.Iterables.contains;
import static com.google.common.collect.Iterables.filter;
import static com.google.common.collect.Iterables.transform;
import static com.google.common.collect.Lists.newArrayList;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableSet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

public final class ExpressionUtils
{
Expand Down Expand Up @@ -86,7 +84,7 @@ public static Expression and(Expression... expressions)
return and(Arrays.asList(expressions));
}

public static Expression and(Iterable<Expression> expressions)
public static Expression and(Collection<Expression> expressions)
{
return binaryExpression(LogicalBinaryExpression.Type.AND, expressions);
}
Expand All @@ -96,19 +94,19 @@ public static Expression or(Expression... expressions)
return or(Arrays.asList(expressions));
}

public static Expression or(Iterable<Expression> expressions)
public static Expression or(Collection<Expression> expressions)
{
return binaryExpression(LogicalBinaryExpression.Type.OR, expressions);
}

public static Expression binaryExpression(LogicalBinaryExpression.Type type, Iterable<Expression> expressions)
public static Expression binaryExpression(LogicalBinaryExpression.Type type, Collection<Expression> expressions)
{
requireNonNull(type, "type is null");
requireNonNull(expressions, "expressions is null");
Preconditions.checkArgument(!Iterables.isEmpty(expressions), "expressions is empty");
Preconditions.checkArgument(!expressions.isEmpty(), "expressions is empty");

// build balanced tree for efficient recursive processing
Queue<Expression> queue = new ArrayDeque<>(newArrayList(expressions));
Queue<Expression> queue = new ArrayDeque<>(expressions);
while (queue.size() > 1) {
queue.add(new LogicalBinaryExpression(type, queue.remove(), queue.remove()));
}
Expand All @@ -120,7 +118,7 @@ public static Expression combinePredicates(LogicalBinaryExpression.Type type, Ex
return combinePredicates(type, Arrays.asList(expressions));
}

public static Expression combinePredicates(LogicalBinaryExpression.Type type, Iterable<Expression> expressions)
public static Expression combinePredicates(LogicalBinaryExpression.Type type, Collection<Expression> expressions)
{
if (type == LogicalBinaryExpression.Type.AND) {
return combineConjuncts(expressions);
Expand All @@ -134,60 +132,64 @@ public static Expression combineConjuncts(Expression... expressions)
return combineConjuncts(Arrays.asList(expressions));
}

public static Expression combineConjuncts(Iterable<Expression> expressions)
public static Expression combineConjuncts(Collection<Expression> expressions)
{
return combineConjunctsWithDefault(expressions, TRUE_LITERAL);
}

public static Expression combineConjunctsWithDefault(Iterable<Expression> expressions, Expression emptyDefault)
public static Expression combineConjunctsWithDefault(Collection<Expression> expressions, Expression emptyDefault)
{
requireNonNull(expressions, "expressions is null");

// Flatten all the expressions into their component conjuncts
expressions = Iterables.concat(transform(expressions, ExpressionUtils::extractConjuncts));
List<Expression> conjuncts = expressions.stream()
.flatMap(e -> ExpressionUtils.extractConjuncts(e).stream())
.filter(e -> !e.equals(TRUE_LITERAL))
.collect(toList());

// Strip out all true literal conjuncts
expressions = filter(expressions, not(Predicates.<Expression>equalTo(TRUE_LITERAL)));
expressions = removeDuplicates(expressions);
conjuncts = removeDuplicates(conjuncts);

if (contains(expressions, FALSE_LITERAL)) {
if (conjuncts.contains(FALSE_LITERAL)) {
return FALSE_LITERAL;
}

return Iterables.isEmpty(expressions) ? emptyDefault : and(expressions);
return conjuncts.isEmpty() ? emptyDefault : and(conjuncts);
}

public static Expression combineDisjuncts(Expression... expressions)
{
return combineDisjuncts(Arrays.asList(expressions));
}

public static Expression combineDisjuncts(Iterable<Expression> expressions)
public static Expression combineDisjuncts(Collection<Expression> expressions)
{
return combineDisjunctsWithDefault(expressions, FALSE_LITERAL);
}

public static Expression combineDisjunctsWithDefault(Iterable<Expression> expressions, Expression emptyDefault)
public static Expression combineDisjunctsWithDefault(Collection<Expression> expressions, Expression emptyDefault)
{
requireNonNull(expressions, "expressions is null");

// Flatten all the expressions into their component disjuncts
expressions = Iterables.concat(transform(expressions, ExpressionUtils::extractDisjuncts));
List<Expression> disjuncts = expressions.stream()
.flatMap(e -> ExpressionUtils.extractDisjuncts(e).stream())
.filter(e -> !e.equals(FALSE_LITERAL))
.collect(toList());

// Strip out all false literal disjuncts
expressions = filter(expressions, not(Predicates.<Expression>equalTo(FALSE_LITERAL)));
expressions = removeDuplicates(expressions);
disjuncts = removeDuplicates(disjuncts);

if (contains(expressions, TRUE_LITERAL)) {
if (disjuncts.contains(TRUE_LITERAL)) {
return TRUE_LITERAL;
}

return Iterables.isEmpty(expressions) ? emptyDefault : or(expressions);
return disjuncts.isEmpty() ? emptyDefault : or(disjuncts);
}

public static Expression stripNonDeterministicConjuncts(Expression expression)
{
return combineConjuncts(filter(extractConjuncts(expression), DeterminismEvaluator::isDeterministic));
Set<Expression> conjuncts = extractConjuncts(expression).stream()
.filter(DeterminismEvaluator::isDeterministic)
.collect(toSet());

return combineConjuncts(conjuncts);
}

public static Expression stripDeterministicConjuncts(Expression expression)
Expand All @@ -205,7 +207,10 @@ public static Function<Expression, Expression> expressionOrNullSymbols(final Pre
resultDisjunct.add(expression);

for (Predicate<Symbol> nullSymbolScope : nullSymbolScopes) {
Iterable<Symbol> symbols = filter(DependencyExtractor.extractUnique(expression), nullSymbolScope);
List<Symbol> symbols = DependencyExtractor.extractUnique(expression).stream()
.filter(nullSymbolScope)
.collect(toImmutableList());

if (Iterables.isEmpty(symbols)) {
continue;
}
Expand All @@ -222,15 +227,22 @@ public static Function<Expression, Expression> expressionOrNullSymbols(final Pre
};
}

private static Iterable<Expression> removeDuplicates(Iterable<Expression> expressions)
private static List<Expression> removeDuplicates(List<Expression> expressions)
{
// Capture all non-deterministic predicates
Iterable<Expression> nonDeterministicDisjuncts = filter(expressions, not(DeterminismEvaluator::isDeterministic));
List<Expression> nonDeterministicDisjuncts = expressions.stream()
.filter(e -> !DeterminismEvaluator.isDeterministic(e))
.collect(toImmutableList());

// Capture and de-dupe all deterministic predicates
Iterable<Expression> deterministicDisjuncts = ImmutableSet.copyOf(filter(expressions, DeterminismEvaluator::isDeterministic));

return Iterables.concat(nonDeterministicDisjuncts, deterministicDisjuncts);
Set<Expression> deterministicDisjuncts = expressions.stream()
.filter(DeterminismEvaluator::isDeterministic)
.collect(toImmutableSet());

return ImmutableList.<Expression>builder()
.addAll(nonDeterministicDisjuncts)
.addAll(deterministicDisjuncts)
.build();
}

public static Expression normalize(Expression expression)
Expand Down
Expand Up @@ -59,8 +59,6 @@
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.Predicates.in;
import static com.google.common.collect.Iterables.transform;
import static java.util.stream.Collectors.toList;

/**
* Computes the effective predicate at the top of the specified PlanNode
Expand Down Expand Up @@ -232,33 +230,33 @@ public Expression visitJoin(JoinNode node, Void context)
case LEFT:
return combineConjuncts(ImmutableList.<Expression>builder()
.add(leftPredicate)
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), in(node.getRight().getOutputSymbols())))
.addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, in(node.getRight().getOutputSymbols())))
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getRight().getOutputSymbols()::contains))
.addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getRight().getOutputSymbols()::contains))
.build());
case RIGHT:
return combineConjuncts(ImmutableList.<Expression>builder()
.add(rightPredicate)
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), in(node.getLeft().getOutputSymbols())))
.addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, in(node.getLeft().getOutputSymbols())))
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getLeft().getOutputSymbols()::contains))
.addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getLeft().getOutputSymbols()::contains))
.build());
case FULL:
return combineConjuncts(ImmutableList.<Expression>builder()
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), in(node.getLeft().getOutputSymbols())))
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), in(node.getRight().getOutputSymbols())))
.addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, in(node.getLeft().getOutputSymbols()), in(node.getRight().getOutputSymbols())))
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getLeft().getOutputSymbols()::contains))
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getRight().getOutputSymbols()::contains))
.addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getLeft().getOutputSymbols()::contains, node.getRight().getOutputSymbols()::contains))
.build());
default:
throw new UnsupportedOperationException("Unknown join type: " + node.getType());
}
}

private Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> conjuncts, com.google.common.base.Predicate<Symbol>... nullSymbolScopes)
private Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> conjuncts, Predicate<Symbol>... nullSymbolScopes)
{
// Conjuncts without any symbol dependencies cannot be applied to the effective predicate (e.g. FALSE literal)
conjuncts = conjuncts.stream()
return conjuncts.stream()
.map(expression -> DependencyExtractor.extractAll(expression).isEmpty() ? TRUE_LITERAL : expression)
.collect(toList());
return transform(conjuncts, expressionOrNullSymbols(nullSymbolScopes));
.map(expressionOrNullSymbols(nullSymbolScopes))
.collect(toImmutableList());
}

@Override
Expand Down
Expand Up @@ -52,7 +52,6 @@
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand All @@ -68,6 +67,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts;
Expand All @@ -81,13 +81,13 @@
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Predicates.equalTo;
import static com.google.common.base.Predicates.in;
import static com.google.common.base.Predicates.not;
import static com.google.common.collect.Iterables.filter;
import static com.google.common.collect.Iterables.transform;
import static java.util.Objects.requireNonNull;

public class PredicatePushDown
Expand Down Expand Up @@ -371,7 +371,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Expression> context)
ImmutableList.Builder<JoinNode.EquiJoinClause> joinConditionBuilder = ImmutableList.builder();
ImmutableList.Builder<Expression> joinFilterBuilder = ImmutableList.builder();
for (Expression conjunct : extractConjuncts(newJoinPredicate)) {
if (joinEqualityExpression(node.getLeft().getOutputSymbols()).apply(conjunct)) {
if (joinEqualityExpression(node.getLeft().getOutputSymbols()).test(conjunct)) {
ComparisonExpression equality = (ComparisonExpression) conjunct;

boolean alignedComparison = Iterables.all(DependencyExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols()));
Expand Down Expand Up @@ -602,8 +602,14 @@ private InnerJoinPushDownResult processInnerJoin(Expression inheritedPredicate,

// Since we only currently support equality in join conjuncts, factor out the non-equality conjuncts to a post-join filter
List<Expression> joinConjunctsList = joinConjuncts.build();
List<Expression> postJoinConjuncts = ImmutableList.copyOf(filter(joinConjunctsList, not(joinEqualityExpression(leftSymbols))));
joinConjunctsList = ImmutableList.copyOf(filter(joinConjunctsList, joinEqualityExpression(leftSymbols)));

List<Expression> postJoinConjuncts = joinConjunctsList.stream()
.filter(joinEqualityExpression(leftSymbols).negate())
.collect(toImmutableList());

joinConjunctsList = joinConjunctsList.stream()
.filter(joinEqualityExpression(leftSymbols))
.collect(toImmutableList());

return new InnerJoinPushDownResult(combineConjuncts(leftPushDownConjuncts.build()), combineConjuncts(rightPushDownConjuncts.build()), combineConjuncts(joinConjunctsList), combineConjuncts(postJoinConjuncts));
}
Expand Down Expand Up @@ -769,13 +775,15 @@ public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Expression> cont
Expression rewrittenConjunct = joinInference.rewriteExpression(conjunct, equalTo(node.getFilteringSourceJoinSymbol()));
if (rewrittenConjunct != null && DeterminismEvaluator.isDeterministic(rewrittenConjunct)) {
// Alter conjunct to include an OR filteringSourceJoinSymbol IS NULL disjunct
Expression rewrittenConjunctOrNull = expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol())).apply(rewrittenConjunct);
Expression rewrittenConjunctOrNull = expressionOrNullSymbols(Predicate.isEqual(node.getFilteringSourceJoinSymbol())).apply(rewrittenConjunct);
filteringSourceConjuncts.add(rewrittenConjunctOrNull);
}
}
EqualityInference.EqualityPartition joinInferenceEqualityPartition = joinInference.generateEqualitiesPartitionedBy(equalTo(node.getFilteringSourceJoinSymbol()));
filteringSourceConjuncts.addAll(ImmutableList.copyOf(transform(joinInferenceEqualityPartition.getScopeEqualities(),
expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol())))));

filteringSourceConjuncts.addAll(joinInferenceEqualityPartition.getScopeEqualities().stream()
.map(expressionOrNullSymbols(Predicate.isEqual(node.getFilteringSourceJoinSymbol())))
.collect(Collectors.toList()));

// Push inheritedPredicates down to the source if they don't involve the semi join output
EqualityInference inheritedInference = createEqualityInference(inheritedPredicate);
Expand Down
Expand Up @@ -59,6 +59,7 @@
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -733,7 +734,7 @@ private Set<Expression> normalizeConjuncts(Expression... conjuncts)
return normalizeConjuncts(Arrays.asList(conjuncts));
}

private Set<Expression> normalizeConjuncts(Iterable<Expression> conjuncts)
private Set<Expression> normalizeConjuncts(Collection<Expression> conjuncts)
{
return normalizeConjuncts(combineConjuncts(conjuncts));
}
Expand Down

0 comments on commit ddfeb54

Please sign in to comment.