Skip to content

Commit

Permalink
Adds HyperLogLog++ functions for Spark
Browse files Browse the repository at this point in the history
Includes native function registration, and a hashing strategy that handles all Spark datatypes
  • Loading branch information
pidge committed Nov 10, 2018
1 parent 41b3e9e commit e7d2e54
Show file tree
Hide file tree
Showing 12 changed files with 1,014 additions and 8 deletions.
2 changes: 1 addition & 1 deletion VERSION
@@ -1 +1 @@
0.1.0-SNAPSHOT
0.2.0-SNAPSHOT
@@ -0,0 +1,7 @@
package com.swoop.alchemy.spark.expressions

import org.apache.spark.sql.SparkSession

trait FunctionRegistration {
def registerFunctions(spark: SparkSession): Unit
}
@@ -0,0 +1,85 @@
package com.swoop.alchemy.spark.expressions

import org.apache.spark.sql.EncapsulationViolator.createAnalysisException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, RuntimeReplaceable}

import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

// based on Spark's FunctionRegistry @ossSpark
trait NativeFunctionRegistration extends FunctionRegistration {

type FunctionBuilder = Seq[Expression] => Expression

def expressions: Map[String, (ExpressionInfo, FunctionBuilder)]


def registerFunctions(fr: FunctionRegistry): Unit = {
expressions.foreach { case (name, (info, builder)) => fr.registerFunction(FunctionIdentifier(name), info, builder) }
}

def registerFunctions(spark: SparkSession): Unit = {
registerFunctions(spark.sessionState.functionRegistry)
}

/** See usage above. */
protected def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {

// For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main
// constructor and contains non-parameter `child` and should not be used as function builder.
val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) {
val all = tag.runtimeClass.getConstructors
val maxNumArgs = all.map(_.getParameterCount).max
all.filterNot(_.getParameterCount == maxNumArgs)
} else {
tag.runtimeClass.getConstructors
}
// See if we can find a constructor that accepts Seq[Expression]
val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]]))
val builder = (expressions: Seq[Expression]) => {
if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) =>
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
throw createAnalysisException(e.getCause.getMessage)
}
} else {
// Otherwise, find a constructor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
throw createAnalysisException(s"Invalid number of arguments for function $name")
}
Try(f.newInstance(expressions: _*).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) =>
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
throw createAnalysisException(e.getCause.getMessage)
}
}
}

(name, (expressionInfo[T](name), builder))
}

/**
* Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name.
*/
protected def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = {
val clazz = scala.reflect.classTag[T].runtimeClass
val df = clazz.getAnnotation(classOf[ExpressionDescription])
if (df != null) {
new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended())
} else {
new ExpressionInfo(clazz.getCanonicalName, name)
}
}

}
@@ -0,0 +1,15 @@
package com.swoop.alchemy.spark.expressions

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction

trait WithHelper {
def withExpr(expr: Expression): Column = new Column(expr)

def withAggregateFunction(
func: AggregateFunction,
isDistinct: Boolean = false): Column = {
new Column(func.toAggregateExpression(isDistinct))
}
}
@@ -0,0 +1,52 @@
package com.swoop.alchemy.spark.expressions.hll

import org.apache.spark.sql
import org.apache.spark.sql.Column


/** Convenience trait to use HyperLogLog functions with the same error consistently.
* Spark's own [[sql.functions.approx_count_distinct()]] as well as the granular HLL
* [[HLLFunctions.hll_init()]] and [[HLLFunctions.hll_init_collection()]] will be
* automatically parameterized by [[BoundHLL.hllError]].
*/
trait BoundHLL extends Serializable {

def hllError: Double

def approx_count_distinct(col: Column): Column =
sql.functions.approx_count_distinct(col, hllError)

def approx_count_distinct(colName: String): Column =
sql.functions.approx_count_distinct(colName, hllError)

def hll_init(col: Column): Column =
functions.hll_init(col, hllError)

def hll_init(columnName: String): Column =
functions.hll_init(columnName, hllError)

def hll_init_collection(col: Column): Column =
functions.hll_init_collection(col, hllError)

def hll_init_collection(columnName: String): Column =
functions.hll_init_collection(columnName, hllError)

def hll_init_agg(col: Column): Column =
functions.hll_init_agg(col, hllError)

def hll_init_agg(columnName: String): Column =
functions.hll_init_agg(columnName, hllError)

def hll_init_collection_agg(col: Column): Column =
functions.hll_init_collection_agg(col, hllError)

def hll_init_collection_agg(columnName: String): Column =
functions.hll_init_collection_agg(columnName, hllError)

}

object BoundHLL {
def apply(error: Double): BoundHLL = new BoundHLL {
def hllError: Double = error
}
}
@@ -0,0 +1,47 @@
package com.swoop.alchemy.spark.expressions.hll

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{InterpretedHashFunction, XXH64}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Hash function for Spark data values that is suitable for cardinality counting. Unlike Spark's built-in hashing,
* it differentiates between different data types and accounts for nulls.
*/
abstract class CardinalityHashFunction extends InterpretedHashFunction {

override def hash(value: Any, dataType: DataType, seed: Long): Long = {

def hashWithTag(typeTag: Long) =
super.hash(value, dataType, hashLong(typeTag, seed))

value match {
// change null handling to differentiate between things like Array.empty and Array(null)
case null => hashLong(seed, seed)
// add type tags to differentiate between values on their own or in complex types
case _: Array[Byte] => hashWithTag(-3698894927619418744L)
case _: UTF8String => hashWithTag(-8468821688391060513L)
case _: ArrayData => hashWithTag(-1666055126678331734L)
case _: MapData => hashWithTag(5587693012926141532L)
case _: InternalRow => hashWithTag(-891294170547231607L)
// pass through everything else (simple types)
case _ => super.hash(value, dataType, seed)
}
}

}


object CardinalityXxHash64Function extends CardinalityHashFunction {

override protected def hashInt(i: Int, seed: Long): Long = XXH64.hashInt(i, seed)

override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed)

override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
XXH64.hashUnsafeBytes(base, offset, len, seed)
}

}
@@ -0,0 +1,17 @@
package com.swoop.alchemy.spark.expressions.hll

import com.swoop.alchemy.spark.expressions.NativeFunctionRegistration
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo

object HLLFunctionRegistration extends NativeFunctionRegistration {

val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
expression[HyperLogLogInitSimple]("hll_init"),
expression[HyperLogLogInitCollection]("hll_init_collection"),
expression[HyperLogLogInitSimpleAgg]("hll_init_agg"),
expression[HyperLogLogInitCollectionAgg]("hll_init_collection_agg"),
expression[HyperLogLogMerge]("hll_merge"),
expression[HyperLogLogCardinality]("hll_cardinality")
)

}

0 comments on commit e7d2e54

Please sign in to comment.