Skip to content

Commit

Permalink
Add DynamicFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrii Rosa authored and amoghmargoor committed Mar 21, 2018
1 parent 760b0ff commit 5622cb0
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 0 deletions.
130 changes: 130 additions & 0 deletions presto-main/src/main/java/com/facebook/presto/sql/DynamicFilter.java
@@ -0,0 +1,130 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql;

import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.DeferredSymbolReference;
import com.facebook.presto.sql.tree.Expression;

import java.util.Objects;
import java.util.Optional;

import static com.facebook.presto.sql.DynamicFilterUtils.isDynamicFilter;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class DynamicFilter
{
private final Expression sourceExpression;
private final String tupleDomainSourceId;
private final String tupleDomainName;
private final ComparisonExpressionType comparisonType;

public static Optional<DynamicFilter> getDynamicFilterOptional(Expression expression)
{
if (!isDynamicFilter(expression)) {
return Optional.empty();
}
return Optional.of(from(expression));
}

public static DynamicFilter from(Expression expression)
{
checkArgument(expression instanceof ComparisonExpression, "Unexpected expression: %s", expression);
ComparisonExpression comparison = (ComparisonExpression) expression;

Expression sourceExpression = null;
DeferredSymbolReference deferredReference = null;

checkState(comparison.getLeft() instanceof DeferredSymbolReference ^ comparison.getRight() instanceof DeferredSymbolReference, "Exactly one deferred symbol per expression is required");
if (comparison.getLeft() instanceof DeferredSymbolReference) {
sourceExpression = comparison.getRight();
deferredReference = (DeferredSymbolReference) comparison.getLeft();
}
else if (comparison.getRight() instanceof DeferredSymbolReference) {
sourceExpression = comparison.getLeft();
deferredReference = (DeferredSymbolReference) comparison.getRight();
}

return new DynamicFilter(
sourceExpression,
deferredReference.getSourceId(),
deferredReference.getSymbol(),
comparison.getType());
}

public DynamicFilter(Expression sourceExpression, String tupleDomainSourceId, String tupleDomainName, ComparisonExpressionType comparisonType)
{
this.sourceExpression = requireNonNull(sourceExpression, "symbol is null");
this.tupleDomainSourceId = requireNonNull(tupleDomainSourceId, "tupleDomainSourceId is null");
this.tupleDomainName = requireNonNull(tupleDomainName, "tupleDomainName is null");
this.comparisonType = requireNonNull(comparisonType, "comparisonType is null");
}

public Expression getSourceExpression()
{
return sourceExpression;
}

public String getTupleDomainSourceId()
{
return tupleDomainSourceId;
}

public String getTupleDomainName()
{
return tupleDomainName;
}

public ComparisonExpressionType getComparisonType()
{
return comparisonType;
}

@Override
public String toString()
{
return toStringHelper(this)
.add("sourceExpression", sourceExpression)
.add("tupleDomainSourceId", tupleDomainSourceId)
.add("tupleDomainName", tupleDomainName)
.add("comparisonType", comparisonType)
.toString();
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
DynamicFilter that = (DynamicFilter) o;
return Objects.equals(sourceExpression, that.sourceExpression) &&
Objects.equals(tupleDomainSourceId, that.tupleDomainSourceId) &&
Objects.equals(tupleDomainName, that.tupleDomainName) &&
comparisonType == that.comparisonType;
}

@Override
public int hashCode()
{
return Objects.hash(sourceExpression, tupleDomainSourceId, tupleDomainName, comparisonType);
}
}
@@ -0,0 +1,97 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql;

import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DeferredSymbolReference;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts;
import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;

public class DynamicFilterUtils
{
private DynamicFilterUtils() {}

public static Expression stripDynamicFilters(Expression expression)
{
return combineConjuncts(extractConjuncts(expression)
.stream()
.filter(conjunct -> !isDynamicFilter(conjunct))
.collect(toImmutableList()));
}

public static boolean isDynamicFilter(Expression expression)
{
if (!(expression instanceof ComparisonExpression)) {
return false;
}

ComparisonExpression comparison = (ComparisonExpression) expression;
checkState(!(comparison.getLeft() instanceof DeferredSymbolReference && comparison.getRight() instanceof DeferredSymbolReference), "Dynamic filter cannot have DeferredSymbolReferences");
return comparison.getLeft() instanceof DeferredSymbolReference || comparison.getRight() instanceof DeferredSymbolReference;
}

public static ExtractDynamicFiltersResult extractDynamicFilters(Expression expression)
{
List<Expression> filters = extractConjuncts(expression);

List<Expression> staticFilters = new ArrayList<>(filters.size());
List<Expression> dynamicFilters = new ArrayList<>(filters.size());

for (Expression filter : filters) {
if (isDynamicFilter(filter)) {
dynamicFilters.add(filter);
}
else {
staticFilters.add(filter);
}
}

return new ExtractDynamicFiltersResult(
combineConjuncts(staticFilters),
dynamicFilters.stream().map(DynamicFilter::from).collect(toImmutableSet()));
}

public static class ExtractDynamicFiltersResult
{
private final Expression staticFilters;
private final Set<DynamicFilter> dynamicFilters;

public ExtractDynamicFiltersResult(Expression staticFilters, Set<DynamicFilter> dynamicFilters)
{
this.staticFilters = requireNonNull(staticFilters, "staticFilters is null");
this.dynamicFilters = ImmutableSet.copyOf(requireNonNull(dynamicFilters, "dynamicFilters is null"));
}

public Expression getStaticFilters()
{
return staticFilters;
}

public Set<DynamicFilter> getDynamicFilters()
{
return dynamicFilters;
}
}
}
Expand Up @@ -682,6 +682,11 @@ protected R visitSymbolReference(SymbolReference node, C context)
return visitExpression(node, context);
}

protected R visitDeferredSymbolReference(DeferredSymbolReference node, C context)
{
return visitExpression(node, context);
}

protected R visitQuantifiedComparisonExpression(QuantifiedComparisonExpression node, C context)
{
return visitExpression(node, context);
Expand Down
@@ -0,0 +1,78 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.tree;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static java.util.Objects.requireNonNull;

public class DeferredSymbolReference
extends Expression
{
private final String sourceId;
private final String symbol;

public DeferredSymbolReference(String sourceId, String symbol)
{
super(Optional.empty());
this.sourceId = requireNonNull(sourceId, "sourceId is null");
this.symbol = requireNonNull(symbol, "symbol is null");
}

public String getSourceId()
{
return sourceId;
}

public String getSymbol()
{
return symbol;
}

@Override
protected <R, C> R accept(AstVisitor<R, C> visitor, C context)
{
return visitor.visitDeferredSymbolReference(this, context);
}

@Override
public List<? extends Node> getChildren()
{
return ImmutableList.of();
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
DeferredSymbolReference reference = (DeferredSymbolReference) o;
return Objects.equals(sourceId, reference.sourceId) &&
Objects.equals(symbol, reference.symbol);
}

@Override
public int hashCode()
{
return Objects.hash(sourceId, symbol);
}
}

0 comments on commit 5622cb0

Please sign in to comment.