Skip to content

Commit

Permalink
Extend BindExpression to bind multiple values at a time
Browse files Browse the repository at this point in the history
  • Loading branch information
wenleix committed Jun 12, 2017
1 parent e1922b0 commit 0c904db
Show file tree
Hide file tree
Showing 15 changed files with 123 additions and 58 deletions.
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -375,7 +375,12 @@ protected Boolean visitLambdaExpression(LambdaExpression node, Void context)
@Override @Override
protected Boolean visitBindExpression(BindExpression node, Void context) protected Boolean visitBindExpression(BindExpression node, Void context)
{ {
return process(node.getValue(), context) && process(node.getFunction(), context); for (Expression value : node.getValues()) {
if (!process(value, context)) {
return false;
}
}
return process(node.getFunction(), context);
} }


@Override @Override
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -1087,18 +1087,24 @@ protected Type visitBindExpression(BindExpression node, StackableAstVisitorConte
{ {
verify(context.getContext().isExpectingLambda(), "bind expression found when lambda is not expected"); verify(context.getContext().isExpectingLambda(), "bind expression found when lambda is not expected");


List<Type> functionInputTypes = ImmutableList.<Type>builder() StackableAstVisitorContext<Context> innerContext = new StackableAstVisitorContext<>(context.getContext().notExpectingLambda());
.add(process(node.getValue(), new StackableAstVisitorContext<>(context.getContext().notExpectingLambda()))) ImmutableList.Builder<Type> functionInputTypesBuilder = ImmutableList.builder();
.addAll(context.getContext().getFunctionInputTypes()) for (Expression value : node.getValues()) {
.build(); functionInputTypesBuilder.add(process(value, innerContext));
}
functionInputTypesBuilder.addAll(context.getContext().getFunctionInputTypes());
List<Type> functionInputTypes = functionInputTypesBuilder.build();


FunctionType functionType = (FunctionType) process(node.getFunction(), new StackableAstVisitorContext<>(context.getContext().expectingLambda(functionInputTypes))); FunctionType functionType = (FunctionType) process(node.getFunction(), new StackableAstVisitorContext<>(context.getContext().expectingLambda(functionInputTypes)));


List<Type> argumentTypes = functionType.getArgumentTypes(); List<Type> argumentTypes = functionType.getArgumentTypes();
int numCapturedValues = node.getValues().size();
verify(argumentTypes.size() == functionInputTypes.size()); verify(argumentTypes.size() == functionInputTypes.size());
verify(functionInputTypes.get(0) == argumentTypes.get(0)); for (int i = 0; i < numCapturedValues; i++) {
verify(functionInputTypes.get(i) == argumentTypes.get(i));
}


FunctionType result = new FunctionType(argumentTypes.subList(1, argumentTypes.size()), functionType.getReturnType()); FunctionType result = new FunctionType(argumentTypes.subList(numCapturedValues, argumentTypes.size()), functionType.getReturnType());
return setExpressionType(node, result); return setExpressionType(node, result);
} }


Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable; import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.bytecode.expression.BytecodeExpression;
import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.RowExpression;
Expand Down Expand Up @@ -47,15 +48,20 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon


Variable wasNull = scope.getVariable("wasNull"); Variable wasNull = scope.getVariable("wasNull");


Class<?> valueType = Primitives.wrap(arguments.get(0).getType().getJavaType()); ImmutableList.Builder<BytecodeExpression> captureVariablesBuilder = ImmutableList.builder();
Variable valueVariable = scope.createTempVariable(valueType); int numValues = arguments.size() - 1;
block.append(context.generate(arguments.get(0))); for (int i = 0; i < numValues; i++) {
block.append(boxPrimitiveIfNecessary(scope, valueType)); Class<?> valueType = Primitives.wrap(arguments.get(i).getType().getJavaType());
block.putVariable(valueVariable); Variable valueVariable = scope.createTempVariable(valueType);
block.append(wasNull.set(constantFalse())); block.append(context.generate(arguments.get(i)));
block.append(boxPrimitiveIfNecessary(scope, valueType));
block.putVariable(valueVariable);
block.append(wasNull.set(constantFalse()));
captureVariablesBuilder.add(valueVariable.cast(Object.class));
}


Variable functionVariable = scope.createTempVariable(MethodHandle.class); Variable functionVariable = scope.createTempVariable(MethodHandle.class);
block.append(context.generate(arguments.get(1))); block.append(context.generate(arguments.get(numValues)));
block.append( block.append(
new IfStatement() new IfStatement()
.condition(wasNull) .condition(wasNull)
Expand All @@ -69,7 +75,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon
MethodHandle.class, MethodHandle.class,
functionVariable, functionVariable,
constantInt(0), constantInt(0),
newArray(type(Object[].class), ImmutableList.of(valueVariable.cast(Object.class))))))); newArray(type(Object[].class), captureVariablesBuilder.build())))));


return block; return block;
} }
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -1002,16 +1002,23 @@ protected Object visitLambdaExpression(LambdaExpression node, Object context)
@Override @Override
protected Object visitBindExpression(BindExpression node, Object context) protected Object visitBindExpression(BindExpression node, Object context)
{ {
Object value = process(node.getValue(), context); List<Object> values = node.getValues().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
Object function = process(node.getFunction(), context); Object function = process(node.getFunction(), context);


if (hasUnresolvedValue(value, function)) { if (hasUnresolvedValue(values) || hasUnresolvedValue(function)) {
ImmutableList.Builder<Expression> builder = ImmutableList.builder();
for (int i = 0; i < values.size(); i++) {
builder.add(toExpression(values.get(i), type(node.getValues().get(i))));
}

return new BindExpression( return new BindExpression(
toExpression(value, type(node.getValue())), builder.build(),
toExpression(function, type(node.getFunction()))); toExpression(function, type(node.getFunction())));
} }


return MethodHandles.insertArguments((MethodHandle) function, 0, value); return MethodHandles.insertArguments((MethodHandle) function, 0, values.toArray());
} }


@Override @Override
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ public Expression rewriteLambdaExpression(LambdaExpression node, Context context
} }
newLambdaArguments.addAll(node.getArguments()); newLambdaArguments.addAll(node.getArguments());
Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), replaceSymbols(rewrittenBody, captureSymbolToExtraSymbol.build())); Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), replaceSymbols(rewrittenBody, captureSymbolToExtraSymbol.build()));
for (Symbol captureSymbol : captureSymbols) {
rewrittenExpression = new BindExpression(new SymbolReference(captureSymbol.getName()), rewrittenExpression); if (captureSymbols.size() != 0) {
List<Expression> capturedValues = captureSymbols.stream()
.map(symbol -> new SymbolReference(symbol.getName()))
.collect(toImmutableList());
rewrittenExpression = new BindExpression(capturedValues, rewrittenExpression);
} }


context.getReferencedSymbols().addAll(captureSymbols); context.getReferencedSymbols().addAll(captureSymbols);
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -155,9 +155,14 @@ public static Signature trySignature(Type returnType)
return new Signature(TRY, SCALAR, returnType.getTypeSignature()); return new Signature(TRY, SCALAR, returnType.getTypeSignature());
} }


public static Signature bindSignature(Type returnType, Type valueType, Type functionType) public static Signature bindSignature(Type returnType, List<Type> valueTypes, Type functionType)
{ {
return new Signature(BIND, SCALAR, returnType.getTypeSignature(), valueType.getTypeSignature(), functionType.getTypeSignature()); ImmutableList.Builder<TypeSignature> typeSignatureBuilder = ImmutableList.builder();
for (Type valueType : valueTypes) {
typeSignatureBuilder.add(valueType.getTypeSignature());
}
typeSignatureBuilder.add(functionType.getTypeSignature());
return new Signature(BIND, SCALAR, returnType.getTypeSignature(), typeSignatureBuilder.build());
} }


// **************** functions that require varargs and/or complex types (e.g., lists) **************** // **************** functions that require varargs and/or complex types (e.g., lists) ****************
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -353,14 +353,20 @@ protected RowExpression visitLambdaExpression(LambdaExpression node, Void contex
@Override @Override
protected RowExpression visitBindExpression(BindExpression node, Void context) protected RowExpression visitBindExpression(BindExpression node, Void context)
{ {
RowExpression value = process(node.getValue(), context); ImmutableList.Builder<Type> valueTypesBuilder = ImmutableList.builder();
ImmutableList.Builder<RowExpression> argumentsBuilder = ImmutableList.builder();
for (Expression value : node.getValues()) {
RowExpression valueRowExpression = process(value, context);
valueTypesBuilder.add(valueRowExpression.getType());
argumentsBuilder.add(valueRowExpression);
}
RowExpression function = process(node.getFunction(), context); RowExpression function = process(node.getFunction(), context);
argumentsBuilder.add(function);


return call( return call(
bindSignature(getType(node), value.getType(), function.getType()), bindSignature(getType(node), valueTypesBuilder.build(), function.getType()),
getType(node), getType(node),
value, argumentsBuilder.build());
function);
} }


@Override @Override
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -129,15 +129,23 @@ public RowExpression visitCall(CallExpression call, Void context)
return call(signature, call.getType(), arguments); return call(signature, call.getType(), arguments);
} }
case BIND: { case BIND: {
checkState(call.getArguments().size() == 2, BIND + " function should have 2 arguments. Got " + call.getArguments().size()); checkState(call.getArguments().size() >= 1, BIND + " function should have at least 1 argument. Got " + call.getArguments().size());
RowExpression optimizedValue = call.getArguments().get(0).accept(this, context);
RowExpression optimizedFunction = call.getArguments().get(1).accept(this, context); boolean allConstantExpression = true;
if (optimizedValue instanceof ConstantExpression && optimizedFunction instanceof ConstantExpression) { ImmutableList.Builder<RowExpression> optimizedArgumentsBuilder = ImmutableList.builder();
// Here, optimizedValue and optimizedFunction should be merged together into a new ConstantExpression. for (RowExpression argument : call.getArguments()) {
RowExpression optimizedArgument = argument.accept(this, context);
if (!(optimizedArgument instanceof ConstantExpression)) {
allConstantExpression = false;
}
optimizedArgumentsBuilder.add(optimizedArgument);
}
if (allConstantExpression) {
// Here, optimizedArguments should be merged together into a new ConstantExpression.
// It's not implemented because it would be dead code anyways because visitLambda does not produce ConstantExpression. // It's not implemented because it would be dead code anyways because visitLambda does not produce ConstantExpression.
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
return call(signature, call.getType(), ImmutableList.of(optimizedValue, optimizedFunction)); return call(signature, call.getType(), optimizedArgumentsBuilder.build());
} }
case NULL_IF: case NULL_IF:
case SWITCH: case SWITCH:
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public TestArrayTransformFunction()
public void testBasic() public void testBasic()
throws Exception throws Exception
{ {
assertFunction("transform(ARRAY [5, 6], x -> 9)", new ArrayType(INTEGER), ImmutableList.of(9, 9));
assertFunction("transform(ARRAY [5, 6], x -> x + 1)", new ArrayType(INTEGER), ImmutableList.of(6, 7)); assertFunction("transform(ARRAY [5, 6], x -> x + 1)", new ArrayType(INTEGER), ImmutableList.of(6, 7));
assertFunction("transform(ARRAY [5 + RANDOM(1), 6], x -> x + 1)", new ArrayType(INTEGER), ImmutableList.of(6, 7)); assertFunction("transform(ARRAY [5 + RANDOM(1), 6], x -> x + 1)", new ArrayType(INTEGER), ImmutableList.of(6, 7));
} }
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ public void testBind()
{ {
assertFunction("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", INTEGER, 99); assertFunction("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", INTEGER, 99);
assertFunction("invoke(\"$internal$bind\"(8, x -> x + 1))", INTEGER, 9); assertFunction("invoke(\"$internal$bind\"(8, x -> x + 1))", INTEGER, 9);
assertFunction("apply(900, \"$internal$bind\"(90, \"$internal$bind\"(9, (x, y, z) -> x + y + z)))", INTEGER, 999); assertFunction("apply(900, \"$internal$bind\"(90, 9, (x, y, z) -> x + y + z))", INTEGER, 999);
assertFunction("invoke(\"$internal$bind\"(90, \"$internal$bind\"(9, (x, y) -> x + y)))", INTEGER, 99); assertFunction("invoke(\"$internal$bind\"(90, 9, (x, y) -> x + y))", INTEGER, 99);
} }


@Test @Test
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -368,9 +368,14 @@ protected String visitLambdaExpression(LambdaExpression node, Void context)
@Override @Override
protected String visitBindExpression(BindExpression node, Void context) protected String visitBindExpression(BindExpression node, Void context)
{ {
return "\"$INTERNAL$BIND\"(" + StringBuilder builder = new StringBuilder();
process(node.getValue(), context) + ", " +
process(node.getFunction(), context) + ")"; builder.append("\"$INTERNAL$BIND\"(");
for (Expression value : node.getValues()) {
builder.append(process(value, context) + ", ");
}
builder.append(process(node.getFunction(), context) + ")");
return builder.toString();
} }


@Override @Override
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;


import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.String.format; import static java.lang.String.format;
import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -1278,14 +1279,20 @@ public Node visitFunctionCall(SqlBaseParser.FunctionCallContext context)
return new TryExpression(getLocation(context), (Expression) visit(getOnlyElement(context.expression()))); return new TryExpression(getLocation(context), (Expression) visit(getOnlyElement(context.expression())));
} }
if (name.toString().equalsIgnoreCase("$internal$bind")) { if (name.toString().equalsIgnoreCase("$internal$bind")) {
check(context.expression().size() == 2, "The '$internal$bind' function must have exactly two arguments", context); check(context.expression().size() >= 1, "The '$internal$bind' function must have at least one arguments", context);
check(!window.isPresent(), "OVER clause not valid for '$internal$bind' function", context); check(!window.isPresent(), "OVER clause not valid for '$internal$bind' function", context);
check(!distinct, "DISTINCT not valid for '$internal$bind' function", context); check(!distinct, "DISTINCT not valid for '$internal$bind' function", context);


int numValues = context.expression().size() - 1;
List<Expression> arguments = context.expression().stream()
.map(this::visit)
.map(Expression.class::cast)
.collect(toImmutableList());

return new BindExpression( return new BindExpression(
getLocation(context), getLocation(context),
(Expression) visit(context.expression(0)), arguments.subList(0, numValues),
(Expression) visit(context.expression(1))); arguments.get(numValues));
} }


return new FunctionCall( return new FunctionCall(
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -47,31 +47,31 @@
public class BindExpression public class BindExpression
extends Expression extends Expression
{ {
private final Expression value; private final List<Expression> values;
// Function expression must be of function type. // Function expression must be of function type.
// It is not necessarily a lambda. For example, it can be another bind expression. // It is not necessarily a lambda. For example, it can be another bind expression.
private final Expression function; private final Expression function;


public BindExpression(Expression value, Expression function) public BindExpression(List<Expression> values, Expression function)
{ {
this(Optional.empty(), value, function); this(Optional.empty(), values, function);
} }


public BindExpression(NodeLocation location, Expression value, Expression function) public BindExpression(NodeLocation location, List<Expression> values, Expression function)
{ {
this(Optional.of(location), value, function); this(Optional.of(location), values, function);
} }


private BindExpression(Optional<NodeLocation> location, Expression value, Expression function) private BindExpression(Optional<NodeLocation> location, List<Expression> values, Expression function)
{ {
super(location); super(location);
this.value = requireNonNull(value, "value is null"); this.values = requireNonNull(values, "value is null");
this.function = requireNonNull(function, "function is null"); this.function = requireNonNull(function, "function is null");
} }


public Expression getValue() public List<Expression> getValues()
{ {
return value; return values;
} }


public Expression getFunction() public Expression getFunction()
Expand All @@ -89,7 +89,7 @@ public <R, C> R accept(AstVisitor<R, C> visitor, C context)
public List<Node> getChildren() public List<Node> getChildren()
{ {
ImmutableList.Builder<Node> nodes = ImmutableList.builder(); ImmutableList.Builder<Node> nodes = ImmutableList.builder();
return nodes.add(value) return nodes.addAll(values)
.add(function) .add(function)
.build(); .build();
} }
Expand All @@ -104,13 +104,13 @@ public boolean equals(Object o)
return false; return false;
} }
BindExpression that = (BindExpression) o; BindExpression that = (BindExpression) o;
return Objects.equals(value, that.value) && return Objects.equals(values, that.values) &&
Objects.equals(function, that.function); Objects.equals(function, that.function);
} }


@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(value, function); return Objects.hash(values, function);
} }
} }
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ protected R visitTryExpression(TryExpression node, C context)
@Override @Override
protected R visitBindExpression(BindExpression node, C context) protected R visitBindExpression(BindExpression node, C context)
{ {
process(node.getValue(), context); for (Expression value : node.getValues()) {
process(value, context);
}
process(node.getFunction(), context); process(node.getFunction(), context);


return null; return null;
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;


import static com.google.common.collect.ImmutableList.toImmutableList;

public final class ExpressionTreeRewriter<C> public final class ExpressionTreeRewriter<C>
{ {
private final ExpressionRewriter<C> rewriter; private final ExpressionRewriter<C> rewriter;
Expand Down Expand Up @@ -598,13 +600,14 @@ protected Expression visitBindExpression(BindExpression node, Context<C> context
} }
} }


Expression value = rewrite(node.getValue(), context.get()); List<Expression> values = node.getValues().stream()
.map(value -> rewrite(value, context.get()))
.collect(toImmutableList());
Expression function = rewrite(node.getFunction(), context.get()); Expression function = rewrite(node.getFunction(), context.get());


if ((value != node.getValue()) || (function != node.getFunction())) { if (!sameElements(values, node.getValues()) || (function != node.getFunction())) {
return new BindExpression(value, function); return new BindExpression(values, function);
} }

return node; return node;
} }


Expand Down

0 comments on commit 0c904db

Please sign in to comment.