Skip to content

Commit

Permalink
Rewrite parallel hash build
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Apr 27, 2016
1 parent 2707dd3 commit 6d27491
Show file tree
Hide file tree
Showing 25 changed files with 987 additions and 1,046 deletions.
Expand Up @@ -74,7 +74,7 @@ protected List<Driver> createDrivers(TaskContext taskContext)
} }


// hash build // hash build
HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(2, new PlanNodeId("test"), source.getTypes(), Ints.asList(0), hashChannel, 1_500_000); HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(2, new PlanNodeId("test"), source.getTypes(), Ints.asList(0), hashChannel, false, 1_500_000);
driversBuilder.add(hashBuilder); driversBuilder.add(hashBuilder);
DriverFactory hashBuildDriverFactory = new DriverFactory(true, false, driversBuilder.build(), OptionalInt.empty()); DriverFactory hashBuildDriverFactory = new DriverFactory(true, false, driversBuilder.build(), OptionalInt.empty());
Driver hashBuildDriver = hashBuildDriverFactory.createDriver(taskContext.addPipelineContext(true, false).addDriverContext()); Driver hashBuildDriver = hashBuildDriverFactory.createDriver(taskContext.addPipelineContext(true, false).addDriverContext());
Expand Down
Expand Up @@ -41,7 +41,14 @@ public HashBuildBenchmark(LocalQueryRunner localQueryRunner)
protected List<Driver> createDrivers(TaskContext taskContext) protected List<Driver> createDrivers(TaskContext taskContext)
{ {
OperatorFactory ordersTableScan = createTableScanOperator(0, new PlanNodeId("test"), "orders", "orderkey", "totalprice"); OperatorFactory ordersTableScan = createTableScanOperator(0, new PlanNodeId("test"), "orders", "orderkey", "totalprice");
HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(1, new PlanNodeId("test"), ordersTableScan.getTypes(), Ints.asList(0), Optional.empty(), 1_500_000); HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(
1,
new PlanNodeId("test"),
ordersTableScan.getTypes(),
Ints.asList(0),
Optional.empty(),
false,
1_500_000);


DriverFactory driverFactory = new DriverFactory(true, true, ImmutableList.of(ordersTableScan, hashBuilder), OptionalInt.empty()); DriverFactory driverFactory = new DriverFactory(true, true, ImmutableList.of(ordersTableScan, hashBuilder), OptionalInt.empty());
Driver driver = driverFactory.createDriver(taskContext.addPipelineContext(true, true).addDriverContext()); Driver driver = driverFactory.createDriver(taskContext.addPipelineContext(true, true).addDriverContext());
Expand Down
Expand Up @@ -53,7 +53,14 @@ protected List<Driver> createDrivers(TaskContext taskContext)
{ {
if (lookupSourceSupplier == null) { if (lookupSourceSupplier == null) {
OperatorFactory ordersTableScan = createTableScanOperator(0, new PlanNodeId("test"), "orders", "orderkey", "totalprice"); OperatorFactory ordersTableScan = createTableScanOperator(0, new PlanNodeId("test"), "orders", "orderkey", "totalprice");
HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(1, new PlanNodeId("test"), ordersTableScan.getTypes(), Ints.asList(0), Optional.empty(), 1_500_000); HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(
1,
new PlanNodeId("test"),
ordersTableScan.getTypes(),
Ints.asList(0),
Optional.empty(),
false,
1_500_000);


DriverContext driverContext = taskContext.addPipelineContext(false, false).addDriverContext(); DriverContext driverContext = taskContext.addPipelineContext(false, false).addDriverContext();
Driver driver = new DriverFactory(false, false, ImmutableList.of(ordersTableScan, hashBuilder), OptionalInt.empty()).createDriver(driverContext); Driver driver = new DriverFactory(false, false, ImmutableList.of(ordersTableScan, hashBuilder), OptionalInt.empty()).createDriver(driverContext);
Expand Down
Expand Up @@ -34,26 +34,31 @@ public class HashBuilderOperator
public static class HashBuilderOperatorFactory public static class HashBuilderOperatorFactory
implements OperatorFactory implements OperatorFactory
{ {
private enum State {
NOT_CREATED, CREATED, CLOSED
}

private final int operatorId; private final int operatorId;
private final PlanNodeId planNodeId; private final PlanNodeId planNodeId;
private final SettableLookupSourceSupplier lookupSourceSupplier; private final SettableLookupSourceSupplier lookupSourceSupplier;
private final List<Integer> hashChannels; private final List<Integer> hashChannels;
private final Optional<Integer> hashChannel; private final Optional<Integer> hashChannel;


private final int expectedPositions; private final int expectedPositions;
private boolean closed; private State state = State.NOT_CREATED;


public HashBuilderOperatorFactory( public HashBuilderOperatorFactory(
int operatorId, int operatorId,
PlanNodeId planNodeId, PlanNodeId planNodeId,
List<Type> types, List<Type> types,
List<Integer> hashChannels, List<Integer> hashChannels,
Optional<Integer> hashChannel, Optional<Integer> hashChannel,
boolean outer,
int expectedPositions) int expectedPositions)
{ {
this.operatorId = operatorId; this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
this.lookupSourceSupplier = new SettableLookupSourceSupplier(requireNonNull(types, "types is null")); this.lookupSourceSupplier = new SettableLookupSourceSupplier(requireNonNull(types, "types is null"), outer);


Preconditions.checkArgument(!hashChannels.isEmpty(), "hashChannels is empty"); Preconditions.checkArgument(!hashChannels.isEmpty(), "hashChannels is empty");
this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null")); this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null"));
Expand All @@ -76,7 +81,9 @@ public List<Type> getTypes()
@Override @Override
public Operator createOperator(DriverContext driverContext) public Operator createOperator(DriverContext driverContext)
{ {
checkState(!closed, "Factory is already closed"); checkState(state == State.NOT_CREATED, "Only one hash build operator can be created");
state = State.CREATED;

OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, HashBuilderOperator.class.getSimpleName()); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, HashBuilderOperator.class.getSimpleName());
return new HashBuilderOperator( return new HashBuilderOperator(
operatorContext, operatorContext,
Expand All @@ -89,13 +96,13 @@ public Operator createOperator(DriverContext driverContext)
@Override @Override
public void close() public void close()
{ {
closed = true; state = State.CLOSED;
} }


@Override @Override
public OperatorFactory duplicate() public OperatorFactory duplicate()
{ {
return new HashBuilderOperatorFactory(operatorId, planNodeId, lookupSourceSupplier.getTypes(), hashChannels, hashChannel, expectedPositions); throw new UnsupportedOperationException("Hash build can not be duplicated");
} }
} }


Expand Down Expand Up @@ -145,8 +152,9 @@ public void finish()
return; return;
} }


// After this point the SharedLookupSource will take over our memory reservation, and ours will be zero // After this point the LookupSource will take over our memory reservation, and ours will be zero
lookupSourceSupplier.setLookupSource(new SharedLookupSource(pagesIndex.createLookupSource(hashChannels, hashChannel), operatorContext)); LookupSource lookupSource = pagesIndex.createLookupSource(hashChannels, hashChannel);
lookupSourceSupplier.setLookupSource(lookupSource, operatorContext);
finished = true; finished = true;
} }


Expand Down
Expand Up @@ -32,12 +32,12 @@
public class LookupJoinOperator public class LookupJoinOperator
implements Operator, Closeable implements Operator, Closeable
{ {
private final ListenableFuture<? extends LookupSource> lookupSourceFuture;
private final LookupSourceSupplier lookupSourceSupplier;

private final OperatorContext operatorContext; private final OperatorContext operatorContext;
private final JoinProbeFactory joinProbeFactory;
private final List<Type> types; private final List<Type> types;
private final ListenableFuture<? extends LookupSource> lookupSourceFuture;
private final JoinProbeFactory joinProbeFactory;
private final Runnable onClose;

private final PageBuilder pageBuilder; private final PageBuilder pageBuilder;


private final boolean probeOnOuterSide; private final boolean probeOnOuterSide;
Expand All @@ -51,28 +51,23 @@ public class LookupJoinOperator


public LookupJoinOperator( public LookupJoinOperator(
OperatorContext operatorContext, OperatorContext operatorContext,
LookupSourceSupplier lookupSourceSupplier, List<Type> types,
List<Type> probeTypes,
JoinType joinType, JoinType joinType,
JoinProbeFactory joinProbeFactory) ListenableFuture<LookupSource> lookupSourceFuture,
JoinProbeFactory joinProbeFactory,
Runnable onClose)
{ {
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));


// todo pass in desired projection requireNonNull(joinType, "joinType is null");
this.lookupSourceSupplier = requireNonNull(lookupSourceSupplier, "lookupSourceSupplier is null");
lookupSourceSupplier.retain();
requireNonNull(probeTypes, "probeTypes is null");

this.lookupSourceFuture = lookupSourceSupplier.getLookupSource(operatorContext);
this.joinProbeFactory = joinProbeFactory;

// Cannot use switch case here, because javac will synthesize an inner class and cause IllegalAccessError // Cannot use switch case here, because javac will synthesize an inner class and cause IllegalAccessError
probeOnOuterSide = joinType == PROBE_OUTER || joinType == FULL_OUTER; probeOnOuterSide = joinType == PROBE_OUTER || joinType == FULL_OUTER;


this.types = ImmutableList.<Type>builder() this.lookupSourceFuture = requireNonNull(lookupSourceFuture, "lookupSourceFuture is null");
.addAll(probeTypes) this.joinProbeFactory = requireNonNull(joinProbeFactory, "joinProbeFactory is null");
.addAll(lookupSourceSupplier.getTypes()) this.onClose = requireNonNull(onClose, "onClose is null");
.build();
this.pageBuilder = new PageBuilder(types); this.pageBuilder = new PageBuilder(types);
} }


Expand Down Expand Up @@ -101,12 +96,7 @@ public boolean isFinished()


// if finished drop references so memory is freed early // if finished drop references so memory is freed early
if (finished) { if (finished) {
if (lookupSource != null) { close();
lookupSource.close();
lookupSource = null;
}
probe = null;
pageBuilder.reset();
} }
return finished; return finished;
} }
Expand Down Expand Up @@ -177,16 +167,18 @@ public Page getOutput()
@Override @Override
public void close() public void close()
{ {
if (lookupSource != null) {
lookupSource.close();
lookupSource = null;
}
// Closing the lookupSource is always safe to do, but we don't want to release the supplier multiple times, since its reference counted // Closing the lookupSource is always safe to do, but we don't want to release the supplier multiple times, since its reference counted
if (closed) { if (closed) {
return; return;
} }
closed = true; closed = true;
lookupSourceSupplier.release(); probe = null;
pageBuilder.reset();
onClose.run();
// closing lookup source is only here for index join
if (lookupSource != null) {
lookupSource.close();
}
} }


private boolean joinCurrentPosition() private boolean joinCurrentPosition()
Expand Down
Expand Up @@ -15,27 +15,36 @@


import com.facebook.presto.operator.LookupJoinOperators.JoinType; import com.facebook.presto.operator.LookupJoinOperators.JoinType;
import com.facebook.presto.operator.LookupOuterOperator.LookupOuterOperatorFactory; import com.facebook.presto.operator.LookupOuterOperator.LookupOuterOperatorFactory;
import com.facebook.presto.operator.LookupOuterOperator.OuterLookupSourceSupplier; import com.facebook.presto.operator.LookupSource.OuterPositionIterator;
import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.SettableFuture;


import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.function.Consumer;


import static com.facebook.presto.operator.LookupJoinOperators.JoinType.INNER;
import static com.facebook.presto.operator.LookupJoinOperators.JoinType.PROBE_OUTER;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNull;


public class LookupJoinOperatorFactory public class LookupJoinOperatorFactory
implements JoinOperatorFactory implements JoinOperatorFactory
{ {
private final int operatorId; private final int operatorId;
private final PlanNodeId planNodeId; private final PlanNodeId planNodeId;
private final LookupSourceSupplier lookupSourceSupplier;
private final List<Type> probeTypes; private final List<Type> probeTypes;
private final List<Type> buildTypes;
private final JoinType joinType; private final JoinType joinType;
private final List<Type> types; private final LookupSourceSupplier lookupSourceSupplier;
private final JoinProbeFactory joinProbeFactory; private final JoinProbeFactory joinProbeFactory;
private final Optional<OperatorFactory> outerOperatorFactory;
private final ReferenceCount referenceCount;
private boolean closed; private boolean closed;


public LookupJoinOperatorFactory(int operatorId, public LookupJoinOperatorFactory(int operatorId,
Expand All @@ -47,40 +56,84 @@ public LookupJoinOperatorFactory(int operatorId,
{ {
this.operatorId = operatorId; this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
this.lookupSourceSupplier = lookupSourceSupplier; this.lookupSourceSupplier = requireNonNull(lookupSourceSupplier, "lookupSourceSupplier is null");
this.probeTypes = probeTypes; this.probeTypes = ImmutableList.copyOf(requireNonNull(probeTypes, "probeTypes is null"));
this.joinType = joinType; this.buildTypes = ImmutableList.copyOf(lookupSourceSupplier.getTypes());
this.joinType = requireNonNull(joinType, "joinType is null");
this.joinProbeFactory = requireNonNull(joinProbeFactory, "joinProbeFactory is null");


this.joinProbeFactory = joinProbeFactory; this.referenceCount = new ReferenceCount();


this.types = ImmutableList.<Type>builder() if (joinType == INNER || joinType == PROBE_OUTER) {
.addAll(probeTypes) // when all join operators finish, destroy the lookup source (freeing the memory)
.addAll(lookupSourceSupplier.getTypes()) this.referenceCount.getFreeFuture().addListener(lookupSourceSupplier::destroy, directExecutor());
.build(); this.outerOperatorFactory = Optional.empty();
}
else {
// when all join operators finish, set the outer position future to start the outer operator
SettableFuture<OuterPositionIterator> outerPositionsFuture = SettableFuture.create();
this.referenceCount.getFreeFuture().addListener(() -> {
// lookup source may not be finished yet, so add a listener
Futures.addCallback(
lookupSourceSupplier.getLookupSource(),
new OnSuccessFutureCallback<>(lookupSource -> outerPositionsFuture.set(lookupSource.getOuterPositionIterator())));
}, directExecutor());

// when output operator finishes, destroy the lookup source
Runnable onOperatorClose = () -> {
// lookup source may not be finished yet, so add a listener, to free the memory
lookupSourceSupplier.getLookupSource().addListener(lookupSourceSupplier::destroy, directExecutor());
};
this.outerOperatorFactory = Optional.of(new LookupOuterOperatorFactory(operatorId, planNodeId, outerPositionsFuture, probeTypes, buildTypes, onOperatorClose));
}
} }


public int getOperatorId() private LookupJoinOperatorFactory(LookupJoinOperatorFactory other)
{ {
return operatorId; requireNonNull(other, "other is null");
operatorId = other.operatorId;
planNodeId = other.planNodeId;
probeTypes = other.probeTypes;
buildTypes = other.buildTypes;
joinType = other.joinType;
lookupSourceSupplier = other.lookupSourceSupplier;
joinProbeFactory = other.joinProbeFactory;
referenceCount = other.referenceCount;
outerOperatorFactory = other.outerOperatorFactory;

referenceCount.retain();
} }


public List<Type> getProbeTypes() public int getOperatorId()
{ {
return probeTypes; return operatorId;
} }


@Override @Override
public List<Type> getTypes() public List<Type> getTypes()
{ {
return types; return ImmutableList.<Type>builder()
.addAll(probeTypes)
.addAll(buildTypes)
.build();
} }


@Override @Override
public Operator createOperator(DriverContext driverContext) public Operator createOperator(DriverContext driverContext)
{ {
checkState(!closed, "Factory is already closed"); checkState(!closed, "Factory is already closed");
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, LookupJoinOperator.class.getSimpleName()); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, LookupJoinOperator.class.getSimpleName());
return new LookupJoinOperator(operatorContext, lookupSourceSupplier, probeTypes, joinType, joinProbeFactory);
lookupSourceSupplier.setTaskContext(driverContext.getPipelineContext().getTaskContext());

referenceCount.retain();
return new LookupJoinOperator(
operatorContext,
getTypes(),
joinType,
lookupSourceSupplier.getLookupSource(),
joinProbeFactory,
referenceCount::release);
} }


@Override @Override
Expand All @@ -90,21 +143,41 @@ public void close()
return; return;
} }
closed = true; closed = true;
lookupSourceSupplier.release(); referenceCount.release();
} }


@Override @Override
public OperatorFactory duplicate() public OperatorFactory duplicate()
{ {
return new LookupJoinOperatorFactory(operatorId, planNodeId, lookupSourceSupplier, probeTypes, joinType, joinProbeFactory); return new LookupJoinOperatorFactory(this);
} }


@Override @Override
public Optional<OperatorFactory> createOuterOperatorFactory() public Optional<OperatorFactory> createOuterOperatorFactory()
{ {
if (lookupSourceSupplier instanceof OuterLookupSourceSupplier) { return outerOperatorFactory;
return Optional.of(new LookupOuterOperatorFactory(operatorId, planNodeId, (OuterLookupSourceSupplier) lookupSourceSupplier, probeTypes)); }

// We use a public class to avoid access problems with the isolated class loaders
public static class OnSuccessFutureCallback<T>
implements FutureCallback<T>
{
private final Consumer<T> onSuccess;

public OnSuccessFutureCallback(Consumer<T> onSuccess)
{
this.onSuccess = onSuccess;
}

@Override
public void onSuccess(T result)
{
onSuccess.accept(result);
}

@Override
public void onFailure(Throwable t)
{
} }
return Optional.empty();
} }
} }

0 comments on commit 6d27491

Please sign in to comment.