Skip to content

Commit

Permalink
Remove bogus OperatorContext hack
Browse files Browse the repository at this point in the history
Also simplify the PagesIndex and LookupSource to not require OperatorContext
  • Loading branch information
cberner committed Apr 9, 2015
1 parent a232d37 commit 85d57b7
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 39 deletions.
Expand Up @@ -135,7 +135,9 @@ public void finish()
return;
}

LookupSource lookupSource = pagesIndex.createLookupSource(hashChannels, operatorContext, hashChannel);
LookupSource lookupSource = pagesIndex.createLookupSource(hashChannels, hashChannel);
// TODO: should we really be using the no-fail flag here?
operatorContext.setMemoryReservation(pagesIndex.getEstimatedSize().toBytes() + lookupSource.getInMemorySizeInBytes(), true);
lookupSourceSupplier.setLookupSource(lookupSource);
finished = true;
}
Expand Down
Expand Up @@ -42,20 +42,19 @@ public final class InMemoryJoinHash
private final int mask;
private final int[] key;
private final int[] positionLinks;
private final long size;
private final List<Type> hashTypes;

public InMemoryJoinHash(LongArrayList addresses, List<Type> hashTypes, PagesHashStrategy pagesHashStrategy, OperatorContext operatorContext)
public InMemoryJoinHash(LongArrayList addresses, List<Type> hashTypes, PagesHashStrategy pagesHashStrategy)
{
this.addresses = checkNotNull(addresses, "addresses is null");
this.hashTypes = ImmutableList.copyOf(checkNotNull(hashTypes, "hashTypes is null"));
this.pagesHashStrategy = checkNotNull(pagesHashStrategy, "pagesHashStrategy is null");
this.channelCount = pagesHashStrategy.getChannelCount();

checkNotNull(operatorContext, "operatorContext is null");

// reserve memory for the arrays
int hashSize = HashCommon.arraySize(addresses.size(), 0.75f);
operatorContext.reserveMemory(sizeOfIntArray(hashSize) + sizeOfIntArray(addresses.size()));
size = sizeOfIntArray(hashSize) + sizeOfIntArray(addresses.size());

mask = hashSize - 1;
key = new int[hashSize];
Expand Down Expand Up @@ -93,6 +92,12 @@ public final int getChannelCount()
return channelCount;
}

@Override
public long getInMemorySizeInBytes()
{
return size;
}

@Override
public long getJoinPosition(int position, Page page)
{
Expand Down
Expand Up @@ -23,6 +23,8 @@ public interface LookupSource
{
int getChannelCount();

long getInMemorySizeInBytes();

long getJoinPosition(int position, Page page, int rawHash);

long getJoinPosition(int position, Page page);
Expand Down
Expand Up @@ -277,12 +277,12 @@ private PagesIndexOrdering createPagesIndexComparator(List<Integer> sortChannels
return orderingCompiler.compilePagesIndexOrdering(sortTypes, sortChannels, sortOrders);
}

public LookupSource createLookupSource(List<Integer> joinChannels, OperatorContext operatorContext)
public LookupSource createLookupSource(List<Integer> joinChannels)
{
return createLookupSource(joinChannels, operatorContext, Optional.empty());
return createLookupSource(joinChannels, Optional.empty());
}

public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, OperatorContext operatorContext, Optional<Integer> hashChannel)
public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, Optional<Integer> hashChannel)
{
try {
return joinCompiler.compilePagesHashStrategyFactory(types, joinChannels)
Expand All @@ -296,7 +296,7 @@ public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, Ope
return new SimplePagesHashStrategy(types, ImmutableList.<List<Block>>copyOf(channels), joinChannels, hashChannel);
}

public LookupSource createLookupSource(List<Integer> joinChannels, OperatorContext operatorContext, Optional<Integer> hashChannel)
public LookupSource createLookupSource(List<Integer> joinChannels, Optional<Integer> hashChannel)
{
try {
LookupSourceFactory lookupSourceFactory = joinCompiler.compileLookupSourceFactory(types, joinChannels);
Expand All @@ -309,8 +309,7 @@ public LookupSource createLookupSource(List<Integer> joinChannels, OperatorConte
valueAddresses,
joinChannelTypes.build(),
ImmutableList.<List<Block>>copyOf(channels),
hashChannel,
operatorContext);
hashChannel);

return lookupSource;
}
Expand All @@ -329,7 +328,7 @@ public LookupSource createLookupSource(List<Integer> joinChannels, OperatorConte
for (Integer channel : joinChannels) {
hashTypes.add(types.get(channel));
}
return new InMemoryJoinHash(valueAddresses, hashTypes.build(), hashStrategy, operatorContext);
return new InMemoryJoinHash(valueAddresses, hashTypes.build(), hashStrategy);
}

@Override
Expand Down
Expand Up @@ -169,8 +169,8 @@ public PrePartitionedWindowOperator(
.collect(toImmutableList());

this.pagesIndex = new PagesIndex(sourceTypes, expectedPositions);
this.partitionHashStrategy = pagesIndex.createPagesHashStrategy(partitionChannels, operatorContext, Optional.<Integer>empty());
this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, operatorContext, Optional.<Integer>empty());
this.partitionHashStrategy = pagesIndex.createPagesHashStrategy(partitionChannels, Optional.<Integer>empty());
this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, Optional.<Integer>empty());

this.pageBuilder = new PageBuilder(this.types);
}
Expand Down
Expand Up @@ -171,8 +171,8 @@ public WindowOperator(
.collect(toImmutableList());

this.pagesIndex = new PagesIndex(sourceTypes, expectedPositions);
this.partitionHashStrategy = pagesIndex.createPagesHashStrategy(partitionChannels, operatorContext, Optional.empty());
this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, operatorContext, Optional.empty());
this.partitionHashStrategy = pagesIndex.createPagesHashStrategy(partitionChannels, Optional.empty());
this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, Optional.empty());

this.pageBuilder = new PageBuilder(this.types);
}
Expand Down
Expand Up @@ -365,6 +365,12 @@ public int getChannelCount()
return channelCount;
}

@Override
public long getInMemorySizeInBytes()
{
return 0;
}

@Override
public long getJoinPosition(int position, Page page, int rawHash)
{
Expand Down
Expand Up @@ -43,6 +43,12 @@ public int getChannelCount()
return indexLoader.getChannelCount();
}

@Override
public long getInMemorySizeInBytes()
{
return 0;
}

@Override
public long getJoinPosition(int position, Page page, int rawHash)
{
Expand Down
Expand Up @@ -15,17 +15,14 @@

import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.LookupSource;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.PagesIndex;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.operator.index.UnloadedIndexKeyRecordSet.UnloadedIndexKeyRecordCursor;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import io.airlift.units.DataSize.Unit;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -45,7 +42,6 @@ public class IndexSnapshotBuilder
private final List<Integer> missingKeysChannels;

private final long maxMemoryInBytes;
private final OperatorContext bogusOperatorContext;
private PagesIndex outputPagesIndex;
private PagesIndex missingKeysIndex;
private LookupSource missingKeys;
Expand Down Expand Up @@ -83,16 +79,9 @@ public IndexSnapshotBuilder(List<Type> outputTypes,
this.missingKeysTypes = missingKeysTypes.build();
this.missingKeysChannels = missingKeysChannels.build();

// create a bogus operator context with unlimited memory for the pages index
boolean cpuTimerEnabled = driverContext.getPipelineContext().getTaskContext().isCpuTimerEnabled();
this.bogusOperatorContext = new TaskContext(driverContext.getTaskId(), driverContext.getExecutor(), driverContext.getSession(), new DataSize(Long.MAX_VALUE, Unit.BYTE), cpuTimerEnabled)
.addPipelineContext(true, true)
.addDriverContext()
.addOperatorContext(0, "operator");

this.outputPagesIndex = new PagesIndex(outputTypes, expectedPositions);
this.missingKeysIndex = new PagesIndex(missingKeysTypes.build(), expectedPositions);
this.missingKeys = missingKeysIndex.createLookupSource(this.missingKeysChannels, bogusOperatorContext);
this.missingKeys = missingKeysIndex.createLookupSource(this.missingKeysChannels);
}

public List<Type> getOutputTypes()
Expand Down Expand Up @@ -126,11 +115,10 @@ public IndexSnapshot createIndexSnapshot(UnloadedIndexKeyRecordSet indexKeysReco
checkState(!isMemoryExceeded(), "Max memory exceeded");
for (Page page : pages) {
outputPagesIndex.addPage(page);
bogusOperatorContext.setMemoryReservation(outputPagesIndex.getEstimatedSize().toBytes());
}
pages.clear();

LookupSource lookupSource = outputPagesIndex.createLookupSource(keyOutputChannels, bogusOperatorContext, keyOutputHashChannel);
LookupSource lookupSource = outputPagesIndex.createLookupSource(keyOutputChannels, keyOutputHashChannel);

// Build a page containing the keys that produced no output rows, so in future requests can skip these keys
PageBuilder missingKeysPageBuilder = new PageBuilder(missingKeysIndex.getTypes());
Expand Down Expand Up @@ -158,7 +146,7 @@ public IndexSnapshot createIndexSnapshot(UnloadedIndexKeyRecordSet indexKeysReco
// only update missing keys if we have new missing keys
if (!missingKeysPageBuilder.isEmpty()) {
missingKeysIndex.addPage(missingKeysPage);
missingKeys = missingKeysIndex.createLookupSource(missingKeysChannels, bogusOperatorContext);
missingKeys = missingKeysIndex.createLookupSource(missingKeysChannels);
}

return new IndexSnapshot(lookupSource, missingKeys);
Expand Down
Expand Up @@ -28,7 +28,6 @@
import com.facebook.presto.byteCode.instruction.LabelNode;
import com.facebook.presto.operator.InMemoryJoinHash;
import com.facebook.presto.operator.LookupSource;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.PagesHashStrategy;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.BlockBuilder;
Expand Down Expand Up @@ -524,18 +523,18 @@ public LookupSourceFactory(Class<? extends LookupSource> lookupSourceClass, Page
{
this.pagesHashStrategyFactory = pagesHashStrategyFactory;
try {
constructor = lookupSourceClass.getConstructor(LongArrayList.class, List.class, PagesHashStrategy.class, OperatorContext.class);
constructor = lookupSourceClass.getConstructor(LongArrayList.class, List.class, PagesHashStrategy.class);
}
catch (NoSuchMethodException e) {
throw Throwables.propagate(e);
}
}

public LookupSource createLookupSource(LongArrayList addresses, List<Type> types, List<List<com.facebook.presto.spi.block.Block>> channels, Optional<Integer> hashChannel, OperatorContext operatorContext)
public LookupSource createLookupSource(LongArrayList addresses, List<Type> types, List<List<com.facebook.presto.spi.block.Block>> channels, Optional<Integer> hashChannel)
{
PagesHashStrategy pagesHashStrategy = pagesHashStrategyFactory.createPagesHashStrategy(channels, hashChannel);
try {
return constructor.newInstance(addresses, types, pagesHashStrategy, operatorContext);
return constructor.newInstance(addresses, types, pagesHashStrategy);
}
catch (Exception e) {
throw Throwables.propagate(e);
Expand Down
Expand Up @@ -20,9 +20,7 @@
import com.facebook.presto.operator.JoinProbe;
import com.facebook.presto.operator.JoinProbeFactory;
import com.facebook.presto.operator.LookupSource;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.operator.ValuesOperator;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
Expand Down Expand Up @@ -82,7 +80,6 @@ public void testSingleChannel(boolean hashEnabled)
throws Exception
{
DriverContext driverContext = taskContext.addPipelineContext(true, true).addDriverContext();
OperatorContext operatorContext = driverContext.addOperatorContext(0, ValuesOperator.class.getSimpleName());

ImmutableList<Type> types = ImmutableList.<Type>of(VARCHAR);
LookupSourceFactory lookupSourceFactoryFactory = joinCompiler.compileLookupSourceFactory(types, Ints.asList(0));
Expand Down Expand Up @@ -112,7 +109,7 @@ public void testSingleChannel(boolean hashEnabled)
hashChannel = Optional.of(1);
channels = ImmutableList.of(channel, hashChannelBuilder.build());
}
LookupSource lookupSource = lookupSourceFactoryFactory.createLookupSource(addresses, types, channels, hashChannel, operatorContext);
LookupSource lookupSource = lookupSourceFactoryFactory.createLookupSource(addresses, types, channels, hashChannel);

JoinProbeCompiler joinProbeCompiler = new JoinProbeCompiler();
JoinProbeFactory probeFactory = joinProbeCompiler.internalCompileJoinProbe(types, Ints.asList(0), hashChannel);
Expand Down

0 comments on commit 85d57b7

Please sign in to comment.