Skip to content

Commit

Permalink
Add stage id/plan node id to generated page projection class name
Browse files Browse the repository at this point in the history
  • Loading branch information
nezihyigitbasi committed Jul 20, 2017
1 parent 3ecc315 commit 4df9290
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 27 deletions.
Expand Up @@ -69,7 +69,7 @@ protected List<? extends OperatorFactory> createOperatorFactories()
// and quantity < 24; // and quantity < 24;
OperatorFactory tableScanOperator = createTableScanOperator(0, new PlanNodeId("test"), "lineitem", "extendedprice", "discount", "shipdate", "quantity"); OperatorFactory tableScanOperator = createTableScanOperator(0, new PlanNodeId("test"), "lineitem", "extendedprice", "discount", "shipdate", "quantity");


Supplier<PageProjection> projection = new PageFunctionCompiler(localQueryRunner.getMetadata()).compileProjection(field(0, BIGINT)); Supplier<PageProjection> projection = new PageFunctionCompiler(localQueryRunner.getMetadata()).compileProjection(field(0, BIGINT), Optional.empty());


FilterAndProjectOperator.FilterAndProjectOperatorFactory tpchQuery6Operator = new FilterAndProjectOperator.FilterAndProjectOperatorFactory( FilterAndProjectOperator.FilterAndProjectOperatorFactory tpchQuery6Operator = new FilterAndProjectOperator.FilterAndProjectOperatorFactory(
1, 1,
Expand Down
Expand Up @@ -160,7 +160,7 @@ private SqlTaskExecution(
List<DriverFactory> driverFactories; List<DriverFactory> driverFactories;
try { try {
LocalExecutionPlan localExecutionPlan = planner.plan( LocalExecutionPlan localExecutionPlan = planner.plan(
taskContext.getSession(), taskContext,
fragment.getRoot(), fragment.getRoot(),
fragment.getSymbols(), fragment.getSymbols(),
fragment.getPartitioningScheme(), fragment.getPartitioningScheme(),
Expand Down
Expand Up @@ -72,7 +72,7 @@ public DynamicTupleFilterFactory(int filterOperatorId, PlanNodeId planNodeId, in
this.outputTypes = ImmutableList.copyOf(outputTypes); this.outputTypes = ImmutableList.copyOf(outputTypes);
PageFunctionCompiler pageFunctionCompiler = new PageFunctionCompiler(metadata); PageFunctionCompiler pageFunctionCompiler = new PageFunctionCompiler(metadata);
this.outputProjections = IntStream.range(0, outputTypes.size()) this.outputProjections = IntStream.range(0, outputTypes.size())
.mapToObj(field -> pageFunctionCompiler.compileProjection(Expressions.field(field, outputTypes.get(field)))) .mapToObj(field -> pageFunctionCompiler.compileProjection(Expressions.field(field, outputTypes.get(field)), Optional.empty()))
.collect(toImmutableList()); .collect(toImmutableList());
} }


Expand Down
Expand Up @@ -97,12 +97,12 @@ public Supplier<CursorProcessor> compileCursorProcessor(Optional<RowExpression>
}; };
} }


public Supplier<PageProcessor> compilePageProcessor(Optional<RowExpression> filter, List<? extends RowExpression> projections) public Supplier<PageProcessor> compilePageProcessor(Optional<RowExpression> filter, List<? extends RowExpression> projections, Optional<String> classNameSuffix)
{ {
PageFunctionCompiler pageFunctionCompiler = new PageFunctionCompiler(metadata); PageFunctionCompiler pageFunctionCompiler = new PageFunctionCompiler(metadata);
Optional<Supplier<PageFilter>> filterFunctionSupplier = filter.map(pageFunctionCompiler::compileFilter); Optional<Supplier<PageFilter>> filterFunctionSupplier = filter.map(pageFunctionCompiler::compileFilter);
List<Supplier<PageProjection>> pageProjectionSuppliers = projections.stream() List<Supplier<PageProjection>> pageProjectionSuppliers = projections.stream()
.map(pageFunctionCompiler::compileProjection) .map(projection -> pageFunctionCompiler.compileProjection(projection, classNameSuffix))
.collect(toImmutableList()); .collect(toImmutableList());


return () -> { return () -> {
Expand All @@ -114,6 +114,11 @@ public Supplier<PageProcessor> compilePageProcessor(Optional<RowExpression> filt
}; };
} }


public Supplier<PageProcessor> compilePageProcessor(Optional<RowExpression> filter, List<? extends RowExpression> projections)
{
return compilePageProcessor(filter, projections, Optional.empty());
}

private <T> Class<? extends T> compile(Optional<RowExpression> filter, List<RowExpression> projections, BodyCompiler bodyCompiler, Class<? extends T> superType) private <T> Class<? extends T> compile(Optional<RowExpression> filter, List<RowExpression> projections, BodyCompiler bodyCompiler, Class<? extends T> superType)
{ {
// create filter and project page iterator class // create filter and project page iterator class
Expand Down
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.bytecode.FieldDefinition; import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.ParameterizedType;
import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable; import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.ForLoop; import com.facebook.presto.bytecode.control.ForLoop;
Expand Down Expand Up @@ -58,6 +59,7 @@
import javax.inject.Inject; import javax.inject.Inject;


import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.function.Consumer; import java.util.function.Consumer;
Expand Down Expand Up @@ -104,7 +106,7 @@ public PageFunctionCompiler(Metadata metadata)
this.determinismEvaluator = new DeterminismEvaluator(metadata.getFunctionRegistry()); this.determinismEvaluator = new DeterminismEvaluator(metadata.getFunctionRegistry());
} }


public Supplier<PageProjection> compileProjection(RowExpression projection) public Supplier<PageProjection> compileProjection(RowExpression projection, Optional<String> classNameSuffix)
{ {
requireNonNull(projection, "projection is null"); requireNonNull(projection, "projection is null");


Expand All @@ -123,7 +125,7 @@ public Supplier<PageProjection> compileProjection(RowExpression projection)
PageFieldsToInputParametersRewriter.Result result = rewritePageFieldsToInputParameters(projection); PageFieldsToInputParametersRewriter.Result result = rewritePageFieldsToInputParameters(projection);


CallSiteBinder callSiteBinder = new CallSiteBinder(); CallSiteBinder callSiteBinder = new CallSiteBinder();
ClassDefinition classDefinition = defineProjectionClass(result.getRewrittenExpression(), result.getInputChannels(), callSiteBinder); ClassDefinition classDefinition = defineProjectionClass(result.getRewrittenExpression(), result.getInputChannels(), callSiteBinder, classNameSuffix);


Class<? extends PageProjection> projectionClass; Class<? extends PageProjection> projectionClass;
try { try {
Expand All @@ -143,11 +145,21 @@ public Supplier<PageProjection> compileProjection(RowExpression projection)
}; };
} }


private ClassDefinition defineProjectionClass(RowExpression projection, InputChannels inputChannels, CallSiteBinder callSiteBinder) private ParameterizedType generateProjectionClassName(Optional<String> classNameSuffix)
{
StringBuilder className = new StringBuilder(PageProjection.class.getSimpleName());
classNameSuffix.ifPresent(suffix -> className.append("_").append(suffix.replace('.', '_')));
return makeClassName(className.toString());
}

private ClassDefinition defineProjectionClass(RowExpression projection,
InputChannels inputChannels,
CallSiteBinder callSiteBinder,
Optional<String> classNameSuffix)
{ {
ClassDefinition classDefinition = new ClassDefinition( ClassDefinition classDefinition = new ClassDefinition(
a(PUBLIC, FINAL), a(PUBLIC, FINAL),
makeClassName(PageProjection.class.getSimpleName()), generateProjectionClassName(classNameSuffix),
type(Object.class), type(Object.class),
type(PageProjection.class)); type(PageProjection.class));


Expand Down
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.execution.QueryPerformanceFetcher; import com.facebook.presto.execution.QueryPerformanceFetcher;
import com.facebook.presto.execution.StageId;
import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.execution.buffer.OutputBuffer; import com.facebook.presto.execution.buffer.OutputBuffer;
import com.facebook.presto.execution.buffer.PagesSerdeFactory; import com.facebook.presto.execution.buffer.PagesSerdeFactory;
Expand Down Expand Up @@ -57,6 +58,7 @@
import com.facebook.presto.operator.SetBuilderOperator.SetSupplier; import com.facebook.presto.operator.SetBuilderOperator.SetSupplier;
import com.facebook.presto.operator.SourceOperatorFactory; import com.facebook.presto.operator.SourceOperatorFactory;
import com.facebook.presto.operator.TableScanOperator.TableScanOperatorFactory; import com.facebook.presto.operator.TableScanOperator.TableScanOperatorFactory;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.operator.TaskOutputOperator.TaskOutputFactory; import com.facebook.presto.operator.TaskOutputOperator.TaskOutputFactory;
import com.facebook.presto.operator.TopNOperator.TopNOperatorFactory; import com.facebook.presto.operator.TopNOperator.TopNOperatorFactory;
import com.facebook.presto.operator.TopNRowNumberOperator; import com.facebook.presto.operator.TopNRowNumberOperator;
Expand Down Expand Up @@ -296,7 +298,7 @@ public LocalExecutionPlanner(
} }


public LocalExecutionPlan plan( public LocalExecutionPlan plan(
Session session, TaskContext taskContext,
PlanNode plan, PlanNode plan,
Map<Symbol, Type> types, Map<Symbol, Type> types,
PartitioningScheme partitioningScheme, PartitioningScheme partitioningScheme,
Expand All @@ -307,7 +309,7 @@ public LocalExecutionPlan plan(
partitioningScheme.getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioningScheme.getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION) ||
partitioningScheme.getPartitioning().getHandle().equals(SINGLE_DISTRIBUTION) || partitioningScheme.getPartitioning().getHandle().equals(SINGLE_DISTRIBUTION) ||
partitioningScheme.getPartitioning().getHandle().equals(COORDINATOR_DISTRIBUTION)) { partitioningScheme.getPartitioning().getHandle().equals(COORDINATOR_DISTRIBUTION)) {
return plan(session, plan, outputLayout, types, new TaskOutputFactory(outputBuffer)); return plan(taskContext, plan, outputLayout, types, new TaskOutputFactory(outputBuffer));
} }


// We can convert the symbols directly into channels, because the root must be a sink and therefore the layout is fixed // We can convert the symbols directly into channels, because the root must be a sink and therefore the layout is fixed
Expand Down Expand Up @@ -342,7 +344,7 @@ public LocalExecutionPlan plan(
.collect(toImmutableList()); .collect(toImmutableList());
} }


PartitionFunction partitionFunction = nodePartitioningManager.getPartitionFunction(session, partitioningScheme, partitionChannelTypes); PartitionFunction partitionFunction = nodePartitioningManager.getPartitionFunction(taskContext.getSession(), partitioningScheme, partitionChannelTypes);
OptionalInt nullChannel = OptionalInt.empty(); OptionalInt nullChannel = OptionalInt.empty();
Set<Symbol> partitioningColumns = partitioningScheme.getPartitioning().getColumns(); Set<Symbol> partitioningColumns = partitioningScheme.getPartitioning().getColumns();


Expand All @@ -353,7 +355,7 @@ public LocalExecutionPlan plan(
} }


return plan( return plan(
session, taskContext,
plan, plan,
outputLayout, outputLayout,
types, types,
Expand All @@ -367,13 +369,14 @@ public LocalExecutionPlan plan(
maxPagePartitioningBufferSize)); maxPagePartitioningBufferSize));
} }


public LocalExecutionPlan plan(Session session, public LocalExecutionPlan plan(TaskContext taskContext,
PlanNode plan, PlanNode plan,
List<Symbol> outputLayout, List<Symbol> outputLayout,
Map<Symbol, Type> types, Map<Symbol, Type> types,
OutputFactory outputOperatorFactory) OutputFactory outputOperatorFactory)
{ {
LocalExecutionPlanContext context = new LocalExecutionPlanContext(session, types); Session session = taskContext.getSession();
LocalExecutionPlanContext context = new LocalExecutionPlanContext(taskContext, types);


PhysicalOperation physicalOperation = plan.accept(new Visitor(session), context); PhysicalOperation physicalOperation = plan.accept(new Visitor(session), context);


Expand Down Expand Up @@ -441,7 +444,7 @@ private static void addLookupOuterDrivers(LocalExecutionPlanContext context)


private static class LocalExecutionPlanContext private static class LocalExecutionPlanContext
{ {
private final Session session; private final TaskContext taskContext;
private final Map<Symbol, Type> types; private final Map<Symbol, Type> types;
private final List<DriverFactory> driverFactories; private final List<DriverFactory> driverFactories;
private final Optional<IndexSourceContext> indexSourceContext; private final Optional<IndexSourceContext> indexSourceContext;
Expand All @@ -453,19 +456,19 @@ private static class LocalExecutionPlanContext
private boolean inputDriver = true; private boolean inputDriver = true;
private OptionalInt driverInstanceCount = OptionalInt.empty(); private OptionalInt driverInstanceCount = OptionalInt.empty();


public LocalExecutionPlanContext(Session session, Map<Symbol, Type> types) public LocalExecutionPlanContext(TaskContext taskContext, Map<Symbol, Type> types)
{ {
this(session, types, new ArrayList<>(), Optional.empty(), new AtomicInteger(0)); this(taskContext, types, new ArrayList<>(), Optional.empty(), new AtomicInteger(0));
} }


private LocalExecutionPlanContext( private LocalExecutionPlanContext(
Session session, TaskContext taskContext,
Map<Symbol, Type> types, Map<Symbol, Type> types,
List<DriverFactory> driverFactories, List<DriverFactory> driverFactories,
Optional<IndexSourceContext> indexSourceContext, Optional<IndexSourceContext> indexSourceContext,
AtomicInteger nextPipelineId) AtomicInteger nextPipelineId)
{ {
this.session = session; this.taskContext = taskContext;
this.types = types; this.types = types;
this.driverFactories = driverFactories; this.driverFactories = driverFactories;
this.indexSourceContext = indexSourceContext; this.indexSourceContext = indexSourceContext;
Expand All @@ -484,7 +487,12 @@ private List<DriverFactory> getDriverFactories()


public Session getSession() public Session getSession()
{ {
return session; return taskContext.getSession();
}

public StageId getStageId()
{
return taskContext.getTaskId().getStageId();
} }


public Map<Symbol, Type> getTypes() public Map<Symbol, Type> getTypes()
Expand Down Expand Up @@ -520,12 +528,12 @@ private void setInputDriver(boolean inputDriver)
public LocalExecutionPlanContext createSubContext() public LocalExecutionPlanContext createSubContext()
{ {
checkState(!indexSourceContext.isPresent(), "index build plan can not have sub-contexts"); checkState(!indexSourceContext.isPresent(), "index build plan can not have sub-contexts");
return new LocalExecutionPlanContext(session, types, driverFactories, indexSourceContext, nextPipelineId); return new LocalExecutionPlanContext(taskContext, types, driverFactories, indexSourceContext, nextPipelineId);
} }


public LocalExecutionPlanContext createIndexSourceSubContext(IndexSourceContext indexSourceContext) public LocalExecutionPlanContext createIndexSourceSubContext(IndexSourceContext indexSourceContext)
{ {
return new LocalExecutionPlanContext(session, types, driverFactories, Optional.of(indexSourceContext), nextPipelineId); return new LocalExecutionPlanContext(taskContext, types, driverFactories, Optional.of(indexSourceContext), nextPipelineId);
} }


public OptionalInt getDriverInstanceCount() public OptionalInt getDriverInstanceCount()
Expand Down Expand Up @@ -1070,7 +1078,7 @@ private PhysicalOperation visitScanFilterAndProject(
try { try {
if (columns != null) { if (columns != null) {
Supplier<CursorProcessor> cursorProcessor = expressionCompiler.compileCursorProcessor(translatedFilter, translatedProjections, sourceNode.getId()); Supplier<CursorProcessor> cursorProcessor = expressionCompiler.compileCursorProcessor(translatedFilter, translatedProjections, sourceNode.getId());
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections); Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));


SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory( SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory(
context.getNextOperatorId(), context.getNextOperatorId(),
Expand All @@ -1085,7 +1093,7 @@ private PhysicalOperation visitScanFilterAndProject(
return new PhysicalOperation(operatorFactory, outputMappings); return new PhysicalOperation(operatorFactory, outputMappings);
} }
else { else {
Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections); Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));


OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory( OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory(
context.getNextOperatorId(), context.getNextOperatorId(),
Expand Down
Expand Up @@ -597,7 +597,7 @@ public List<Driver> createDrivers(Session session, @Language("SQL") String sql,


// plan query // plan query
LocalExecutionPlan localExecutionPlan = executionPlanner.plan( LocalExecutionPlan localExecutionPlan = executionPlanner.plan(
session, taskContext,
subplan.getFragment().getRoot(), subplan.getFragment().getRoot(),
subplan.getFragment().getPartitioningScheme().getOutputLayout(), subplan.getFragment().getPartitioningScheme().getOutputLayout(),
plan.getTypes(), plan.getTypes(),
Expand Down
Expand Up @@ -24,6 +24,7 @@
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test; import org.testng.annotations.Test;


import java.util.Optional;
import java.util.function.Supplier; import java.util.function.Supplier;


import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
Expand All @@ -35,6 +36,7 @@
import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.sql.relational.Expressions.field;
import static com.facebook.presto.testing.TestingConnectorSession.SESSION; import static com.facebook.presto.testing.TestingConnectorSession.SESSION;
import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail; import static org.testng.Assert.fail;


public class TestPageFunctionCompiler public class TestPageFunctionCompiler
Expand All @@ -51,7 +53,7 @@ public void testFailureDoesNotCorruptFutureResults()
field(0, BIGINT), field(0, BIGINT),
constant(10L, BIGINT)); constant(10L, BIGINT));


Supplier<PageProjection> projectionSupplier = functionCompiler.compileProjection(add10); Supplier<PageProjection> projectionSupplier = functionCompiler.compileProjection(add10, Optional.empty());
PageProjection projection = projectionSupplier.get(); PageProjection projection = projectionSupplier.get();


// process good page and verify we got the expected number of result rows // process good page and verify we got the expected number of result rows
Expand All @@ -75,6 +77,25 @@ public void testFailureDoesNotCorruptFutureResults()
assertEquals(goodPage.getPositionCount(), goodResult.getPositionCount()); assertEquals(goodPage.getPositionCount(), goodResult.getPositionCount());
} }


@Test
public void testGeneratedClassName()
{
PageFunctionCompiler functionCompiler = new PageFunctionCompiler(createTestMetadataManager());
RowExpression add10 = call(
Signature.internalOperator(ADD, BIGINT.getTypeSignature(), ImmutableList.of(BIGINT.getTypeSignature(), BIGINT.getTypeSignature())),
BIGINT,
field(0, BIGINT),
constant(10L, BIGINT));

String planNodeId = "7";
String stageId = "20170707_223500_67496_zguwn.2";
String classSuffix = stageId + "_" + planNodeId;
Supplier<PageProjection> projectionSupplier = functionCompiler.compileProjection(add10, Optional.of(classSuffix));
PageProjection projection = projectionSupplier.get();
// class name should look like PageProjection_20170707_223500_67496_zguwn_2_7_XX
assertTrue(projection.getClass().getSimpleName().startsWith("PageProjection_" + stageId.replace('.', '_') + "_" + planNodeId));
}

private static Page createLongBlockPage(long... values) private static Page createLongBlockPage(long... values)
{ {
BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(values.length); BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(values.length);
Expand Down

0 comments on commit 4df9290

Please sign in to comment.