Skip to content

Commit

Permalink
Refactor PageProcessor to use ByteCodeExpressions
Browse files Browse the repository at this point in the history
Example of generate byte code tree
https://gist.github.com/nileema/21678edd182daa30517b
  • Loading branch information
nileema committed Nov 9, 2015
1 parent d40f771 commit ff4003e
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 104 deletions.
Expand Up @@ -16,16 +16,19 @@
import com.facebook.presto.byteCode.ByteCodeNode; import com.facebook.presto.byteCode.ByteCodeNode;
import com.facebook.presto.byteCode.ByteCodeVisitor; import com.facebook.presto.byteCode.ByteCodeVisitor;
import com.facebook.presto.byteCode.FieldDefinition; import com.facebook.presto.byteCode.FieldDefinition;
import com.facebook.presto.byteCode.MethodDefinition;
import com.facebook.presto.byteCode.MethodGenerationContext; import com.facebook.presto.byteCode.MethodGenerationContext;
import com.facebook.presto.byteCode.ParameterizedType; import com.facebook.presto.byteCode.ParameterizedType;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.MethodVisitor;


import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.List;


import static com.facebook.presto.byteCode.ParameterizedType.type; import static com.facebook.presto.byteCode.ParameterizedType.type;
import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantInt; import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantInt;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.transform; import static com.google.common.collect.Iterables.transform;
import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNull;


Expand Down Expand Up @@ -123,6 +126,13 @@ public final ByteCodeExpression invoke(Method method, ByteCodeExpression... para
return invoke(method, ImmutableList.copyOf(requireNonNull(parameters, "parameters is null"))); return invoke(method, ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")));
} }


public final ByteCodeExpression invoke(MethodDefinition method, Iterable<? extends ByteCodeExpression> parameters)
{
List<ByteCodeExpression> params = ImmutableList.copyOf(parameters);
checkArgument(method.getParameters().size() == params.size(), "Expected %s params found %s", method.getParameters().size(), params.size());
return invoke(method.getName(), method.getReturnType(), parameters);
}

public final ByteCodeExpression invoke(Method method, Iterable<? extends ByteCodeExpression> parameters) public final ByteCodeExpression invoke(Method method, Iterable<? extends ByteCodeExpression> parameters)
{ {
return invoke(method.getName(), type(method.getReturnType()), parameters); return invoke(method.getName(), type(method.getReturnType()), parameters);
Expand Down
Expand Up @@ -18,11 +18,11 @@
import com.facebook.presto.byteCode.ClassDefinition; import com.facebook.presto.byteCode.ClassDefinition;
import com.facebook.presto.byteCode.MethodDefinition; import com.facebook.presto.byteCode.MethodDefinition;
import com.facebook.presto.byteCode.Parameter; import com.facebook.presto.byteCode.Parameter;
import com.facebook.presto.byteCode.ParameterizedType;
import com.facebook.presto.byteCode.Scope; import com.facebook.presto.byteCode.Scope;
import com.facebook.presto.byteCode.Variable; import com.facebook.presto.byteCode.Variable;
import com.facebook.presto.byteCode.control.ForLoop; import com.facebook.presto.byteCode.control.ForLoop;
import com.facebook.presto.byteCode.control.IfStatement; import com.facebook.presto.byteCode.control.IfStatement;
import com.facebook.presto.byteCode.expression.ByteCodeExpression;
import com.facebook.presto.byteCode.instruction.LabelNode; import com.facebook.presto.byteCode.instruction.LabelNode;
import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.operator.PageProcessor; import com.facebook.presto.operator.PageProcessor;
Expand All @@ -39,22 +39,30 @@
import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.RowExpression;
import com.facebook.presto.sql.relational.RowExpressionVisitor; import com.facebook.presto.sql.relational.RowExpressionVisitor;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import com.google.common.primitives.Primitives; import com.google.common.primitives.Primitives;
import io.airlift.slice.Slice; import io.airlift.slice.Slice;


import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.TreeSet; import java.util.TreeSet;


import static com.facebook.presto.byteCode.Access.PUBLIC; import static com.facebook.presto.byteCode.Access.PUBLIC;
import static com.facebook.presto.byteCode.Access.a; import static com.facebook.presto.byteCode.Access.a;
import static com.facebook.presto.byteCode.OpCode.NOP;
import static com.facebook.presto.byteCode.Parameter.arg; import static com.facebook.presto.byteCode.Parameter.arg;
import static com.facebook.presto.byteCode.ParameterizedType.type; import static com.facebook.presto.byteCode.ParameterizedType.type;
import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.add;
import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantFalse;
import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.constantInt;
import static com.facebook.presto.byteCode.expression.ByteCodeExpressions.lessThan;
import static com.facebook.presto.byteCode.instruction.JumpInstruction.jump;
import static com.facebook.presto.sql.gen.ByteCodeUtils.generateWrite; import static com.facebook.presto.sql.gen.ByteCodeUtils.generateWrite;
import static com.facebook.presto.sql.gen.ByteCodeUtils.loadConstant; import static com.facebook.presto.sql.gen.ByteCodeUtils.loadConstant;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format; import static java.lang.String.format;
import static java.util.Collections.nCopies; import static java.util.stream.Collectors.toList;


public class PageProcessorCompiler public class PageProcessorCompiler
implements BodyCompiler<PageProcessor> implements BodyCompiler<PageProcessor>
Expand All @@ -69,15 +77,15 @@ public PageProcessorCompiler(Metadata metadata)
@Override @Override
public void generateMethods(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter, List<RowExpression> projections) public void generateMethods(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter, List<RowExpression> projections)
{ {
generateProcessMethod(classDefinition, filter, projections); ImmutableList.Builder<MethodDefinition> projectionMethods = ImmutableList.builder();
generateFilterMethod(classDefinition, callSiteBinder, filter);

for (int i = 0; i < projections.size(); i++) { for (int i = 0; i < projections.size(); i++) {
generateProjectMethod(classDefinition, callSiteBinder, "project_" + i, projections.get(i)); projectionMethods.add(generateProjectMethod(classDefinition, callSiteBinder, "project_" + i, projections.get(i)));
} }
generateProcessMethod(classDefinition, filter, projections, projectionMethods.build());
generateFilterMethod(classDefinition, callSiteBinder, filter);
} }


private void generateProcessMethod(ClassDefinition classDefinition, RowExpression filter, List<RowExpression> projections) private static void generateProcessMethod(ClassDefinition classDefinition, RowExpression filter, List<RowExpression> projections, List<MethodDefinition> projectionMethods)
{ {
Parameter session = arg("session", ConnectorSession.class); Parameter session = arg("session", ConnectorSession.class);
Parameter page = arg("page", Page.class); Parameter page = arg("page", Page.class);
Expand All @@ -88,105 +96,56 @@ private void generateProcessMethod(ClassDefinition classDefinition, RowExpressio


Scope scope = method.getScope(); Scope scope = method.getScope();
Variable thisVariable = method.getThis(); Variable thisVariable = method.getThis();
Variable position = scope.declareVariable(int.class, "position");

method.getBody()
.comment("int position = start;")
.getVariable(start)
.putVariable(position);


// extract blocks
List<Integer> allInputChannels = getInputChannels(Iterables.concat(projections, ImmutableList.of(filter))); List<Integer> allInputChannels = getInputChannels(Iterables.concat(projections, ImmutableList.of(filter)));
ImmutableMap.Builder<Integer, Variable> channelBlockBuilder = ImmutableMap.builder();
for (int channel : allInputChannels) { for (int channel : allInputChannels) {
Variable blockVariable = scope.declareVariable(Block.class, "block_" + channel); Variable blockVariable = scope.declareVariable(Block.class, "block_" + channel);
method.getBody() method.getBody().append(blockVariable.set(page.invoke("getBlock", Block.class, constantInt(channel))));
.comment("Block %s = page.getBlock(%s);", blockVariable.getName(), channel) channelBlockBuilder.put(channel, blockVariable);
.getVariable(page)
.push(channel)
.invokeVirtual(Page.class, "getBlock", Block.class, int.class)
.putVariable(blockVariable);
} }
Map<Integer, Variable> channelBlock = channelBlockBuilder.build();
Map<RowExpression, List<Variable>> expressionInputBlocks = getExpressionInputBlocks(projections, filter, channelBlock);


// // extract block builders
// for loop loop body ImmutableList.Builder<Variable> builder = ImmutableList.<Variable>builder();
// for (int projectionIndex = 0; projectionIndex < projections.size(); projectionIndex++) {
LabelNode done = new LabelNode("done"); Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder_" + projectionIndex);
method.getBody().append(blockBuilder.set(pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, constantInt(projectionIndex))));
builder.add(blockBuilder);
}
List<Variable> blockBuilders = builder.build();


ByteCodeBlock loopBody = new ByteCodeBlock(); // projection body
Variable position = scope.declareVariable(int.class, "position");


ForLoop loop = new ForLoop() ByteCodeBlock project = new ByteCodeBlock()
.initialize(NOP) .append(pageBuilder.invoke("declarePosition", void.class));
.condition(new ByteCodeBlock()
.comment("position < end")
.getVariable(position)
.getVariable(end)
.invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)
)
.update(new ByteCodeBlock()
.comment("position++")
.incrementVariable(position, (byte) 1))
.body(loopBody);

loopBody.comment("if (pageBuilder.isFull()) break;")
.getVariable(pageBuilder)
.invokeVirtual(PageBuilder.class, "isFull", boolean.class)
.ifTrueGoto(done);

// if (filter(cursor))
IfStatement filterBlock = new IfStatement();
filterBlock.condition()
.append(thisVariable)
.getVariable(session)
.append(pushBlockVariables(scope, getInputChannels(filter)))
.getVariable(position)
.invokeVirtual(classDefinition.getType(),
"filter",
type(boolean.class),
ImmutableList.<ParameterizedType>builder()
.add(type(ConnectorSession.class))
.addAll(nCopies(getInputChannels(filter).size(), type(Block.class)))
.add(type(int.class))
.build());

filterBlock.ifTrue()
.append(pageBuilder)
.invokeVirtual(PageBuilder.class, "declarePosition", void.class);


for (int projectionIndex = 0; projectionIndex < projections.size(); projectionIndex++) { for (int projectionIndex = 0; projectionIndex < projections.size(); projectionIndex++) {
List<Integer> inputChannels = getInputChannels(projections.get(projectionIndex)); RowExpression projection = projections.get(projectionIndex);

project.append(invokeProject(thisVariable, session, expressionInputBlocks.get(projection), position, blockBuilders.get(projectionIndex), projectionMethods.get(projectionIndex)));
filterBlock.ifTrue()
.append(thisVariable)
.append(session)
.append(pushBlockVariables(scope, inputChannels))
.getVariable(position);

filterBlock.ifTrue()
.comment("pageBuilder.getBlockBuilder(%d)", projectionIndex)
.append(pageBuilder)
.push(projectionIndex)
.invokeVirtual(PageBuilder.class, "getBlockBuilder", BlockBuilder.class, int.class);

filterBlock.ifTrue()
.comment("project_%d(session, block_%s, position, blockBuilder)", projectionIndex, inputChannels)
.invokeVirtual(classDefinition.getType(),
"project_" + projectionIndex,
type(void.class),
ImmutableList.<ParameterizedType>builder()
.add(type(ConnectorSession.class))
.addAll(nCopies(inputChannels.size(), type(Block.class)))
.add(type(int.class))
.add(type(BlockBuilder.class))
.build());
} }
LabelNode done = new LabelNode("done");


loopBody.append(filterBlock); // for loop loop body
ForLoop loop = new ForLoop()
.initialize(position.set(start))
.condition(lessThan(position, end))
.update(position.set(add(position, constantInt(1))))
.body(new ByteCodeBlock()
.append(new IfStatement()
.condition(pageBuilder.invoke("isFull", boolean.class))
.ifTrue(jump(done)))
.append(new IfStatement()
.condition(invokeFilter(thisVariable, session, expressionInputBlocks.get(filter), position))
.ifTrue(project)));


method.getBody() method.getBody()
.append(loop) .append(loop)
.visitLabel(done) .visitLabel(done)
.comment("return position;") .append(position.ret());
.getVariable(position)
.retInt();
} }


private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter) private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter)
Expand Down Expand Up @@ -229,7 +188,7 @@ private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinde
.retBoolean(); .retBoolean();
} }


private void generateProjectMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, String methodName, RowExpression projection) private MethodDefinition generateProjectMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, String methodName, RowExpression projection)
{ {
Parameter session = arg("session", ConnectorSession.class); Parameter session = arg("session", ConnectorSession.class);
List<Parameter> inputs = toBlockParameters(getInputChannels(projection)); List<Parameter> inputs = toBlockParameters(getInputChannels(projection));
Expand All @@ -252,8 +211,7 @@ private void generateProjectMethod(ClassDefinition classDefinition, CallSiteBind
Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull"); Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull");


ByteCodeBlock body = method.getBody() ByteCodeBlock body = method.getBody()
.comment("boolean wasNull = false;") .append(wasNullVariable.set(constantFalse()));
.putVariable(wasNullVariable, false);


ByteCodeExpressionVisitor visitor = new ByteCodeExpressionVisitor(callSiteBinder, fieldReferenceCompiler(callSiteBinder, position, wasNullVariable), metadata.getFunctionRegistry()); ByteCodeExpressionVisitor visitor = new ByteCodeExpressionVisitor(callSiteBinder, fieldReferenceCompiler(callSiteBinder, position, wasNullVariable), metadata.getFunctionRegistry());


Expand All @@ -262,6 +220,7 @@ private void generateProjectMethod(ClassDefinition classDefinition, CallSiteBind
.append(projection.accept(visitor, scope)) .append(projection.accept(visitor, scope))
.append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType()))
.ret(); .ret();
return method;
} }


private static List<Integer> getInputChannels(Iterable<RowExpression> expressions) private static List<Integer> getInputChannels(Iterable<RowExpression> expressions)
Expand Down Expand Up @@ -289,16 +248,7 @@ private static List<Parameter> toBlockParameters(List<Integer> inputChannels)
return parameters.build(); return parameters.build();
} }


private static ByteCodeNode pushBlockVariables(Scope scope, List<Integer> inputs) private static RowExpressionVisitor<Scope, ByteCodeNode> fieldReferenceCompiler(final CallSiteBinder callSiteBinder, final Variable positionVariable, final Variable wasNullVariable)
{
ByteCodeBlock block = new ByteCodeBlock();
for (int channel : inputs) {
block.append(scope.getVariable("block_" + channel));
}
return block;
}

private RowExpressionVisitor<Scope, ByteCodeNode> fieldReferenceCompiler(final CallSiteBinder callSiteBinder, final Variable positionVariable, final Variable wasNullVariable)
{ {
return new RowExpressionVisitor<Scope, ByteCodeNode>() return new RowExpressionVisitor<Scope, ByteCodeNode>()
{ {
Expand Down Expand Up @@ -349,4 +299,50 @@ public ByteCodeNode visitConstant(ConstantExpression literal, Scope scope)
} }
}; };
} }

private static Map<RowExpression, List<Variable>> getExpressionInputBlocks(List<RowExpression> projections, RowExpression filter, Map<Integer, Variable> channelBlock)
{
Map<RowExpression, List<Variable>> inputBlocksBuilder = new HashMap<>();

for (RowExpression projection : projections) {
List<Variable> inputBlocks = getInputChannels(projection).stream()
.map(channelBlock::get)
.collect(toList());

List<Variable> existingVariables = inputBlocksBuilder.get(projection);
// Constant expressions or expressions that are reused, should reference the same input blocks
checkState(existingVariables == null || existingVariables.equals(inputBlocks), "malformed RowExpression");
inputBlocksBuilder.put(projection, inputBlocks);
}

List<Variable> filterBlocks = getInputChannels(filter).stream()
.map(channelBlock::get)
.collect(toList());

inputBlocksBuilder.put(filter, filterBlocks);

return inputBlocksBuilder;
}

private static ByteCodeExpression invokeFilter(ByteCodeExpression objRef, ByteCodeExpression session, List<? extends ByteCodeExpression> blockVariables, ByteCodeExpression position)
{
List<ByteCodeExpression> params = ImmutableList.<ByteCodeExpression>builder()
.add(session)
.addAll(blockVariables)
.add(position)
.build();

return objRef.invoke("filter", boolean.class, params);
}

private static ByteCodeNode invokeProject(Variable objRef, Variable session, List<Variable> blockVariables, ByteCodeExpression position, Variable blockBuilder, MethodDefinition projectionMethod)
{
List<ByteCodeExpression> params = ImmutableList.<ByteCodeExpression>builder()
.add(session)
.addAll(blockVariables)
.add(position)
.add(blockBuilder)
.build();
return new ByteCodeBlock().append(objRef.invoke(projectionMethod, params));
}
} }

0 comments on commit ff4003e

Please sign in to comment.