Skip to content

Commit

Permalink
Avoid binding "this" for every row for lambda execution
Browse files Browse the repository at this point in the history
  • Loading branch information
wenleix committed May 14, 2017
1 parent f8495df commit e4a408e
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 46 deletions.
Expand Up @@ -24,10 +24,7 @@
import com.facebook.presto.sql.relational.RowExpressionVisitor; import com.facebook.presto.sql.relational.RowExpressionVisitor;
import com.facebook.presto.sql.relational.VariableReferenceExpression; import com.facebook.presto.sql.relational.VariableReferenceExpression;


import java.lang.invoke.MethodHandle;

import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic;
import static com.facebook.presto.bytecode.instruction.Constant.loadBoolean; import static com.facebook.presto.bytecode.instruction.Constant.loadBoolean;
import static com.facebook.presto.bytecode.instruction.Constant.loadDouble; import static com.facebook.presto.bytecode.instruction.Constant.loadDouble;
import static com.facebook.presto.bytecode.instruction.Constant.loadFloat; import static com.facebook.presto.bytecode.instruction.Constant.loadFloat;
Expand Down Expand Up @@ -194,8 +191,7 @@ public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Scope scope)
{ {
checkState(preGeneratedExpressions.getLambdaFieldMap().containsKey(lambda), "lambda expressions map does not contain this lambda definition"); checkState(preGeneratedExpressions.getLambdaFieldMap().containsKey(lambda), "lambda expressions map does not contain this lambda definition");


return getStatic(preGeneratedExpressions.getLambdaFieldMap().get(lambda)) return scope.getThis().getField(preGeneratedExpressions.getLambdaFieldMap().get(lambda).getInstanceField());
.invoke("bindTo", MethodHandle.class, scope.getThis().cast(Object.class));
} }


@Override @Override
Expand Down
Expand Up @@ -16,7 +16,6 @@
import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.ClassDefinition; import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Scope;
Expand All @@ -31,6 +30,7 @@
import com.facebook.presto.spi.RecordCursor; import com.facebook.presto.spi.RecordCursor;
import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.LambdaExpressionField;
import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.CallExpression;
import com.facebook.presto.sql.relational.ConstantExpression; import com.facebook.presto.sql.relational.ConstantExpression;
import com.facebook.presto.sql.relational.InputReferenceExpression; import com.facebook.presto.sql.relational.InputReferenceExpression;
Expand All @@ -46,6 +46,7 @@
import com.google.common.primitives.Primitives; import com.google.common.primitives.Primitives;
import io.airlift.slice.Slice; import io.airlift.slice.Slice;


import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;


Expand Down Expand Up @@ -74,11 +75,19 @@ public CursorProcessorCompiler(Metadata metadata)
public void generateMethods(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter, List<RowExpression> projections) public void generateMethods(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter, List<RowExpression> projections)
{ {
CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);
List<PreGeneratedExpressions> allPreGeneratedExpressions = new ArrayList<>(projections.size() + 1);

generateProcessMethod(classDefinition, projections.size()); generateProcessMethod(classDefinition, projections.size());
generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filter);
PreGeneratedExpressions filterPreGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, filter, "filter");
allPreGeneratedExpressions.add(filterPreGeneratedExpressions);
generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filterPreGeneratedExpressions, filter);


for (int i = 0; i < projections.size(); i++) { for (int i = 0; i < projections.size(); i++) {
generateProjectMethod(classDefinition, callSiteBinder, cachedInstanceBinder, "project_" + i, projections.get(i)); String methodName = "project_" + i;
PreGeneratedExpressions projectPreGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, projections.get(i), methodName);
allPreGeneratedExpressions.add(projectPreGeneratedExpressions);
generateProjectMethod(classDefinition, callSiteBinder, cachedInstanceBinder, projectPreGeneratedExpressions, methodName, projections.get(i));
} }


MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC)); MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC));
Expand All @@ -87,7 +96,13 @@ public void generateMethods(ClassDefinition classDefinition, CallSiteBinder call
constructorBody.comment("super();") constructorBody.comment("super();")
.append(thisVariable) .append(thisVariable)
.invokeConstructor(Object.class); .invokeConstructor(Object.class);

cachedInstanceBinder.generateInitializations(thisVariable, constructorBody); cachedInstanceBinder.generateInitializations(thisVariable, constructorBody);
for (PreGeneratedExpressions preGeneratedExpressions : allPreGeneratedExpressions) {
for (LambdaExpressionField field : preGeneratedExpressions.getLambdaFieldMap().values()) {
field.generateInitialization(thisVariable, constructorBody);
}
}
constructorBody.ret(); constructorBody.ret();
} }


Expand Down Expand Up @@ -192,7 +207,7 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry(
Set<RowExpression> lambdaAndTryExpressions = ImmutableSet.copyOf(extractLambdaAndTryExpressions(projection)); Set<RowExpression> lambdaAndTryExpressions = ImmutableSet.copyOf(extractLambdaAndTryExpressions(projection));


ImmutableMap.Builder<CallExpression, MethodDefinition> tryMethodMap = ImmutableMap.builder(); ImmutableMap.Builder<CallExpression, MethodDefinition> tryMethodMap = ImmutableMap.builder();
ImmutableMap.Builder<LambdaDefinitionExpression, FieldDefinition> lambdaFieldMap = ImmutableMap.builder(); ImmutableMap.Builder<LambdaDefinitionExpression, LambdaExpressionField> lambdaFieldMap = ImmutableMap.builder();


int counter = 0; int counter = 0;
for (RowExpression expression : lambdaAndTryExpressions) { for (RowExpression expression : lambdaAndTryExpressions) {
Expand Down Expand Up @@ -230,15 +245,15 @@ else if (expression instanceof LambdaDefinitionExpression) {
LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) expression; LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) expression;
String fieldName = methodPrefix + "_lambda_" + counter; String fieldName = methodPrefix + "_lambda_" + counter;
PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build());
FieldDefinition methodHandleField = LambdaBytecodeGenerator.preGenerateLambdaExpression( LambdaExpressionField lambdaExpressionField = LambdaBytecodeGenerator.preGenerateLambdaExpression(
lambdaExpression, lambdaExpression,
fieldName, fieldName,
containerClassDefinition, containerClassDefinition,
preGeneratedExpressions, preGeneratedExpressions,
callSiteBinder, callSiteBinder,
cachedInstanceBinder, cachedInstanceBinder,
metadata.getFunctionRegistry()); metadata.getFunctionRegistry());
lambdaFieldMap.put(lambdaExpression, methodHandleField); lambdaFieldMap.put(lambdaExpression, lambdaExpressionField);
} }
else { else {
throw new VerifyException(format("unexpected expression: %s", expression.toString())); throw new VerifyException(format("unexpected expression: %s", expression.toString()));
Expand All @@ -249,10 +264,13 @@ else if (expression instanceof LambdaDefinitionExpression) {
return new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); return new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build());
} }


private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpression filter) private void generateFilterMethod(
ClassDefinition classDefinition,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
PreGeneratedExpressions preGeneratedExpressions,
RowExpression filter)
{ {
PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, filter, "filter");

Parameter session = arg("session", ConnectorSession.class); Parameter session = arg("session", ConnectorSession.class);
Parameter cursor = arg("cursor", RecordCursor.class); Parameter cursor = arg("cursor", RecordCursor.class);
MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), "filter", type(boolean.class), session, cursor); MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), "filter", type(boolean.class), session, cursor);
Expand Down Expand Up @@ -284,10 +302,14 @@ private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinde
.retBoolean(); .retBoolean();
} }


private void generateProjectMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, String methodName, RowExpression projection) private void generateProjectMethod(
ClassDefinition classDefinition,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
PreGeneratedExpressions preGeneratedExpressions,
String methodName,
RowExpression projection)
{ {
PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, projection, methodName);

Parameter session = arg("session", ConnectorSession.class); Parameter session = arg("session", ConnectorSession.class);
Parameter cursor = arg("cursor", RecordCursor.class); Parameter cursor = arg("cursor", RecordCursor.class);
Parameter output = arg("output", BlockBuilder.class); Parameter output = arg("output", BlockBuilder.class);
Expand Down
Expand Up @@ -29,6 +29,7 @@
import com.facebook.presto.operator.StandardJoinFilterFunction; import com.facebook.presto.operator.StandardJoinFilterFunction;
import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.Block;
import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.LambdaExpressionField;
import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression;
import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.CallExpression;
import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.LambdaDefinitionExpression;
Expand Down Expand Up @@ -146,11 +147,17 @@ private void generateMethods(ClassDefinition classDefinition, CallSiteBinder cal


FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class); FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class);


generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filter, leftBlocksSize, sessionField); PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, leftBlocksSize, filter);
generateConstructor(classDefinition, sessionField, cachedInstanceBinder); generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, preGeneratedExpressions, filter, leftBlocksSize, sessionField);

generateConstructor(classDefinition, sessionField, cachedInstanceBinder, preGeneratedExpressions);
} }


private static void generateConstructor(ClassDefinition classDefinition, FieldDefinition sessionField, CachedInstanceBinder cachedInstanceBinder) private static void generateConstructor(
ClassDefinition classDefinition,
FieldDefinition sessionField,
CachedInstanceBinder cachedInstanceBinder,
PreGeneratedExpressions preGeneratedExpressions)
{ {
Parameter sessionParameter = arg("session", ConnectorSession.class); Parameter sessionParameter = arg("session", ConnectorSession.class);
MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), sessionParameter); MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), sessionParameter);
Expand All @@ -164,13 +171,21 @@ private static void generateConstructor(ClassDefinition classDefinition, FieldDe


body.append(thisVariable.setField(sessionField, sessionParameter)); body.append(thisVariable.setField(sessionField, sessionParameter));
cachedInstanceBinder.generateInitializations(thisVariable, body); cachedInstanceBinder.generateInitializations(thisVariable, body);
for (LambdaExpressionField field : preGeneratedExpressions.getLambdaFieldMap().values()) {
field.generateInitialization(thisVariable, body);
}
body.ret(); body.ret();
} }


private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpression filter, int leftBlocksSize, FieldDefinition sessionField) private void generateFilterMethod(
ClassDefinition classDefinition,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
PreGeneratedExpressions preGeneratedExpressions,
RowExpression filter,
int leftBlocksSize,
FieldDefinition sessionField)
{ {
PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, leftBlocksSize, filter);

// int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks // int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks
Parameter leftPosition = arg("leftPosition", int.class); Parameter leftPosition = arg("leftPosition", int.class);
Parameter leftBlocks = arg("leftBlocks", Block[].class); Parameter leftBlocks = arg("leftBlocks", Block[].class);
Expand Down Expand Up @@ -222,7 +237,7 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry(
{ {
Set<RowExpression> lambdaAndTryExpressions = ImmutableSet.copyOf(extractLambdaAndTryExpressions(filter)); Set<RowExpression> lambdaAndTryExpressions = ImmutableSet.copyOf(extractLambdaAndTryExpressions(filter));
ImmutableMap.Builder<CallExpression, MethodDefinition> tryMethodMap = ImmutableMap.builder(); ImmutableMap.Builder<CallExpression, MethodDefinition> tryMethodMap = ImmutableMap.builder();
ImmutableMap.Builder<LambdaDefinitionExpression, FieldDefinition> lambdaFieldMap = ImmutableMap.builder(); ImmutableMap.Builder<LambdaDefinitionExpression, LambdaExpressionField> lambdaFieldMap = ImmutableMap.builder();


int counter = 0; int counter = 0;
for (RowExpression expression : lambdaAndTryExpressions) { for (RowExpression expression : lambdaAndTryExpressions) {
Expand Down Expand Up @@ -265,15 +280,15 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry(
else if (expression instanceof LambdaDefinitionExpression) { else if (expression instanceof LambdaDefinitionExpression) {
LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) expression; LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) expression;
PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build());
FieldDefinition methodHandleField = LambdaBytecodeGenerator.preGenerateLambdaExpression( LambdaExpressionField lambdaExpressionField = LambdaBytecodeGenerator.preGenerateLambdaExpression(
lambdaExpression, lambdaExpression,
"lambda_" + counter, "lambda_" + counter,
containerClassDefinition, containerClassDefinition,
preGeneratedExpressions, preGeneratedExpressions,
callSiteBinder, callSiteBinder,
cachedInstanceBinder, cachedInstanceBinder,
metadata.getFunctionRegistry()); metadata.getFunctionRegistry());
lambdaFieldMap.put(lambdaExpression, methodHandleField); lambdaFieldMap.put(lambdaExpression, lambdaExpressionField);
} }
else { else {
throw new VerifyException(format("unexpected expression: %s", expression.toString())); throw new VerifyException(format("unexpected expression: %s", expression.toString()));
Expand Down
Expand Up @@ -48,12 +48,14 @@
import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.ParameterizedType.type;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantClass; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantClass;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.setStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.setStatic;
import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary; import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary;
import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary; import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary;
import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;


public class LambdaBytecodeGenerator public class LambdaBytecodeGenerator
{ {
Expand All @@ -64,7 +66,7 @@ private LambdaBytecodeGenerator()
/** /**
* @return a MethodHandle field that represents the lambda expression * @return a MethodHandle field that represents the lambda expression
*/ */
public static FieldDefinition preGenerateLambdaExpression( public static LambdaExpressionField preGenerateLambdaExpression(
LambdaDefinitionExpression lambdaExpression, LambdaDefinitionExpression lambdaExpression,
String fieldName, String fieldName,
ClassDefinition classDefinition, ClassDefinition classDefinition,
Expand Down Expand Up @@ -100,7 +102,7 @@ public static FieldDefinition preGenerateLambdaExpression(
lambdaExpression); lambdaExpression);
} }


private static FieldDefinition defineLambdaMethodAndField( private static LambdaExpressionField defineLambdaMethodAndField(
BytecodeExpressionVisitor innerExpressionVisitor, BytecodeExpressionVisitor innerExpressionVisitor,
ClassDefinition classDefinition, ClassDefinition classDefinition,
String fieldAndMethodName, String fieldAndMethodName,
Expand All @@ -119,11 +121,12 @@ private static FieldDefinition defineLambdaMethodAndField(
.append(boxPrimitiveIfNecessary(scope, returnType)) .append(boxPrimitiveIfNecessary(scope, returnType))
.ret(returnType); .ret(returnType);


FieldDefinition methodHandleField = classDefinition.declareField(a(PRIVATE, STATIC, FINAL), fieldAndMethodName, type(MethodHandle.class)); FieldDefinition staticField = classDefinition.declareField(a(PRIVATE, STATIC, FINAL), fieldAndMethodName, type(MethodHandle.class));
FieldDefinition instanceField = classDefinition.declareField(a(PRIVATE, FINAL), "binded_" + fieldAndMethodName, type(MethodHandle.class));


classDefinition.getClassInitializer().getBody() classDefinition.getClassInitializer().getBody()
.append(setStatic( .append(setStatic(
methodHandleField, staticField,
invokeStatic( invokeStatic(
Reflection.class, Reflection.class,
"methodHandle", "methodHandle",
Expand All @@ -136,7 +139,8 @@ private static FieldDefinition defineLambdaMethodAndField(
.map(Parameter::getType) .map(Parameter::getType)
.map(BytecodeExpressions::constantClass) .map(BytecodeExpressions::constantClass)
.collect(toImmutableList()))))); .collect(toImmutableList())))));
return methodHandleField;
return new LambdaExpressionField(staticField, instanceField);
} }


private static RowExpressionVisitor<Scope, BytecodeNode> variableReferenceCompiler(Map<String, ParameterAndType> parameterMap) private static RowExpressionVisitor<Scope, BytecodeNode> variableReferenceCompiler(Map<String, ParameterAndType> parameterMap)
Expand Down Expand Up @@ -179,4 +183,30 @@ public BytecodeNode visitVariableReference(VariableReferenceExpression reference
} }
}; };
} }

static class LambdaExpressionField
{
private final FieldDefinition staticField;
// the instance field will be binded to "this" in constructor
private final FieldDefinition instanceField;

public LambdaExpressionField(FieldDefinition staticField, FieldDefinition instanceField)
{
this.staticField = requireNonNull(staticField, "staticField is null");
this.instanceField = requireNonNull(instanceField, "instanceField is null");
}

public FieldDefinition getInstanceField()
{
return instanceField;
}

public void generateInitialization(Variable thisVariable, BytecodeBlock block)
{
block.append(
thisVariable.setField(
instanceField,
getStatic(staticField).invoke("bindTo", MethodHandle.class, thisVariable.cast(Object.class))));
}
}
} }

0 comments on commit e4a408e

Please sign in to comment.