Skip to content

Commit

Permalink
[SPARK-18429][SQL] implement a new Aggregate for CountMinSketch
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch.

## How was this patch tested?

add test cases

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes apache#15877 from wzhfy/cms.
  • Loading branch information
wzhfy authored and Robert Kruszewski committed Dec 2, 2016
1 parent 157430e commit cded7ed
Show file tree
Hide file tree
Showing 8 changed files with 710 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
* <li>{@link Integer}</li>
* <li>{@link Long}</li>
* <li>{@link String}</li>
* <li>{@link Float}</li>
* <li>{@link Double}</li>
* <li>{@link java.math.BigDecimal}</li>
* <li>{@link Boolean}</li>
* </ul>
* A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
* <ol>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Random;

Expand Down Expand Up @@ -152,6 +153,16 @@ public void add(Object item) {
public void add(Object item, long count) {
if (item instanceof String) {
addString((String) item, count);
} else if (item instanceof BigDecimal) {
addString(((BigDecimal) item).toString(), count);
} else if (item instanceof byte[]) {
addBinary((byte[]) item, count);
} else if (item instanceof Float) {
addLong(Float.floatToIntBits((Float) item), count);
} else if (item instanceof Double) {
addLong(Double.doubleToLongBits((Double) item), count);
} else if (item instanceof Boolean) {
addLong(((Boolean) item) ? 1L : 0L, count);
} else {
addLong(Utils.integralToLong(item), count);
}
Expand Down Expand Up @@ -216,10 +227,6 @@ private int hash(long item, int count) {
return ((int) hash) % width;
}

private static int[] getHashBuckets(String key, int hashCount, int max) {
return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
}

private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
int[] result = new int[hashCount];
int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
Expand All @@ -233,7 +240,18 @@ private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
@Override
public long estimateCount(Object item) {
if (item instanceof String) {
return estimateCountForStringItem((String) item);
return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item));
} else if (item instanceof BigDecimal) {
return estimateCountForBinaryItem(
Utils.getBytesFromUTF8String(((BigDecimal) item).toString()));
} else if (item instanceof byte[]) {
return estimateCountForBinaryItem((byte[]) item);
} else if (item instanceof Float) {
return estimateCountForLongItem(Float.floatToIntBits((Float) item));
} else if (item instanceof Double) {
return estimateCountForLongItem(Double.doubleToLongBits((Double) item));
} else if (item instanceof Boolean) {
return estimateCountForLongItem(((Boolean) item) ? 1L : 0L);
} else {
return estimateCountForLongItem(Utils.integralToLong(item));
}
Expand All @@ -247,7 +265,7 @@ private long estimateCountForLongItem(long item) {
return res;
}

private long estimateCountForStringItem(String item) {
private long estimateCountForBinaryItem(byte[] item) {
long res = Long.MAX_VALUE;
int[] buckets = getHashBuckets(item, depth, width);
for (int i = 0; i < depth; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util.sketch

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.charset.StandardCharsets

import scala.reflect.ClassTag
import scala.util.Random
Expand All @@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
}

def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
def getProbeItem(item: T): Any = item match {
// Use a string to represent the content of an array of bytes
case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
case i => identity(i)
}

test(s"accuracy - $typeName") {
// Uses fixed seed to ensure reproducible test execution
val r = new Random(31)
Expand All @@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

val exactFreq = {
val sampledItems = sampledItemIndices.map(allItems)
sampledItems.groupBy(identity).mapValues(_.length.toLong)
sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
}

val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
Expand All @@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

val probCorrect = {
val numErrors = allItems.map { item =>
val count = exactFreq.getOrElse(item, 0L)
val count = exactFreq.getOrElse(getProbeItem(item), 0L)
val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
if (ratio > epsOfTotalCount) 1 else 0
}.sum
Expand Down Expand Up @@ -135,6 +142,18 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite

testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }

testItemType[Float]("Float") { _.nextFloat() }

testItemType[Double]("Double") { _.nextDouble() }

testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) }

testItemType[Boolean]("Boolean") { _.nextBoolean() }

testItemType[Array[Byte]]("Binary") { r =>
Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20)))
}

test("incompatible merge") {
intercept[IncompatibleMergeException] {
CountMinSketch.create(10, 10, 1).mergeInPlace(null)
Expand Down
5 changes: 5 additions & 0 deletions sql/catalyst/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
<artifactId>spark-unsafe_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sketch_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ object FunctionRegistry {
expression[VarianceSamp]("var_samp"),
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),

// string functions
expression[Ascii]("ascii"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch

/**
* This function returns a count-min sketch of a column with the given esp, confidence and seed.
* A count-min sketch is a probabilistic data structure used for summarizing streams of data in
* sub-linear space, which is useful for equality predicates and join size estimation.
* The result returned by the function is an array of bytes, which should be deserialized to a
* `CountMinSketch` before usage.
*
* @param child child expression that can produce column value with `child.eval(inputRow)`
* @param epsExpression relative error, must be positive
* @param confidenceExpression confidence, must be positive and less than 1.0
* @param seedExpression random seed
*/
@ExpressionDescription(
usage = """
_FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp,
confidence and seed. The result is an array of bytes, which should be deserialized to a
`CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join
size estimation.
""")
case class CountMinSketchAgg(
child: Expression,
epsExpression: Expression,
confidenceExpression: Expression,
seedExpression: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] {

def this(
child: Expression,
epsExpression: Expression,
confidenceExpression: Expression,
seedExpression: Expression) = {
this(child, epsExpression, confidenceExpression, seedExpression, 0, 0)
}

// Mark as lazy so that they are not evaluated during tree transformation.
private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double]
private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double]
private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int]

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!epsExpression.foldable || !confidenceExpression.foldable ||
!seedExpression.foldable) {
TypeCheckFailure(
"The eps, confidence or seed provided must be a literal or constant foldable")
} else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
seedExpression.eval() == null) {
TypeCheckFailure("The eps, confidence or seed provided should not be null")
} else if (eps <= 0D) {
TypeCheckFailure(s"Relative error must be positive (current value = $eps)")
} else if (confidence <= 0D || confidence >= 1D) {
TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)")
} else {
TypeCheckSuccess
}
}

override def createAggregationBuffer(): CountMinSketch = {
CountMinSketch.create(eps, confidence, seed)
}

override def update(buffer: CountMinSketch, input: InternalRow): Unit = {
val value = child.eval(input)
// Ignore empty rows
if (value != null) {
child.dataType match {
// `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them
// into acceptable types for `CountMinSketch`.
case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal)
// For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
// instead of `addString` to avoid unnecessary conversion.
case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
case _ => buffer.add(value)
}
}
}

override def merge(buffer: CountMinSketch, input: CountMinSketch): Unit = {
buffer.mergeInPlace(input)
}

override def eval(buffer: CountMinSketch): Any = serialize(buffer)

override def serialize(buffer: CountMinSketch): Array[Byte] = {
val out = new ByteArrayOutputStream()
buffer.writeTo(out)
out.toByteArray
}

override def deserialize(storageFormat: Array[Byte]): CountMinSketch = {
val in = new ByteArrayInputStream(storageFormat)
CountMinSketch.readFrom(in)
}

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CountMinSketchAgg =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def inputTypes: Seq[AbstractDataType] = {
Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType),
DoubleType, DoubleType, IntegerType)
}

override def nullable: Boolean = false

override def dataType: DataType = BinaryType

override def children: Seq[Expression] =
Seq(child, epsExpression, confidenceExpression, seedExpression)

override def prettyName: String = "count_min_sketch"
}

0 comments on commit cded7ed

Please sign in to comment.