Skip to content

Commit

Permalink
Add JoinFilterFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
losipiuk authored and martint committed May 24, 2016
1 parent fdd802d commit 07df223
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 6 deletions.
@@ -0,0 +1,22 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.operator;

import com.facebook.presto.spi.block.Block;

public interface JoinFilterFunction
{
boolean filter(int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks);
}
Expand Up @@ -227,7 +227,13 @@ public Object evaluate(RecordCursor inputs)
public Object evaluate(int position, Block... inputs)
{
checkState(!optimize, "evaluate(int, Block...) not allowed for optimizer");
return visitor.process(expression, new PagePositionContext(position, inputs));
return visitor.process(expression, new SinglePagePositionContext(position, inputs));
}

public Object evaluate(int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks)
{
checkState(!optimize, "evaluate(int, Block[], int, Block[]) not allowed for optimizer");
return visitor.process(expression, new TwoPagesPositionContext(leftPosition, leftBlocks, rightPosition, rightBlocks));
}

public Object optimize(SymbolResolver inputs)
Expand Down Expand Up @@ -284,7 +290,7 @@ public Object visitFieldReference(FieldReference node, Object context)
int channel = node.getFieldIndex();
if (context instanceof PagePositionContext) {
PagePositionContext pagePositionContext = (PagePositionContext) context;
int position = pagePositionContext.getPosition();
int position = pagePositionContext.getPosition(channel);
Block block = pagePositionContext.getBlock(channel);

if (block.isNull(position)) {
Expand Down Expand Up @@ -1124,12 +1130,20 @@ private Object invokeOperator(OperatorType operatorType, List<? extends Type> ar
}
}

private static class PagePositionContext
private interface PagePositionContext
{
public Block getBlock(int channel);

public int getPosition(int channel);
}

private static class SinglePagePositionContext
implements PagePositionContext
{
private final int position;
private final Block[] blocks;

private PagePositionContext(int position, Block[] blocks)
private SinglePagePositionContext(int position, Block[] blocks)
{
this.position = position;
this.blocks = blocks;
Expand All @@ -1140,12 +1154,51 @@ public Block getBlock(int channel)
return blocks[channel];
}

public int getPosition()
public int getPosition(int channel)
{
return position;
}
}

private static class TwoPagesPositionContext
implements PagePositionContext
{
private final int leftPosition;
private final int rightPosition;
private final Block[] leftBlocks;
private final Block[] rightBlocks;

private TwoPagesPositionContext(int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks)
{
this.leftPosition = leftPosition;
this.rightPosition = rightPosition;
this.leftBlocks = leftBlocks;
this.rightBlocks = rightBlocks;
}

@Override
public Block getBlock(int channel)
{
if (channel < leftBlocks.length) {
return leftBlocks[channel];
}
else {
return rightBlocks[channel - leftBlocks.length];
}
}

@Override
public int getPosition(int channel)
{
if (channel < leftBlocks.length) {
return leftPosition;
}
else {
return rightPosition;
}
}
}

public static Object invoke(ConnectorSession session, ScalarFunctionImplementation function, List<Object> argumentValues)
{
MethodHandle handle = function.getMethodHandle();
Expand Down
Expand Up @@ -16,6 +16,7 @@
import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.operator.FilterFunction;
import com.facebook.presto.operator.JoinFilterFunction;
import com.facebook.presto.spi.RecordCursor;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.Type;
Expand All @@ -33,7 +34,7 @@
import static java.lang.Boolean.TRUE;

public class InterpretedFilterFunction
implements FilterFunction
implements FilterFunction, JoinFilterFunction
{
private final ExpressionInterpreter evaluator;
private final Set<Integer> inputChannels;
Expand Down Expand Up @@ -62,6 +63,12 @@ public InterpretedFilterFunction(
this.inputChannels = ImmutableSet.copyOf(inputReferenceExtractor.getInputChannels());
}

@Override
public boolean filter(int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks)
{
return evaluator.evaluate(leftPosition, leftBlocks, rightPosition, rightBlocks) == TRUE;
}

@Override
public boolean filter(int position, Block... blocks)
{
Expand Down

0 comments on commit 07df223

Please sign in to comment.