Skip to content

Commit

Permalink
Add approx_distinct for boolean
Browse files Browse the repository at this point in the history
  • Loading branch information
ankitdixit authored and martint committed Feb 5, 2019
1 parent 20fb7d8 commit 7569071
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.annotations.VisibleForTesting;
import io.airlift.slice.Slice;
import io.airlift.stats.cardinality.HyperLogLog;
import io.prestosql.operator.aggregation.state.BooleanDistinctState;
import io.prestosql.operator.aggregation.state.HyperLogLogState;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
Expand Down Expand Up @@ -120,6 +121,12 @@ public static void input(
state.addMemoryUsage(hll.estimatedInMemorySize());
}

@InputFunction
public static void input(BooleanDistinctState state, @SqlType(StandardTypes.BOOLEAN) boolean value, @SqlType(StandardTypes.DOUBLE) double maxStandardError)
{
state.setByte((byte) (state.getByte() | (value ? 1 : 2)));
}

private static HyperLogLog getOrCreateHyperLogLog(HyperLogLogState state, double maxStandardError)
{
HyperLogLog hll = state.getHyperLogLog();
Expand Down Expand Up @@ -162,6 +169,12 @@ public static void combineState(@AggregationState HyperLogLogState state, @Aggre
}
}

@CombineFunction
public static void combineState(BooleanDistinctState state, BooleanDistinctState otherState)
{
state.setByte((byte) (state.getByte() | otherState.getByte()));
}

@OutputFunction(StandardTypes.BIGINT)
public static void evaluateFinal(@AggregationState HyperLogLogState state, BlockBuilder out)
{
Expand All @@ -173,4 +186,10 @@ public static void evaluateFinal(@AggregationState HyperLogLogState state, Block
BIGINT.writeLong(out, hyperLogLog.cardinality());
}
}

@OutputFunction(StandardTypes.BIGINT)
public static void evaluateFinal(BooleanDistinctState state, BlockBuilder out)
{
BIGINT.writeLong(out, Integer.bitCount(state.getByte()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.prestosql.operator.aggregation;

import io.airlift.slice.Slice;
import io.prestosql.operator.aggregation.state.BooleanDistinctState;
import io.prestosql.operator.aggregation.state.HyperLogLogState;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
Expand Down Expand Up @@ -79,15 +80,33 @@ public static void input(
ApproximateCountDistinctAggregation.input(methodHandle, state, value, DEFAULT_STANDARD_ERROR);
}

@InputFunction
public static void input(BooleanDistinctState state, @SqlType(StandardTypes.BOOLEAN) boolean value)
{
ApproximateCountDistinctAggregation.input(state, value, DEFAULT_STANDARD_ERROR);
}

@CombineFunction
public static void combineState(@AggregationState HyperLogLogState state, @AggregationState HyperLogLogState otherState)
{
ApproximateCountDistinctAggregation.combineState(state, otherState);
}

@CombineFunction
public static void combineState(BooleanDistinctState state, BooleanDistinctState otherState)
{
ApproximateCountDistinctAggregation.combineState(state, otherState);
}

@OutputFunction(StandardTypes.BIGINT)
public static void evaluateFinal(@AggregationState HyperLogLogState state, BlockBuilder out)
{
ApproximateCountDistinctAggregation.evaluateFinal(state, out);
}

@OutputFunction(StandardTypes.BIGINT)
public static void evaluateFinal(BooleanDistinctState state, BlockBuilder out)
{
ApproximateCountDistinctAggregation.evaluateFinal(state, out);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.operator.aggregation.state;

import io.prestosql.spi.function.AccumulatorState;

public interface BooleanDistinctState
extends AccumulatorState
{
byte getByte();

void setByte(byte value);
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void testMultiplePositionsPartial(double maxStandardError)
}
}

private void assertCount(List<Object> values, double maxStandardError, long expectedCount)
protected void assertCount(List<?> values, double maxStandardError, long expectedCount)
{
if (!values.isEmpty()) {
assertEquals(estimateGroupByCount(values, maxStandardError), expectedCount);
Expand All @@ -136,25 +136,25 @@ private void assertCount(List<Object> values, double maxStandardError, long expe
assertEquals(estimateCountPartial(values, maxStandardError), expectedCount);
}

private long estimateGroupByCount(List<Object> values, double maxStandardError)
private long estimateGroupByCount(List<?> values, double maxStandardError)
{
Object result = AggregationTestUtils.groupedAggregation(getAggregationFunction(), createPage(values, maxStandardError));
return (long) result;
}

private long estimateCount(List<Object> values, double maxStandardError)
private long estimateCount(List<?> values, double maxStandardError)
{
Object result = AggregationTestUtils.aggregation(getAggregationFunction(), createPage(values, maxStandardError));
return (long) result;
}

private long estimateCountPartial(List<Object> values, double maxStandardError)
private long estimateCountPartial(List<?> values, double maxStandardError)
{
Object result = AggregationTestUtils.partialAggregation(getAggregationFunction(), createPage(values, maxStandardError));
return (long) result;
}

private Page createPage(List<Object> values, double maxStandardError)
private Page createPage(List<?> values, double maxStandardError)
{
if (values.isEmpty()) {
return new Page(0);
Expand All @@ -169,7 +169,7 @@ private Page createPage(List<Object> values, double maxStandardError)
/**
* Produce a block with the given values in the last field.
*/
private static Block createBlock(Type type, List<Object> values)
private static Block createBlock(Type type, List<?> values)
{
BlockBuilder blockBuilder = type.createBlockBuilder(null, values.size());

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.operator.aggregation;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Booleans;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.type.Type;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

import static io.prestosql.metadata.FunctionKind.AGGREGATE;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.BooleanType.BOOLEAN;
import static io.prestosql.spi.type.DoubleType.DOUBLE;

public class TestApproximateCountDistinctBoolean
extends AbstractTestApproximateCountDistinct
{
@Override
public InternalAggregationFunction getAggregationFunction()
{
return metadata.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), BOOLEAN.getTypeSignature(), DOUBLE.getTypeSignature()));
}

@Override
public Type getValueType()
{
return BOOLEAN;
}

@Override
public Object randomValue()
{
return ThreadLocalRandom.current().nextBoolean();
}

@DataProvider(name = "inputSequences")
public Object[][] inputSequences()
{
return new Object[][] {
{true},
{false},
{true, false},
{true, true, true},
{false, false, false},
{true, false, true, false},
};
}

@Test(dataProvider = "inputSequences")
public void testNonEmptyInputs(boolean... inputSequence)
{
List<Boolean> values = Booleans.asList(inputSequence);
assertCount(values, 0, distinctCount(values));
}

@Test
public void testNoInput()
{
assertCount(ImmutableList.of(), 0, 0);
}

private long distinctCount(List<Boolean> inputSequence)
{
return ImmutableSet.copyOf(inputSequence).size();
}

@Override
protected int getUniqueValuesCount()
{
return 2;
}
}

0 comments on commit 7569071

Please sign in to comment.