Skip to content

Commit

Permalink
Allow SymbolsExtractor to accept RowExpression
Browse files Browse the repository at this point in the history
  • Loading branch information
highker committed Apr 17, 2019
1 parent 2b4ddf0 commit 1b05879
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.relational.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
Expand All @@ -35,6 +38,7 @@
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toSet;

public final class SymbolsExtractor
{
Expand Down Expand Up @@ -69,6 +73,12 @@ public static Set<Symbol> extractUnique(Expression expression)
return ImmutableSet.copyOf(extractAll(expression));
}

// TODO: return Set<VariableReferenceExpression>
public static Set<Symbol> extractUnique(RowExpression expression)
{
return ImmutableSet.copyOf(extractAll(expression).stream().map(variable -> new Symbol(variable.getName())).collect(toSet()));
}

public static Set<Symbol> extractUnique(Iterable<? extends Expression> expressions)
{
ImmutableSet.Builder<Symbol> unique = ImmutableSet.builder();
Expand All @@ -85,6 +95,13 @@ public static List<Symbol> extractAll(Expression expression)
return builder.build();
}

public static List<VariableReferenceExpression> extractAll(RowExpression expression)
{
ImmutableList.Builder<VariableReferenceExpression> builder = ImmutableList.builder();
expression.accept(new VariableBuilderVisitor(), builder);
return builder.build();
}

// to extract qualified name with prefix
public static Set<QualifiedName> extractNames(Expression expression, Set<NodeRef<Expression>> columnReferences)
{
Expand Down Expand Up @@ -118,6 +135,17 @@ protected Void visitSymbolReference(SymbolReference node, ImmutableList.Builder<
}
}

private static class VariableBuilderVisitor
extends DefaultRowExpressionTraversalVisitor<ImmutableList.Builder<VariableReferenceExpression>>
{
@Override
public Void visitVariableReference(VariableReferenceExpression variable, ImmutableList.Builder<VariableReferenceExpression> builder)
{
builder.add(variable);
return null;
}
}

private static class QualifiedNameBuilderVisitor
extends DefaultTraversalVisitor<Void, ImmutableSet.Builder<QualifiedName>>
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.relational;

import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;

/**
* The default visitor serves as a template for "consumer-like" tree traversal.
* {@param context} is the consumer to apply customized actions on the visiting RowExpression.
*/
public class DefaultRowExpressionTraversalVisitor<C>
implements RowExpressionVisitor<Void, C>
{
@Override
public Void visitInputReference(InputReferenceExpression input, C context)
{
return null;
}

@Override
public Void visitCall(CallExpression call, C context)
{
call.getArguments().forEach(argument -> argument.accept(this, context));
return null;
}

@Override
public Void visitConstant(ConstantExpression literal, C context)
{
return null;
}

@Override
public Void visitLambda(LambdaDefinitionExpression lambda, C context)
{
return null;
}

@Override
public Void visitVariableReference(VariableReferenceExpression reference, C context)
{
return null;
}

@Override
public Void visitSpecialForm(SpecialFormExpression specialForm, C context)
{
specialForm.getArguments().forEach(argument -> argument.accept(this, context));
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.Map;

import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences;
import static com.facebook.presto.sql.planner.SymbolsExtractor.extractAll;
import static com.facebook.presto.sql.planner.SymbolsExtractor.extractUnique;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static org.testng.Assert.assertEquals;

public class TestSymbolExtractor
{
private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA);
private static final Map<Symbol, Type> SYMBOL_TYPES = ImmutableMap.of(new Symbol("a"), BIGINT, new Symbol("b"), BIGINT, new Symbol("c"), BIGINT);

@Test
public void testSimple()
{
assertSymbols("a > b");
assertSymbols("a + b > c");
assertSymbols("sin(a) - b");
assertSymbols("sin(a) + cos(a) - b");
assertSymbols("sin(a) + cos(a) + a - b");
assertSymbols("COALESCE(a, b, 1)");
assertSymbols("a IN (a, b, c)");
assertSymbols("transform(sequence(1, 5), a -> a + b)");
assertSymbols("bigint '1'");
}

private static void assertSymbols(String expression)
{
Expression expected = rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expression, new ParsingOptions()));
RowExpression actual = TRANSLATOR.translate(expected, TypeProvider.copyOf(SYMBOL_TYPES));
assertEquals(extractUnique(expected), extractUnique(actual));
assertEquals(
extractAll(expected).stream().sorted().collect(toImmutableList()),
extractAll(actual).stream().map(variable -> new Symbol(variable.getName())).sorted().collect(toImmutableList()));
}
}

0 comments on commit 1b05879

Please sign in to comment.