Skip to content

Commit

Permalink
Encapsulate InternalJoinFilterFunction in JoinFilterFunctionCompiler
Browse files Browse the repository at this point in the history
InternalJoinFilterFunction is an internal detail of JoinFilterFunction and
should not be exposed to callers. The JoinFilterFuction implementation has
direct access the build pages for performance reasons, and thus can not be
created until the end of the build.  A new JoinFilterFunctionFactory is
added to support this delayed creation.
  • Loading branch information
dain committed Oct 24, 2016
1 parent b490cee commit 596ca8d
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 91 deletions.
Expand Up @@ -15,6 +15,7 @@


import com.facebook.presto.spi.Page; import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory;
import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -44,7 +45,7 @@ private enum State {
private final SettableLookupSourceSupplier lookupSourceSupplier; private final SettableLookupSourceSupplier lookupSourceSupplier;
private final List<Integer> hashChannels; private final List<Integer> hashChannels;
private final Optional<Integer> hashChannel; private final Optional<Integer> hashChannel;
private final Optional<InternalJoinFilterFunction> filterFunction; private final Optional<JoinFilterFunctionFactory> filterFunctionFactory;


private final int expectedPositions; private final int expectedPositions;
private State state = State.NOT_CREATED; private State state = State.NOT_CREATED;
Expand All @@ -57,7 +58,7 @@ public HashBuilderOperatorFactory(
List<Integer> hashChannels, List<Integer> hashChannels,
Optional<Integer> hashChannel, Optional<Integer> hashChannel,
boolean outer, boolean outer,
Optional<InternalJoinFilterFunction> filterFunction, Optional<JoinFilterFunctionFactory> filterFunctionFactory,
int expectedPositions) int expectedPositions)
{ {
this.operatorId = operatorId; this.operatorId = operatorId;
Expand All @@ -69,7 +70,7 @@ public HashBuilderOperatorFactory(


this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null")); this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null"));
this.hashChannel = requireNonNull(hashChannel, "hashChannel is null"); this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
this.filterFunction = requireNonNull(filterFunction, "filterFunction is null"); this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null");


this.expectedPositions = expectedPositions; this.expectedPositions = expectedPositions;
} }
Expand Down Expand Up @@ -97,7 +98,7 @@ public Operator createOperator(DriverContext driverContext)
lookupSourceSupplier, lookupSourceSupplier,
hashChannels, hashChannels,
hashChannel, hashChannel,
filterFunction, filterFunctionFactory,
expectedPositions); expectedPositions);
} }


Expand All @@ -118,7 +119,7 @@ public OperatorFactory duplicate()
private final SettableLookupSourceSupplier lookupSourceSupplier; private final SettableLookupSourceSupplier lookupSourceSupplier;
private final List<Integer> hashChannels; private final List<Integer> hashChannels;
private final Optional<Integer> hashChannel; private final Optional<Integer> hashChannel;
private final Optional<InternalJoinFilterFunction> filterFunction; private final Optional<JoinFilterFunctionFactory> filterFunctionFactory;


private final PagesIndex pagesIndex; private final PagesIndex pagesIndex;


Expand All @@ -129,7 +130,7 @@ public HashBuilderOperator(
SettableLookupSourceSupplier lookupSourceSupplier, SettableLookupSourceSupplier lookupSourceSupplier,
List<Integer> hashChannels, List<Integer> hashChannels,
Optional<Integer> hashChannel, Optional<Integer> hashChannel,
Optional<InternalJoinFilterFunction> filterFunction, Optional<JoinFilterFunctionFactory> filterFunctionFactory,
int expectedPositions) int expectedPositions)
{ {
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
Expand All @@ -138,7 +139,7 @@ public HashBuilderOperator(


this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null")); this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null"));
this.hashChannel = requireNonNull(hashChannel, "hashChannel is null"); this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
this.filterFunction = requireNonNull(filterFunction, "filterFunction is null"); this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null");


this.pagesIndex = new PagesIndex(lookupSourceSupplier.getTypes(), expectedPositions); this.pagesIndex = new PagesIndex(lookupSourceSupplier.getTypes(), expectedPositions);
} }
Expand All @@ -163,7 +164,7 @@ public void finish()
} }


// After this point the LookupSource will take over our memory reservation, and ours will be zero // After this point the LookupSource will take over our memory reservation, and ours will be zero
LookupSource lookupSource = pagesIndex.createLookupSource(hashChannels, hashChannel, filterFunction); LookupSource lookupSource = pagesIndex.createLookupSource(operatorContext.getSession(), hashChannels, hashChannel, filterFunctionFactory);
lookupSourceSupplier.setLookupSource(lookupSource, operatorContext); lookupSourceSupplier.setLookupSource(lookupSource, operatorContext);
finished = true; finished = true;
} }
Expand Down
Expand Up @@ -13,13 +13,15 @@
*/ */
package com.facebook.presto.operator; package com.facebook.presto.operator;


import com.facebook.presto.Session;
import com.facebook.presto.spi.Page; import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.block.SortOrder;
import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory;
import com.facebook.presto.sql.gen.OrderingCompiler; import com.facebook.presto.sql.gen.OrderingCompiler;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger; import io.airlift.log.Logger;
Expand Down Expand Up @@ -302,9 +304,9 @@ private PagesIndexOrdering createPagesIndexComparator(List<Integer> sortChannels
return orderingCompiler.compilePagesIndexOrdering(sortTypes, sortChannels, sortOrders); return orderingCompiler.compilePagesIndexOrdering(sortTypes, sortChannels, sortOrders);
} }


public LookupSource createLookupSource(List<Integer> joinChannels) public LookupSource createLookupSource(Session session, List<Integer> joinChannels)
{ {
return createLookupSource(joinChannels, Optional.empty(), Optional.empty()); return createLookupSource(session, joinChannels, Optional.empty(), Optional.empty());
} }


public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, Optional<Integer> hashChannel) public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, Optional<Integer> hashChannel)
Expand All @@ -321,13 +323,12 @@ public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, Opt
return new SimplePagesHashStrategy(types, ImmutableList.<List<Block>>copyOf(channels), joinChannels, hashChannel); return new SimplePagesHashStrategy(types, ImmutableList.<List<Block>>copyOf(channels), joinChannels, hashChannel);
} }


private Optional<JoinFilterFunction> createJoinFilterFunction(Optional<InternalJoinFilterFunction> filterFunctionOptional, List<List<Block>> channels) public LookupSource createLookupSource(Session session, List<Integer> joinChannels, Optional<Integer> hashChannel, Optional<JoinFilterFunctionFactory> filterFunctionFactory)
{ {
return filterFunctionOptional.map(filterFunction -> joinCompiler.compileJoinFilterFunctionFactory(filterFunction).create(filterFunction, channels)); Optional<JoinFilterFunction> joinFilterFunction = filterFunctionFactory.map(factory -> factory.create(
} session.toConnectorSession(),
ImmutableList.copyOf(channels)));


public LookupSource createLookupSource(List<Integer> joinChannels, Optional<Integer> hashChannel, Optional<InternalJoinFilterFunction> filterFunction)
{
if (!joinChannels.isEmpty()) { if (!joinChannels.isEmpty()) {
// todo compiled implementation of lookup join does not support when we are joining with empty join channels. // todo compiled implementation of lookup join does not support when we are joining with empty join channels.
// This code path will trigger only for OUTER joins. To fix that we need to add support for // This code path will trigger only for OUTER joins. To fix that we need to add support for
Expand All @@ -340,7 +341,7 @@ public LookupSource createLookupSource(List<Integer> joinChannels, Optional<Inte
valueAddresses, valueAddresses,
ImmutableList.copyOf(channels), ImmutableList.copyOf(channels),
hashChannel, hashChannel,
createJoinFilterFunction(filterFunction, ImmutableList.copyOf(channels))); joinFilterFunction);
} }
catch (Exception e) { catch (Exception e) {
log.error(e, "Lookup source compile failed for types=%s error=%s", types, e); log.error(e, "Lookup source compile failed for types=%s error=%s", types, e);
Expand All @@ -355,7 +356,7 @@ public LookupSource createLookupSource(List<Integer> joinChannels, Optional<Inte
hashChannel hashChannel
); );


return new InMemoryJoinHash(valueAddresses, hashStrategy, createJoinFilterFunction(filterFunction, ImmutableList.copyOf(channels))); return new InMemoryJoinHash(valueAddresses, hashStrategy, joinFilterFunction);
} }


@Override @Override
Expand Down
Expand Up @@ -15,6 +15,7 @@


import com.facebook.presto.spi.Page; import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory;
import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -43,7 +44,7 @@ public static class ParallelHashBuildOperatorFactory
private final PartitionedLookupSourceSupplier lookupSourceSupplier; private final PartitionedLookupSourceSupplier lookupSourceSupplier;
private final List<Integer> hashChannels; private final List<Integer> hashChannels;
private final Optional<Integer> preComputedHashChannel; private final Optional<Integer> preComputedHashChannel;
private final Optional<InternalJoinFilterFunction> filterFunction; private final Optional<JoinFilterFunctionFactory> filterFunctionFactory;


private final int expectedPositions; private final int expectedPositions;


Expand All @@ -58,7 +59,7 @@ public ParallelHashBuildOperatorFactory(
List<Integer> hashChannels, List<Integer> hashChannels,
Optional<Integer> preComputedHashChannel, Optional<Integer> preComputedHashChannel,
boolean outer, boolean outer,
Optional<InternalJoinFilterFunction> filterFunction, Optional<JoinFilterFunctionFactory> filterFunctionFactory,
int expectedPositions, int expectedPositions,
int partitionCount) int partitionCount)
{ {
Expand All @@ -76,7 +77,7 @@ public ParallelHashBuildOperatorFactory(
checkArgument(!hashChannels.isEmpty(), "hashChannels is empty"); checkArgument(!hashChannels.isEmpty(), "hashChannels is empty");
this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null")); this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null"));
this.preComputedHashChannel = requireNonNull(preComputedHashChannel, "preComputedHashChannel is null"); this.preComputedHashChannel = requireNonNull(preComputedHashChannel, "preComputedHashChannel is null");
this.filterFunction = requireNonNull(filterFunction, "filterFunction is null"); this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null");


this.expectedPositions = expectedPositions; this.expectedPositions = expectedPositions;
} }
Expand All @@ -103,7 +104,7 @@ public Operator createOperator(DriverContext driverContext)
partitionIndex, partitionIndex,
hashChannels, hashChannels,
preComputedHashChannel, preComputedHashChannel,
filterFunction, filterFunctionFactory,
expectedPositions); expectedPositions);


partitionIndex++; partitionIndex++;
Expand All @@ -129,7 +130,7 @@ public OperatorFactory duplicate()


private final List<Integer> hashChannels; private final List<Integer> hashChannels;
private final Optional<Integer> preComputedHashChannel; private final Optional<Integer> preComputedHashChannel;
private final Optional<InternalJoinFilterFunction> filterFunction; private final Optional<JoinFilterFunctionFactory> filterFunctionFactory;


private final PagesIndex index; private final PagesIndex index;


Expand All @@ -141,12 +142,12 @@ public ParallelHashBuildOperator(
int partitionIndex, int partitionIndex,
List<Integer> hashChannels, List<Integer> hashChannels,
Optional<Integer> preComputedHashChannel, Optional<Integer> preComputedHashChannel,
Optional<InternalJoinFilterFunction> filterFunction, Optional<JoinFilterFunctionFactory> filterFunctionFactory,
int expectedPositions) int expectedPositions)
{ {
this.operatorContext = operatorContext; this.operatorContext = operatorContext;
this.partitionIndex = partitionIndex; this.partitionIndex = partitionIndex;
this.filterFunction = filterFunction; this.filterFunctionFactory = filterFunctionFactory;


this.index = new PagesIndex(lookupSourceSupplier.getTypes(), expectedPositions); this.index = new PagesIndex(lookupSourceSupplier.getTypes(), expectedPositions);
this.lookupSourceSupplier = lookupSourceSupplier; this.lookupSourceSupplier = lookupSourceSupplier;
Expand Down Expand Up @@ -175,7 +176,7 @@ public void finish()
} }
finishing = true; finishing = true;


LookupSource lookupSource = index.createLookupSource(hashChannels, preComputedHashChannel, filterFunction); LookupSource lookupSource = index.createLookupSource(operatorContext.getSession(), hashChannels, preComputedHashChannel, filterFunctionFactory);
lookupSourceSupplier.setLookupSource(partitionIndex, lookupSource); lookupSourceSupplier.setLookupSource(partitionIndex, lookupSource);


operatorContext.setMemoryReservation(lookupSource.getInMemorySizeInBytes()); operatorContext.setMemoryReservation(lookupSource.getInMemorySizeInBytes());
Expand Down
Expand Up @@ -13,6 +13,7 @@
*/ */
package com.facebook.presto.operator.index; package com.facebook.presto.operator.index;


import com.facebook.presto.Session;
import com.facebook.presto.operator.DriverContext; import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.LookupSource; import com.facebook.presto.operator.LookupSource;
import com.facebook.presto.operator.PagesIndex; import com.facebook.presto.operator.PagesIndex;
Expand All @@ -34,6 +35,7 @@


public class IndexSnapshotBuilder public class IndexSnapshotBuilder
{ {
private final Session session;
private final int expectedPositions; private final int expectedPositions;
private final List<Type> outputTypes; private final List<Type> outputTypes;
private final List<Type> missingKeysTypes; private final List<Type> missingKeysTypes;
Expand Down Expand Up @@ -63,6 +65,7 @@ public IndexSnapshotBuilder(List<Type> outputTypes,
requireNonNull(maxMemoryInBytes, "maxMemoryInBytes is null"); requireNonNull(maxMemoryInBytes, "maxMemoryInBytes is null");
checkArgument(expectedPositions > 0, "expectedPositions must be greater than zero"); checkArgument(expectedPositions > 0, "expectedPositions must be greater than zero");


this.session = driverContext.getSession();
this.outputTypes = ImmutableList.copyOf(outputTypes); this.outputTypes = ImmutableList.copyOf(outputTypes);
this.expectedPositions = expectedPositions; this.expectedPositions = expectedPositions;
this.keyOutputChannels = ImmutableList.copyOf(keyOutputChannels); this.keyOutputChannels = ImmutableList.copyOf(keyOutputChannels);
Expand All @@ -81,7 +84,7 @@ public IndexSnapshotBuilder(List<Type> outputTypes,


this.outputPagesIndex = new PagesIndex(outputTypes, expectedPositions); this.outputPagesIndex = new PagesIndex(outputTypes, expectedPositions);
this.missingKeysIndex = new PagesIndex(missingKeysTypes.build(), expectedPositions); this.missingKeysIndex = new PagesIndex(missingKeysTypes.build(), expectedPositions);
this.missingKeys = missingKeysIndex.createLookupSource(this.missingKeysChannels); this.missingKeys = missingKeysIndex.createLookupSource(session, this.missingKeysChannels);
} }


public List<Type> getOutputTypes() public List<Type> getOutputTypes()
Expand Down Expand Up @@ -118,7 +121,7 @@ public IndexSnapshot createIndexSnapshot(UnloadedIndexKeyRecordSet indexKeysReco
} }
pages.clear(); pages.clear();


LookupSource lookupSource = outputPagesIndex.createLookupSource(keyOutputChannels, keyOutputHashChannel, Optional.empty()); LookupSource lookupSource = outputPagesIndex.createLookupSource(session, keyOutputChannels, keyOutputHashChannel, Optional.empty());


// Build a page containing the keys that produced no output rows, so in future requests can skip these keys // Build a page containing the keys that produced no output rows, so in future requests can skip these keys
PageBuilder missingKeysPageBuilder = new PageBuilder(missingKeysIndex.getTypes()); PageBuilder missingKeysPageBuilder = new PageBuilder(missingKeysIndex.getTypes());
Expand Down Expand Up @@ -146,7 +149,7 @@ public IndexSnapshot createIndexSnapshot(UnloadedIndexKeyRecordSet indexKeysReco
// only update missing keys if we have new missing keys // only update missing keys if we have new missing keys
if (!missingKeysPageBuilder.isEmpty()) { if (!missingKeysPageBuilder.isEmpty()) {
missingKeysIndex.addPage(missingKeysPage); missingKeysIndex.addPage(missingKeysPage);
missingKeys = missingKeysIndex.createLookupSource(missingKeysChannels); missingKeys = missingKeysIndex.createLookupSource(session, missingKeysChannels);
} }


return new IndexSnapshot(lookupSource, missingKeys); return new IndexSnapshot(lookupSource, missingKeys);
Expand Down
Expand Up @@ -27,11 +27,9 @@
import com.facebook.presto.bytecode.expression.BytecodeExpression; import com.facebook.presto.bytecode.expression.BytecodeExpression;
import com.facebook.presto.bytecode.instruction.LabelNode; import com.facebook.presto.bytecode.instruction.LabelNode;
import com.facebook.presto.operator.InMemoryJoinHash; import com.facebook.presto.operator.InMemoryJoinHash;
import com.facebook.presto.operator.InternalJoinFilterFunction;
import com.facebook.presto.operator.JoinFilterFunction; import com.facebook.presto.operator.JoinFilterFunction;
import com.facebook.presto.operator.LookupSource; import com.facebook.presto.operator.LookupSource;
import com.facebook.presto.operator.PagesHashStrategy; import com.facebook.presto.operator.PagesHashStrategy;
import com.facebook.presto.operator.StandardJoinFilterFunction;
import com.facebook.presto.spi.Page; import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.Block;
Expand All @@ -48,7 +46,6 @@
import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongArrayList;


import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
Expand Down Expand Up @@ -96,21 +93,6 @@ public Class<? extends PagesHashStrategy> load(CacheKey key)
} }
}); });


private final LoadingCache<Class<? extends InternalJoinFilterFunction>, Class<? extends JoinFilterFunction>> joinFilterFunctionClasses = CacheBuilder.newBuilder().maximumSize(1000).build(
new CacheLoader<Class<? extends InternalJoinFilterFunction>, Class<? extends JoinFilterFunction>>()
{
@Override
public Class<? extends JoinFilterFunction> load(Class<? extends InternalJoinFilterFunction> key)
throws Exception
{
return IsolatedClass.isolateClass(
new DynamicClassLoader(getClass().getClassLoader()),
JoinFilterFunction.class,
StandardJoinFilterFunction.class
);
}
});

public LookupSourceFactory compileLookupSourceFactory(List<? extends Type> types, List<Integer> joinChannels) public LookupSourceFactory compileLookupSourceFactory(List<? extends Type> types, List<Integer> joinChannels)
{ {
try { try {
Expand All @@ -121,21 +103,6 @@ public LookupSourceFactory compileLookupSourceFactory(List<? extends Type> types
} }
} }


public JoinFilterFunctionFactory compileJoinFilterFunctionFactory(InternalJoinFilterFunction internalJoinFilterFunction)
{
return ((filterFunction, channels) -> {
try {
return joinFilterFunctionClasses
.get(internalJoinFilterFunction.getClass())
.getConstructor(InternalJoinFilterFunction.class, List.class)
.newInstance(filterFunction, channels);
}
catch (ExecutionException | UncheckedExecutionException | ExecutionError | NoSuchMethodException | InstantiationException | IllegalAccessException | InvocationTargetException e) {
throw Throwables.propagate(e.getCause());
}
});
}

public PagesHashStrategyFactory compilePagesHashStrategyFactory(List<Type> types, List<Integer> joinChannels) public PagesHashStrategyFactory compilePagesHashStrategyFactory(List<Type> types, List<Integer> joinChannels)
{ {
requireNonNull(types, "types is null"); requireNonNull(types, "types is null");
Expand Down Expand Up @@ -751,11 +718,6 @@ public PagesHashStrategy createPagesHashStrategy(List<? extends List<Block>> cha
} }
} }


public interface JoinFilterFunctionFactory
{
JoinFilterFunction create(InternalJoinFilterFunction filterFunction, List<List<Block>> channels);
}

private static final class CacheKey private static final class CacheKey
{ {
private final List<Type> types; private final List<Type> types;
Expand Down

0 comments on commit 596ca8d

Please sign in to comment.