Skip to content

Commit

Permalink
Use Scope to resolve lambda arguments in ExpressionAnalyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi authored and sopel39 committed Nov 3, 2017
1 parent 3148df9 commit e506142
Showing 1 changed file with 50 additions and 37 deletions.
Expand Up @@ -98,7 +98,6 @@
import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
Expand Down Expand Up @@ -349,10 +348,9 @@ protected Type visitCurrentTime(CurrentTime node, StackableAstVisitorContext<Con
protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getNameToLambdaArgumentDeclarationMap().get(node.getName());
if (lambdaArgumentDeclaration != null) {
Type result = getExpressionType(lambdaArgumentDeclaration);
return setExpressionType(node, result);
Optional<ResolvedField> resolvedField = context.getContext().getScope().tryResolveField(node, QualifiedName.of(node.getName()));
if (resolvedField.isPresent() && context.getContext().getFieldToLambdaArgumentDeclaration().containsKey(FieldId.from(resolvedField.get()))) {
return setExpressionType(node, resolvedField.get().getType());
}
}
Type type = symbolTypes.get(Symbol.from(node));
Expand All @@ -362,24 +360,26 @@ protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorCon
@Override
protected Type visitIdentifier(Identifier node, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getNameToLambdaArgumentDeclarationMap().get(node.getValue());
if (lambdaArgumentDeclaration != null) {
lambdaArgumentReferences.put(NodeRef.of(node), lambdaArgumentDeclaration);
Type result = getExpressionType(lambdaArgumentDeclaration);
return setExpressionType(node, result);
}
}
return handleResolvedField(node, context.getContext().getScope().resolveField(node, QualifiedName.of(node.getValue())));
ResolvedField resolvedField = context.getContext().getScope().resolveField(node, QualifiedName.of(node.getValue()));
return handleResolvedField(node, resolvedField, context);
}

private Type handleResolvedField(Expression node, ResolvedField resolvedField)
private Type handleResolvedField(Expression node, ResolvedField resolvedField, StackableAstVisitorContext<Context> context)
{
return handleResolvedField(node, FieldId.from(resolvedField), resolvedField.getType());
return handleResolvedField(node, FieldId.from(resolvedField), resolvedField.getType(), context);
}

private Type handleResolvedField(Expression node, FieldId fieldId, Type resolvedType)
private Type handleResolvedField(Expression node, FieldId fieldId, Type resolvedType, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getFieldToLambdaArgumentDeclaration().get(fieldId);
if (lambdaArgumentDeclaration != null) {
// Lambda argument reference is not a column reference
lambdaArgumentReferences.put(NodeRef.of((Identifier) node), lambdaArgumentDeclaration);
return setExpressionType(node, resolvedType);
}
}

FieldId previous = columnReferences.put(NodeRef.of(node), fieldId);
checkState(previous == null, "%s already known to refer to %s", node, previous);
return setExpressionType(node, resolvedType);
Expand All @@ -396,7 +396,7 @@ protected Type visitDereferenceExpression(DereferenceExpression node, StackableA
Scope scope = context.getContext().getScope();
Optional<ResolvedField> resolvedField = scope.tryResolveField(node, qualifiedName);
if (resolvedField.isPresent()) {
return handleResolvedField(node, resolvedField.get());
return handleResolvedField(node, resolvedField.get(), context);
}
if (!scope.isColumnReference(qualifiedName)) {
throw missingAttributeException(node, qualifiedName);
Expand Down Expand Up @@ -795,7 +795,7 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext<C
parameters,
isDescribe);
if (context.getContext().isInLambda()) {
for (LambdaArgumentDeclaration argument : context.getContext().getNameToLambdaArgumentDeclarationMap().values()) {
for (LambdaArgumentDeclaration argument : context.getContext().getFieldToLambdaArgumentDeclaration().values()) {
innerExpressionAnalyzer.setExpressionType(argument, getExpressionType(argument));
}
}
Expand Down Expand Up @@ -1048,7 +1048,7 @@ protected Type visitQuantifiedComparisonExpression(QuantifiedComparisonExpressio
public Type visitFieldReference(FieldReference node, StackableAstVisitorContext<Context> context)
{
Type type = baseScope.getRelationType().getFieldByIndex(node.getFieldIndex()).getType();
return handleResolvedField(node, new FieldId(baseScope.getRelationId(), node.getFieldIndex()), type);
return handleResolvedField(node, new FieldId(baseScope.getRelationId(), node.getFieldIndex()), type, context);
}

@Override
Expand All @@ -1067,16 +1067,29 @@ protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorC
format("Expected a lambda that takes %s argument(s) but got %s", types.size(), lambdaArguments.size()));
}

Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap = new HashMap<>();
if (context.getContext().isInLambda()) {
nameToLambdaArgumentDeclarationMap.putAll(context.getContext().getNameToLambdaArgumentDeclarationMap());
}
ImmutableList.Builder<Field> fields = ImmutableList.builder();
for (int i = 0; i < lambdaArguments.size(); i++) {
LambdaArgumentDeclaration lambdaArgument = lambdaArguments.get(i);
nameToLambdaArgumentDeclarationMap.put(lambdaArgument.getName().getValue(), lambdaArgument);
setExpressionType(lambdaArgument, types.get(i));
Type type = types.get(i);
fields.add(Field.newUnqualified(lambdaArgument.getName().getValue(), type));
setExpressionType(lambdaArgument, type);
}

Scope lambdaScope = Scope.builder()
.withParent(context.getContext().getScope())
.withRelationType(RelationId.of(node), new RelationType(fields.build()))
.build();

ImmutableMap.Builder<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration = ImmutableMap.builder();
if (context.getContext().isInLambda()) {
fieldToLambdaArgumentDeclaration.putAll(context.getContext().getFieldToLambdaArgumentDeclaration());
}
Type returnType = process(node.getBody(), new StackableAstVisitorContext<>(Context.inLambda(context.getContext().getScope(), nameToLambdaArgumentDeclarationMap)));
for (LambdaArgumentDeclaration lambdaArgument : lambdaArguments) {
ResolvedField resolvedField = lambdaScope.resolveField(lambdaArgument, QualifiedName.of(lambdaArgument.getName().getValue()));
fieldToLambdaArgumentDeclaration.put(FieldId.from(resolvedField), lambdaArgument);
}

Type returnType = process(node.getBody(), new StackableAstVisitorContext<>(Context.inLambda(lambdaScope, fieldToLambdaArgumentDeclaration.build())));
FunctionType functionType = new FunctionType(types, returnType);
return setExpressionType(node, functionType);
}
Expand Down Expand Up @@ -1264,36 +1277,36 @@ private static class Context
private final List<Type> functionInputTypes;
// The mapping from names to corresponding lambda argument declarations when inside a lambda; null otherwise.
// Empty map means that the all lambda expressions surrounding the current node has no arguments.
private final Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap;
private final Map<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration;

private Context(
Scope scope,
List<Type> functionInputTypes,
Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap)
Map<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration)
{
this.scope = requireNonNull(scope, "scope is null");
this.functionInputTypes = functionInputTypes;
this.nameToLambdaArgumentDeclarationMap = nameToLambdaArgumentDeclarationMap;
this.fieldToLambdaArgumentDeclaration = fieldToLambdaArgumentDeclaration;
}

public static Context notInLambda(Scope scope)
{
return new Context(scope, null, null);
}

public static Context inLambda(Scope scope, Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap)
public static Context inLambda(Scope scope, Map<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration)
{
return new Context(scope, null, requireNonNull(nameToLambdaArgumentDeclarationMap, "nameToLambdaArgumentDeclarationMap is null"));
return new Context(scope, null, requireNonNull(fieldToLambdaArgumentDeclaration, "fieldToLambdaArgumentDeclaration is null"));
}

public Context expectingLambda(List<Type> functionInputTypes)
{
return new Context(scope, requireNonNull(functionInputTypes, "functionInputTypes is null"), this.nameToLambdaArgumentDeclarationMap);
return new Context(scope, requireNonNull(functionInputTypes, "functionInputTypes is null"), this.fieldToLambdaArgumentDeclaration);
}

public Context notExpectingLambda()
{
return new Context(scope, null, this.nameToLambdaArgumentDeclarationMap);
return new Context(scope, null, this.fieldToLambdaArgumentDeclaration);
}

Scope getScope()
Expand All @@ -1303,18 +1316,18 @@ Scope getScope()

public boolean isInLambda()
{
return nameToLambdaArgumentDeclarationMap != null;
return fieldToLambdaArgumentDeclaration != null;
}

public boolean isExpectingLambda()
{
return functionInputTypes != null;
}

public Map<String, LambdaArgumentDeclaration> getNameToLambdaArgumentDeclarationMap()
public Map<FieldId, LambdaArgumentDeclaration> getFieldToLambdaArgumentDeclaration()
{
checkState(isInLambda());
return nameToLambdaArgumentDeclarationMap;
return fieldToLambdaArgumentDeclaration;
}

public List<Type> getFunctionInputTypes()
Expand Down

0 comments on commit e506142

Please sign in to comment.