Skip to content

Commit

Permalink
Revert 5 commits that fix analysis/planning for lambda arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhun committed Nov 2, 2017
1 parent 9721731 commit 9750c2e
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 113 deletions.
Expand Up @@ -439,13 +439,6 @@ public Map<NodeRef<Expression>, FieldId> getColumnReferenceFields()
return unmodifiableMap(columnReferences);
}

public boolean isColumnReference(Expression expression)
{
requireNonNull(expression, "expression is null");
checkArgument(getType(expression) != null, "expression %s has not been analyzed", expression);
return columnReferences.containsKey(NodeRef.of(expression));
}

public void addTypes(Map<NodeRef<Expression>, Type> types)
{
this.types.putAll(types);
Expand Down
Expand Up @@ -98,6 +98,7 @@
import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
Expand Down Expand Up @@ -254,12 +255,12 @@ public Map<NodeRef<Identifier>, LambdaArgumentDeclaration> getLambdaArgumentRefe
public Type analyze(Expression expression, Scope scope)
{
Visitor visitor = new Visitor(scope);
return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda(scope)));
return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda()));
}

private Type analyze(Expression expression, Scope baseScope, Context context)
private Type analyze(Expression expression, Scope scope, Context context)
{
Visitor visitor = new Visitor(baseScope);
Visitor visitor = new Visitor(scope);
return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(context));
}

Expand All @@ -281,12 +282,11 @@ public Set<NodeRef<QuantifiedComparisonExpression>> getQuantifiedComparisons()
private class Visitor
extends StackableAstVisitor<Type, Context>
{
// Used to resolve FieldReferences (e.g. during local execution planning)
private final Scope baseScope;
private final Scope scope;

public Visitor(Scope baseScope)
private Visitor(Scope scope)
{
this.baseScope = requireNonNull(baseScope, "baseScope is null");
this.scope = requireNonNull(scope, "scope is null");
}

@Override
Expand Down Expand Up @@ -348,9 +348,10 @@ protected Type visitCurrentTime(CurrentTime node, StackableAstVisitorContext<Con
protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
Optional<ResolvedField> resolvedField = context.getContext().getScope().tryResolveField(node, QualifiedName.of(node.getName()));
if (resolvedField.isPresent() && context.getContext().getFieldToLambdaArgumentDeclaration().containsKey(FieldId.from(resolvedField.get()))) {
return setExpressionType(node, resolvedField.get().getType());
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getNameToLambdaArgumentDeclarationMap().get(node.getName());
if (lambdaArgumentDeclaration != null) {
Type result = getExpressionType(lambdaArgumentDeclaration);
return setExpressionType(node, result);
}
}
Type type = symbolTypes.get(Symbol.from(node));
Expand All @@ -360,26 +361,24 @@ protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorCon
@Override
protected Type visitIdentifier(Identifier node, StackableAstVisitorContext<Context> context)
{
ResolvedField resolvedField = context.getContext().getScope().resolveField(node, QualifiedName.of(node.getValue()));
return handleResolvedField(node, resolvedField, context);
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getNameToLambdaArgumentDeclarationMap().get(node.getValue());
if (lambdaArgumentDeclaration != null) {
lambdaArgumentReferences.put(NodeRef.of(node), lambdaArgumentDeclaration);
Type result = getExpressionType(lambdaArgumentDeclaration);
return setExpressionType(node, result);
}
}
return handleResolvedField(node, scope.resolveField(node, QualifiedName.of(node.getValue())));
}

private Type handleResolvedField(Expression node, ResolvedField resolvedField, StackableAstVisitorContext<Context> context)
private Type handleResolvedField(Expression node, ResolvedField resolvedField)
{
return handleResolvedField(node, FieldId.from(resolvedField), resolvedField.getType(), context);
return handleResolvedField(node, FieldId.from(resolvedField), resolvedField.getType());
}

private Type handleResolvedField(Expression node, FieldId fieldId, Type resolvedType, StackableAstVisitorContext<Context> context)
private Type handleResolvedField(Expression node, FieldId fieldId, Type resolvedType)
{
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getFieldToLambdaArgumentDeclaration().get(fieldId);
if (lambdaArgumentDeclaration != null) {
// Lambda argument reference is not a column reference
lambdaArgumentReferences.put(NodeRef.of((Identifier) node), lambdaArgumentDeclaration);
return setExpressionType(node, resolvedType);
}
}

FieldId previous = columnReferences.put(NodeRef.of(node), fieldId);
checkState(previous == null, "%s already known to refer to %s", node, previous);
return setExpressionType(node, resolvedType);
Expand All @@ -390,15 +389,16 @@ protected Type visitDereferenceExpression(DereferenceExpression node, StackableA
{
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);

// If this Dereference looks like column reference, try match it to column first.
if (qualifiedName != null) {
Scope scope = context.getContext().getScope();
Optional<ResolvedField> resolvedField = scope.tryResolveField(node, qualifiedName);
if (resolvedField.isPresent()) {
return handleResolvedField(node, resolvedField.get(), context);
}
if (!scope.isColumnReference(qualifiedName)) {
throw missingAttributeException(node, qualifiedName);
if (!context.getContext().isInLambda()) {
// If this Dereference looks like column reference, try match it to column first.
if (qualifiedName != null) {
Optional<ResolvedField> resolvedField = scope.tryResolveField(node, qualifiedName);
if (resolvedField.isPresent()) {
return handleResolvedField(node, resolvedField.get());
}
if (!scope.isColumnReference(qualifiedName)) {
throw missingAttributeException(node, qualifiedName);
}
}
}

Expand Down Expand Up @@ -793,11 +793,11 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext<C
parameters,
isDescribe);
if (context.getContext().isInLambda()) {
for (LambdaArgumentDeclaration argument : context.getContext().getFieldToLambdaArgumentDeclaration().values()) {
for (LambdaArgumentDeclaration argument : context.getContext().getNameToLambdaArgumentDeclarationMap().values()) {
innerExpressionAnalyzer.setExpressionType(argument, getExpressionType(argument));
}
}
return innerExpressionAnalyzer.analyze(expression, baseScope, context.getContext().expectingLambda(types)).getTypeSignature();
return innerExpressionAnalyzer.analyze(expression, scope, context.getContext().expectingLambda(types)).getTypeSignature();
}));
}
else {
Expand Down Expand Up @@ -970,7 +970,7 @@ protected Type visitSubqueryExpression(SubqueryExpression node, StackableAstVisi
}
StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node);
Scope subqueryScope = Scope.builder()
.withParent(context.getContext().getScope())
.withParent(scope)
.build();
Scope queryScope = analyzer.analyze(node.getQuery(), subqueryScope);

Expand Down Expand Up @@ -1001,7 +1001,7 @@ else if (previousNode instanceof QuantifiedComparisonExpression) {
protected Type visitExists(ExistsPredicate node, StackableAstVisitorContext<Context> context)
{
StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node);
Scope subqueryScope = Scope.builder().withParent(context.getContext().getScope()).build();
Scope subqueryScope = Scope.builder().withParent(scope).build();
analyzer.analyze(node.getSubquery(), subqueryScope);

existsSubqueries.add(NodeRef.of(node));
Expand Down Expand Up @@ -1045,8 +1045,8 @@ protected Type visitQuantifiedComparisonExpression(QuantifiedComparisonExpressio
@Override
public Type visitFieldReference(FieldReference node, StackableAstVisitorContext<Context> context)
{
Type type = baseScope.getRelationType().getFieldByIndex(node.getFieldIndex()).getType();
return handleResolvedField(node, new FieldId(baseScope.getRelationId(), node.getFieldIndex()), type, context);
Type type = scope.getRelationType().getFieldByIndex(node.getFieldIndex()).getType();
return handleResolvedField(node, new FieldId(scope.getRelationId(), node.getFieldIndex()), type);
}

@Override
Expand All @@ -1065,29 +1065,16 @@ protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorC
format("Expected a lambda that takes %s argument(s) but got %s", types.size(), lambdaArguments.size()));
}

ImmutableList.Builder<Field> fields = ImmutableList.builder();
for (int i = 0; i < lambdaArguments.size(); i++) {
LambdaArgumentDeclaration lambdaArgument = lambdaArguments.get(i);
Type type = types.get(i);
fields.add(Field.newUnqualified(lambdaArgument.getName().getValue(), type));
setExpressionType(lambdaArgument, type);
}

Scope lambdaScope = Scope.builder()
.withParent(context.getContext().getScope())
.withRelationType(RelationId.of(node), new RelationType(fields.build()))
.build();

ImmutableMap.Builder<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration = ImmutableMap.builder();
Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap = new HashMap<>();
if (context.getContext().isInLambda()) {
fieldToLambdaArgumentDeclaration.putAll(context.getContext().getFieldToLambdaArgumentDeclaration());
nameToLambdaArgumentDeclarationMap.putAll(context.getContext().getNameToLambdaArgumentDeclarationMap());
}
for (LambdaArgumentDeclaration lambdaArgument : lambdaArguments) {
ResolvedField resolvedField = lambdaScope.resolveField(lambdaArgument, QualifiedName.of(lambdaArgument.getName().getValue()));
fieldToLambdaArgumentDeclaration.put(FieldId.from(resolvedField), lambdaArgument);
for (int i = 0; i < lambdaArguments.size(); i++) {
LambdaArgumentDeclaration lambdaArgument = lambdaArguments.get(i);
nameToLambdaArgumentDeclarationMap.put(lambdaArgument.getName().getValue(), lambdaArgument);
setExpressionType(lambdaArgument, types.get(i));
}

Type returnType = process(node.getBody(), new StackableAstVisitorContext<>(Context.inLambda(lambdaScope, fieldToLambdaArgumentDeclaration.build())));
Type returnType = process(node.getBody(), new StackableAstVisitorContext<>(Context.inLambda(nameToLambdaArgumentDeclarationMap)));
FunctionType functionType = new FunctionType(types, returnType);
return setExpressionType(node, functionType);
}
Expand Down Expand Up @@ -1266,66 +1253,57 @@ else if (typeOnlyCoercions.contains(ref)) {

private static class Context
{
private final Scope scope;

// functionInputTypes and nameToLambdaDeclarationMap can be null or non-null independently. All 4 combinations are possible.

// The list of types when expecting a lambda (i.e. processing lambda parameters of a function); null otherwise.
// Empty list represents expecting a lambda with no arguments.
private final List<Type> functionInputTypes;
// The mapping from names to corresponding lambda argument declarations when inside a lambda; null otherwise.
// Empty map means that the all lambda expressions surrounding the current node has no arguments.
private final Map<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration;
private final Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap;

private Context(
Scope scope,
List<Type> functionInputTypes,
Map<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration)
Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap)
{
this.scope = requireNonNull(scope, "scope is null");
this.functionInputTypes = functionInputTypes;
this.fieldToLambdaArgumentDeclaration = fieldToLambdaArgumentDeclaration;
this.nameToLambdaArgumentDeclarationMap = nameToLambdaArgumentDeclarationMap;
}

public static Context notInLambda(Scope scope)
public static Context notInLambda()
{
return new Context(scope, null, null);
return new Context(null, null);
}

public static Context inLambda(Scope scope, Map<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration)
public static Context inLambda(Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap)
{
return new Context(scope, null, requireNonNull(fieldToLambdaArgumentDeclaration, "fieldToLambdaArgumentDeclaration is null"));
return new Context(null, requireNonNull(nameToLambdaArgumentDeclarationMap, "nameToLambdaArgumentDeclarationMap is null"));
}

public Context expectingLambda(List<Type> functionInputTypes)
{
return new Context(scope, requireNonNull(functionInputTypes, "functionInputTypes is null"), this.fieldToLambdaArgumentDeclaration);
return new Context(requireNonNull(functionInputTypes, "functionInputTypes is null"), this.nameToLambdaArgumentDeclarationMap);
}

public Context notExpectingLambda()
{
return new Context(scope, null, this.fieldToLambdaArgumentDeclaration);
}

Scope getScope()
{
return scope;
return new Context(null, this.nameToLambdaArgumentDeclarationMap);
}

public boolean isInLambda()
{
return fieldToLambdaArgumentDeclaration != null;
return nameToLambdaArgumentDeclarationMap != null;
}

public boolean isExpectingLambda()
{
return functionInputTypes != null;
}

public Map<FieldId, LambdaArgumentDeclaration> getFieldToLambdaArgumentDeclaration()
public Map<String, LambdaArgumentDeclaration> getNameToLambdaArgumentDeclarationMap()
{
checkState(isInLambda());
return fieldToLambdaArgumentDeclaration;
return nameToLambdaArgumentDeclarationMap;
}

public List<Type> getFunctionInputTypes()
Expand Down
Expand Up @@ -1673,7 +1673,7 @@ private Scope computeAndAssignOrderByScopeWithAggregation(OrderBy node, Scope so
.filter(expression -> hasReferencesToScope(expression, analysis, outputScope))
.collect(toImmutableList());
List<Expression> orderByAggregationExpressions = orderByAggregationExpressionsBuilder.build().stream()
.filter(expression -> !orderByExpressionsReferencingOutputScope.contains(expression) || analysis.isColumnReference(expression))
.filter(expression -> !orderByExpressionsReferencingOutputScope.contains(expression) || analysis.getColumnReferences().contains(NodeRef.of(expression)))
.collect(toImmutableList());

// generate placeholder fields
Expand Down
Expand Up @@ -171,7 +171,7 @@ public RelationPlan plan(QuerySpecification node)
builder = project(builder, Iterables.concat(outputs, orderByAggregates));
outputs = toSymbolReferences(computeOutputs(builder, outputs));
List<Expression> complexOrderByAggregatesToRemap = orderByAggregates.stream()
.filter(expression -> !analysis.isColumnReference(expression))
.filter(expression -> !analysis.getColumnReferences().contains(NodeRef.of(expression)))
.collect(toImmutableList());
builder = planBuilderFor(builder, analysis.getScope(node.getOrderBy().get()), complexOrderByAggregatesToRemap);
}
Expand Down

0 comments on commit 9750c2e

Please sign in to comment.