diff --git a/VERSION b/VERSION index 4ecb664..d144648 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0-SNAPSHOT \ No newline at end of file +0.2.0-SNAPSHOT diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/FunctionRegistration.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/FunctionRegistration.scala new file mode 100644 index 0000000..dc27f83 --- /dev/null +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/FunctionRegistration.scala @@ -0,0 +1,7 @@ +package com.swoop.alchemy.spark.expressions + +import org.apache.spark.sql.SparkSession + +trait FunctionRegistration { + def registerFunctions(spark: SparkSession): Unit +} diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/NativeFunctionRegistration.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/NativeFunctionRegistration.scala new file mode 100644 index 0000000..a86eeb7 --- /dev/null +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/NativeFunctionRegistration.scala @@ -0,0 +1,85 @@ +package com.swoop.alchemy.spark.expressions + +import org.apache.spark.sql.EncapsulationViolator.createAnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, RuntimeReplaceable} + +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} + +// based on Spark's FunctionRegistry @ossSpark +trait NativeFunctionRegistration extends FunctionRegistration { + + type FunctionBuilder = Seq[Expression] => Expression + + def expressions: Map[String, (ExpressionInfo, FunctionBuilder)] + + + def registerFunctions(fr: FunctionRegistry): Unit = { + expressions.foreach { case (name, (info, builder)) => fr.registerFunction(FunctionIdentifier(name), info, builder) } + } + + def registerFunctions(spark: SparkSession): Unit = { + registerFunctions(spark.sessionState.functionRegistry) + } + + /** See usage above. */ + protected def expression[T <: Expression](name: String) + (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { + + // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main + // constructor and contains non-parameter `child` and should not be used as function builder. + val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) { + val all = tag.runtimeClass.getConstructors + val maxNumArgs = all.map(_.getParameterCount).max + all.filterNot(_.getParameterCount == maxNumArgs) + } else { + tag.runtimeClass.getConstructors + } + // See if we can find a constructor that accepts Seq[Expression] + val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]])) + val builder = (expressions: Seq[Expression]) => { + if (varargCtor.isDefined) { + // If there is an apply method that accepts Seq[Expression], use that one. + Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match { + case Success(e) => e + case Failure(e) => + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + throw createAnalysisException(e.getCause.getMessage) + } + } else { + // Otherwise, find a constructor method that matches the number of arguments, and use that. + val params = Seq.fill(expressions.size)(classOf[Expression]) + val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { + throw createAnalysisException(s"Invalid number of arguments for function $name") + } + Try(f.newInstance(expressions: _*).asInstanceOf[Expression]) match { + case Success(e) => e + case Failure(e) => + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + throw createAnalysisException(e.getCause.getMessage) + } + } + } + + (name, (expressionInfo[T](name), builder)) + } + + /** + * Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name. + */ + protected def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = { + val clazz = scala.reflect.classTag[T].runtimeClass + val df = clazz.getAnnotation(classOf[ExpressionDescription]) + if (df != null) { + new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended()) + } else { + new ExpressionInfo(clazz.getCanonicalName, name) + } + } + +} diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/WithHelper.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/WithHelper.scala new file mode 100644 index 0000000..6cf7ca7 --- /dev/null +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/WithHelper.scala @@ -0,0 +1,15 @@ +package com.swoop.alchemy.spark.expressions + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction + +trait WithHelper { + def withExpr(expr: Expression): Column = new Column(expr) + + def withAggregateFunction( + func: AggregateFunction, + isDistinct: Boolean = false): Column = { + new Column(func.toAggregateExpression(isDistinct)) + } +} diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/BoundHLL.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/BoundHLL.scala new file mode 100644 index 0000000..48621ab --- /dev/null +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/BoundHLL.scala @@ -0,0 +1,52 @@ +package com.swoop.alchemy.spark.expressions.hll + +import org.apache.spark.sql +import org.apache.spark.sql.Column + + +/** Convenience trait to use HyperLogLog functions with the same error consistently. + * Spark's own [[sql.functions.approx_count_distinct()]] as well as the granular HLL + * [[HLLFunctions.hll_init()]] and [[HLLFunctions.hll_init_collection()]] will be + * automatically parameterized by [[BoundHLL.hllError]]. + */ +trait BoundHLL extends Serializable { + + def hllError: Double + + def approx_count_distinct(col: Column): Column = + sql.functions.approx_count_distinct(col, hllError) + + def approx_count_distinct(colName: String): Column = + sql.functions.approx_count_distinct(colName, hllError) + + def hll_init(col: Column): Column = + functions.hll_init(col, hllError) + + def hll_init(columnName: String): Column = + functions.hll_init(columnName, hllError) + + def hll_init_collection(col: Column): Column = + functions.hll_init_collection(col, hllError) + + def hll_init_collection(columnName: String): Column = + functions.hll_init_collection(columnName, hllError) + + def hll_init_agg(col: Column): Column = + functions.hll_init_agg(col, hllError) + + def hll_init_agg(columnName: String): Column = + functions.hll_init_agg(columnName, hllError) + + def hll_init_collection_agg(col: Column): Column = + functions.hll_init_collection_agg(col, hllError) + + def hll_init_collection_agg(columnName: String): Column = + functions.hll_init_collection_agg(columnName, hllError) + +} + +object BoundHLL { + def apply(error: Double): BoundHLL = new BoundHLL { + def hllError: Double = error + } +} diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunction.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunction.scala new file mode 100644 index 0000000..82cb008 --- /dev/null +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunction.scala @@ -0,0 +1,47 @@ +package com.swoop.alchemy.spark.expressions.hll + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{InterpretedHashFunction, XXH64} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Hash function for Spark data values that is suitable for cardinality counting. Unlike Spark's built-in hashing, + * it differentiates between different data types and accounts for nulls. + */ +abstract class CardinalityHashFunction extends InterpretedHashFunction { + + override def hash(value: Any, dataType: DataType, seed: Long): Long = { + + def hashWithTag(typeTag: Long) = + super.hash(value, dataType, hashLong(typeTag, seed)) + + value match { + // change null handling to differentiate between things like Array.empty and Array(null) + case null => hashLong(seed, seed) + // add type tags to differentiate between values on their own or in complex types + case _: Array[Byte] => hashWithTag(-3698894927619418744L) + case _: UTF8String => hashWithTag(-8468821688391060513L) + case _: ArrayData => hashWithTag(-1666055126678331734L) + case _: MapData => hashWithTag(5587693012926141532L) + case _: InternalRow => hashWithTag(-891294170547231607L) + // pass through everything else (simple types) + case _ => super.hash(value, dataType, seed) + } + } + +} + + +object CardinalityXxHash64Function extends CardinalityHashFunction { + + override protected def hashInt(i: Int, seed: Long): Long = XXH64.hashInt(i, seed) + + override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + XXH64.hashUnsafeBytes(base, offset, len, seed) + } + +} diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionRegistration.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionRegistration.scala new file mode 100644 index 0000000..6988c44 --- /dev/null +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionRegistration.scala @@ -0,0 +1,17 @@ +package com.swoop.alchemy.spark.expressions.hll + +import com.swoop.alchemy.spark.expressions.NativeFunctionRegistration +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo + +object HLLFunctionRegistration extends NativeFunctionRegistration { + + val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( + expression[HyperLogLogInitSimple]("hll_init"), + expression[HyperLogLogInitCollection]("hll_init_collection"), + expression[HyperLogLogInitSimpleAgg]("hll_init_agg"), + expression[HyperLogLogInitCollectionAgg]("hll_init_collection_agg"), + expression[HyperLogLogMerge]("hll_merge"), + expression[HyperLogLogCardinality]("hll_cardinality") + ) + +} diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala new file mode 100644 index 0000000..ad6c2a7 --- /dev/null +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala @@ -0,0 +1,444 @@ +package com.swoop.alchemy.spark.expressions.hll + +import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus +import com.swoop.alchemy.spark.expressions.WithHelper +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.aggregate.{HyperLogLogPlusPlus, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, UnaryExpression} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +trait HyperLogLogInit extends Expression { + def relativeSD: Double + + // This formula for `p` came from org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus:93 + protected[this] val p: Int = Math.ceil(2.0d * Math.log(1.106d / relativeSD) / Math.log(2.0d)).toInt + + require(p >= 4, "HLL++ requires at least 4 bits for addressing. Use a lower error, at most 39%.") + + override def dataType: DataType = BinaryType + + def child: Expression + + def offer(value: Any, buffer: HyperLogLogPlus): HyperLogLogPlus + + def createHll = new HyperLogLogPlus(p, 0) + + def hash(value: Any, dataType: DataType, seed: Long): Long = CardinalityXxHash64Function.hash(value, dataType, seed) + + def hash(value: Any, dataType: DataType): Long = { + // Using 0L as the seed results in a hash of 0L for empty arrays, which breaks our cardinality estimation tests due + // to the improbably high number of leading zeros. Instead, use some other arbitrary "normal" long. + hash(value, dataType, 6705405522910076594L) + } +} + + +trait HyperLogLogSimple extends HyperLogLogInit { + def offer(value: Any, buffer: HyperLogLogPlus): HyperLogLogPlus = { + buffer.offerHashed(hash(value, child.dataType)) + buffer + } +} + + +trait HyperLogLogCollection extends HyperLogLogInit { + + override def checkInputDataTypes(): TypeCheckResult = + child.dataType match { + case _: ArrayType | _: MapType | _: NullType => TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array and map input.") + } + + def offer(value: Any, buffer: HyperLogLogPlus): HyperLogLogPlus = { + value match { + case arr: ArrayData => + child.dataType match { + case ArrayType(et, _) => arr.foreach(et, (_, v) => { + if (v != null) buffer.offerHashed(hash(v, et)) + }) + case dt => throw new UnsupportedOperationException(s"Unknown DataType for ArrayData: $dt") + } + case map: MapData => + child.dataType match { + case MapType(kt, vt, _) => map.foreach(kt, vt, (k, v) => { + buffer.offerHashed(hash(v, vt, hash(k, kt))) // chain key and value hash + }) + case dt => throw new UnsupportedOperationException(s"Unknown DataType for MapData: $dt") + } + case _: NullType => // do nothing + case _ => throw new UnsupportedOperationException(s"$prettyName only supports array and map input.") + } + buffer + } +} + + +trait HyperLogLogInitSingle extends UnaryExpression with HyperLogLogInit with CodegenFallback { + override def nullable: Boolean = child.nullable + + override def nullSafeEval(value: Any): Any = + offer(value, createHll).getBytes +} + +trait HyperLogLogInitAgg extends NullableSketchAggregation with HyperLogLogInit { + + override def update(buffer: Option[HyperLogLogPlus], inputRow: InternalRow): Option[HyperLogLogPlus] = { + val value = child.eval(inputRow) + if (value != null) { + Some(offer(value, buffer.getOrElse(createHll))) + } else { + buffer + } + } +} + +trait NullableSketchAggregation extends TypedImperativeAggregate[Option[HyperLogLogPlus]] { + + override def createAggregationBuffer(): Option[HyperLogLogPlus] = None + + override def merge(buffer: Option[HyperLogLogPlus], other: Option[HyperLogLogPlus]): Option[HyperLogLogPlus] = + (buffer, other) match { + case (Some(a), Some(b)) => + a.addAll(b) + Some(a) + case (a, None) => a + case (None, b) => b + case _ => None + } + + override def eval(buffer: Option[HyperLogLogPlus]): Any = + buffer.map(_.getBytes).orNull + + def child: Expression + + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = child.nullable + + override def serialize(hll: Option[HyperLogLogPlus]): Array[Byte] = + hll.map(_.getBytes).orNull + + override def deserialize(bytes: Array[Byte]): Option[HyperLogLogPlus] = + if (bytes == null) None else Option(HyperLogLogPlus.Builder.build(bytes)) +} + + +/** + * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. + * + * This version creates a composable "sketch" for each input row. + * All expression values treated as simple values. + * + * @param child to estimate the cardinality of. + * @param relativeSD defines the maximum estimation error allowed + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr[, relativeSD]) - Returns the composable "sketch" by HyperLogLog++. + `relativeSD` defines the maximum estimation error allowed. + """) +case class HyperLogLogInitSimple( + override val child: Expression, + override val relativeSD: Double = 0.05) + extends HyperLogLogInitSingle with HyperLogLogSimple { + + def this(child: Expression) = this(child, relativeSD = 0.05) + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD) + ) + } + + override def prettyName: String + + = "hll_init" +} + + +/** + * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. + * + * This version combines all input in each aggregate group into a single "sketch". + * All expression values treated as simple values. + * + * @param child to estimate the cardinality of + * @param relativeSD defines the maximum estimation error allowed + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr[, relativeSD]) - Returns the composable "sketch" by HyperLogLog++. + `relativeSD` defines the maximum estimation error allowed. + """) +case class HyperLogLogInitSimpleAgg( + override val child: Expression, + override val relativeSD: Double = 0.05, + override val mutableAggBufferOffset: Int = 0, + override val inputAggBufferOffset: Int = 0) + extends HyperLogLogInitAgg with HyperLogLogSimple { + + def this(child: Expression) = this(child, relativeSD = 0.05) + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD), + mutableAggBufferOffset = 0, + inputAggBufferOffset = 0) + } + + override def withNewMutableAggBufferOffset(newOffset: Int): HyperLogLogInitSimpleAgg = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): HyperLogLogInitSimpleAgg = + copy(inputAggBufferOffset = newOffset) + + override def prettyName: String = "hll_init_agg" +} + +/** + * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. + * + * This version creates a composable "sketch" for each input row. + * Expression must be is a collection (Array, Map), and collection elements are treated as individual values. + * + * @param child to estimate the cardinality of. + * @param relativeSD defines the maximum estimation error allowed + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr[, relativeSD]) - Returns the composable "sketch" by HyperLogLog++. + `relativeSD` defines the maximum estimation error allowed. + """) +case class HyperLogLogInitCollection( + override val child: Expression, + override val relativeSD: Double = 0.05) + extends HyperLogLogInitSingle with HyperLogLogCollection { + + def this(child: Expression) = this(child, relativeSD = 0.05) + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD) + ) + } + + override def prettyName: String = "hll_init_collection" +} + + +/** + * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. + * + * This version combines all input in each aggregate group into a a single "sketch". + * If `expr` is a collection (Array, Map), collection elements are treated as individual values. + * + * @param child to estimate the cardinality of + * @param relativeSD defines the maximum estimation error allowed + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr[, relativeSD]) - Returns the composable "sketch" by HyperLogLog++. + `relativeSD` defines the maximum estimation error allowed. + """) +case class HyperLogLogInitCollectionAgg( + child: Expression, + relativeSD: Double = 0.05, + override val mutableAggBufferOffset: Int = 0, + override val inputAggBufferOffset: Int = 0) + extends HyperLogLogInitAgg with HyperLogLogCollection { + + def this(child: Expression) = this(child, relativeSD = 0.05) + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD), + mutableAggBufferOffset = 0, + inputAggBufferOffset = 0) + } + + override def withNewMutableAggBufferOffset(newOffset: Int): HyperLogLogInitCollectionAgg = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): HyperLogLogInitCollectionAgg = + copy(inputAggBufferOffset = newOffset) + + override def prettyName: String = "hll_init_collection_agg" +} + + +/** + * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. + * + * This version merges the "sketches" into a combined binary composable representation. + * + * @param child "sketch" to merge + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr) - Returns the merged HLL++ sketch. + """) +case class HyperLogLogMerge( + child: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) + extends NullableSketchAggregation { + + def this(child: Expression) = this(child, 0, 0) + + override def update(buffer: Option[HyperLogLogPlus], inputRow: InternalRow): Option[HyperLogLogPlus] = { + val value = child.eval(inputRow) + if (value != null) { + val hll = value match { + case b: Array[Byte] => HyperLogLogPlus.Builder.build(b) + case _ => throw new IllegalStateException(s"$prettyName only supports Array[Byte]") + } + buffer.map(_.merge(hll).asInstanceOf[HyperLogLogPlus]) + .orElse(Option(hll)) + } else { + buffer + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + child.dataType match { + case BinaryType => TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"$prettyName only supports binary input") + } + } + + override def dataType: DataType = BinaryType + + override def withNewMutableAggBufferOffset(newOffset: Int): HyperLogLogMerge = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): HyperLogLogMerge = + copy(inputAggBufferOffset = newOffset) + + override def prettyName: String = "hll_merge" +} + + +/** + * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. + * + * Returns the estimated cardinality of an HLL++ "sketch" + * + * @param child HLL+ "sketch" + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr) - Returns the estimated cardinality of the binary representation produced by HyperLogLog++. + """) +case class HyperLogLogCardinality(override val child: Expression) extends UnaryExpression with ExpectsInputTypes with CodegenFallback { + + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + override def dataType: DataType = LongType + + override def nullable: Boolean = child.nullable + + override def checkInputDataTypes(): TypeCheckResult = { + child.dataType match { + case BinaryType => TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"$prettyName only supports binary input") + } + } + + override def nullSafeEval(input: Any): Long = { + val data = input.asInstanceOf[Array[Byte]] + HyperLogLogPlus.Builder.build(data).cardinality() + } + + override def prettyName: String = "hll_cardinality" + +} + +object functions extends HLLFunctions + +trait HLLFunctions extends WithHelper { + + def hll_init(e: Column, relativeSD: Double): Column = withExpr { + HyperLogLogInitSimple(e.expr, relativeSD) + } + + def hll_init(columnName: String, relativeSD: Double): Column = + hll_init(col(columnName), relativeSD) + + def hll_init(e: Column): Column = withExpr { + HyperLogLogInitSimple(e.expr) + } + + def hll_init(columnName: String): Column = + hll_init(col(columnName)) + + def hll_init_collection(e: Column, relativeSD: Double): Column = withExpr { + HyperLogLogInitCollection(e.expr, relativeSD) + } + + def hll_init_collection(columnName: String, relativeSD: Double): Column = + hll_init_collection(col(columnName), relativeSD) + + def hll_init_collection(e: Column): Column = withExpr { + HyperLogLogInitCollection(e.expr) + } + + def hll_init_collection(columnName: String): Column = + hll_init_collection(col(columnName)) + + def hll_init_agg(e: Column, relativeSD: Double): Column = withAggregateFunction { + HyperLogLogInitSimpleAgg(e.expr, relativeSD) + } + + def hll_init_agg(columnName: String, relativeSD: Double): Column = + hll_init_agg(col(columnName), relativeSD) + + def hll_init_agg(e: Column): Column = withAggregateFunction { + HyperLogLogInitSimpleAgg(e.expr) + } + + def hll_init_agg(columnName: String): Column = + hll_init_agg(col(columnName)) + + def hll_init_collection_agg(e: Column, relativeSD: Double): Column = withAggregateFunction { + HyperLogLogInitCollectionAgg(e.expr, relativeSD) + } + + def hll_init_collection_agg(columnName: String, relativeSD: Double): Column = + hll_init_collection_agg(col(columnName), relativeSD) + + def hll_init_collection_agg(e: Column): Column = withAggregateFunction { + HyperLogLogInitCollectionAgg(e.expr) + } + + def hll_init_collection_agg(columnName: String): Column = + hll_init_collection_agg(col(columnName)) + + def hll_merge(e: Column): Column = withAggregateFunction { + HyperLogLogMerge(e.expr, 0, 0) + } + + def hll_merge(columnName: String): Column = + hll_merge(col(columnName)) + + def hll_cardinality(e: Column): Column = withExpr { + HyperLogLogCardinality(e.expr) + } + + def hll_cardinality(columnName: String): Column = + hll_cardinality(col(columnName)) +} diff --git a/alchemy/src/main/scala/org/apache/spark/sql/EncapsulationViolator.scala b/alchemy/src/main/scala/org/apache/spark/sql/EncapsulationViolator.scala new file mode 100644 index 0000000..f807dc9 --- /dev/null +++ b/alchemy/src/main/scala/org/apache/spark/sql/EncapsulationViolator.scala @@ -0,0 +1,49 @@ +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{GenericRow, NamedExpression} +import org.apache.spark.sql.internal.SessionState +import org.apache.spark.sql.types.{DataType, Metadata, StructType} +import org.json4s.JsonAST.JValue + +object EncapsulationViolator { + + def createAnalysisException(message: String): AnalysisException = + new AnalysisException(message) + + def parseDataType(jv: JValue): DataType = + DataType.parseDataType(jv) + + object implicits { + + implicit class EncapsulationViolationSparkSessionOps(val underlying: SparkSession) extends AnyVal { + def evSessionState: SessionState = underlying.sessionState + } + + implicit class EncapsulationViolationRowOps(val underlying: GenericRow) extends AnyVal { + def evValues: Array[Any] = underlying.values + } + + implicit class EncapsulationViolationColumnOps(val underlying: Column) extends AnyVal { + def evNamed: NamedExpression = underlying.named + + def metadata: Metadata = underlying.expr match { + case ne: NamedExpression => ne.metadata + case other => Metadata.empty + } + } + + implicit class EncapsulationViolationDataTypeOps(val underlying: DataType) extends AnyVal { + def isSameType(other: DataType): Boolean = underlying.sameType(other) + + def jValue: JValue = underlying.jsonValue + + def toNullable: DataType = underlying.asNullable + } + + implicit class EncapsulationViolationStructTypeOps(val underlying: StructType) extends AnyVal { + def evMerge(that: StructType): StructType = underlying.merge(that) + } + + } + +} diff --git a/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunctionTest.scala b/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunctionTest.scala new file mode 100644 index 0000000..9a525e1 --- /dev/null +++ b/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/CardinalityHashFunctionTest.scala @@ -0,0 +1,62 @@ +package com.swoop.alchemy.spark.expressions.hll + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.scalatest.{Matchers, WordSpec} + +class CardinalityHashFunctionTest extends WordSpec with Matchers { + + "Cardinality hash functions" should { + "account for nulls" in { + + val a = UTF8String.fromString("a") + + allDistinct(Seq( + null, + Array.empty[Byte], + Array.apply(1.toByte) + ), BinaryType) + + allDistinct(Seq( + null, + UTF8String.fromString(""), + a + ), StringType) + + allDistinct(Seq( + null, + ArrayData.toArrayData(Array.empty), + ArrayData.toArrayData(Array(null)), + ArrayData.toArrayData(Array(null, null)), + ArrayData.toArrayData(Array(a, null)), + ArrayData.toArrayData(Array(null, a)) + ), ArrayType(StringType)) + + + allDistinct(Seq( + null, + ArrayBasedMapData(Map.empty), + ArrayBasedMapData(Map(null.asInstanceOf[String] -> null)) + ), MapType(StringType, StringType)) + + allDistinct(Seq( + null, + InternalRow(null), + InternalRow(a) + ), new StructType().add("foo", StringType)) + + allDistinct(Seq( + InternalRow(null, a), + InternalRow(a, null) + ), new StructType().add("foo", StringType).add("bar", StringType)) + } + } + + def allDistinct(values: Seq[Any], dataType: DataType): Unit = { + val hashed = values.map(x => CardinalityXxHash64Function.hash(x, dataType, 0)) + hashed.distinct.length should be(hashed.length) + } + +} diff --git a/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionsTest.scala b/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionsTest.scala new file mode 100644 index 0000000..446e1ba --- /dev/null +++ b/alchemy/src/test/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctionsTest.scala @@ -0,0 +1,219 @@ +package com.swoop.alchemy.spark.expressions.hll + +import com.swoop.alchemy.spark.expressions.hll.functions.{hll_init_collection, hll_init_collection_agg, _} +import com.swoop.spark.test.HiveSqlSpec +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{array, col, lit, map} +import org.apache.spark.sql.types._ +import org.scalatest.{Matchers, WordSpec} + + +object HLLFunctionsTestHelpers { + System.setSecurityManager(null) + + case class Data(c1: Int, c2: String, c3: Array[Int], c4: Map[String, String], c5: Array[String]) + + object Data { + def apply(c1: Int, c2: String): Data = Data(c1, c2, null, null, null) + } + + case class Data2(c1: Array[String], c2: Map[String, String]) + +} + +class HLLFunctionsTest extends WordSpec with Matchers with HiveSqlSpec { + + import HLLFunctionsTestHelpers._ + + lazy val spark = sqlc.sparkSession + + "HyperLogLog functions" should { + + "not allow relativeSD > 39%" in { + val err = "requirement failed: HLL++ requires at least 4 bits for addressing. Use a lower error, at most 39%." + val c = lit(null) + + noException should be thrownBy hll_init(c, 0.39) + + the[IllegalArgumentException] thrownBy { + hll_init(c, 0.40) + } should have message err + + noException should be thrownBy hll_init_collection(c, 0.39) + + the[IllegalArgumentException] thrownBy { + hll_init_collection(c, 0.40) + } should have message err + + } + + "register native org.apache.spark.sql.ext.functions" in { + HLLFunctionRegistration.registerFunctions(spark) + + noException should be thrownBy spark.sql( + """select + | hll_cardinality(hll_merge(hll_init(array(1,2,3)))), + | hll_cardinality(hll_merge(hll_init_collection(array(1,2,3)))), + | hll_cardinality(hll_init_agg(array(1,2,3))), + | hll_cardinality(hll_init_collection_agg(array(1,2,3))), + | hll_cardinality(hll_merge(hll_init(array(1,2,3), 0.05))), + | hll_cardinality(hll_merge(hll_init_collection(array(1,2,3), 0.05))), + | hll_cardinality(hll_init_agg(array(1,2,3), 0.05)), + | hll_cardinality(hll_init_collection_agg(array(1,2,3), 0.05)) + """.stripMargin + ) + } + + + "estimate cardinality of simple types and collections" in { + val a123 = array(lit(1), lit(2), lit(3)) + + val simpleValues = Seq( + lit(null).cast(IntegerType), + lit(""), + a123 + ).map(hll_init) + + val collections = Seq( + lit(null).cast(ArrayType(IntegerType)), + array(), + map(), + a123 + ).map(hll_init_collection) + + val results = cardinalities(spark.range(1).select(simpleValues ++ collections: _*)) + + results should be(Seq( + /* simple types */ 0, 1, 1, + /* collections */ 0, 0, 0, 3 + )) + } + // @todo merge tests with grouping + "estimate cardinality correctly" in { + import spark.implicits._ + + val df = spark.createDataset[Data](Seq[Data]( + Data(1, "a", Array(1, 2, 3), Map("a" -> "A"), Array.empty), + Data(2, "b", Array(2, 3, 1), Map("b" -> "B"), Array(null)), + Data(2, "b", Array(2, 3, 1), Map("b" -> "B"), Array(null, null)), + Data(3, "c", Array(3, 1, 2), Map("a" -> "A", "b" -> "B"), null), + Data(2, "b", Array(1, 1, 1), Map("b" -> "B", "c" -> "C"), null), + Data(3, "c", Array(2, 2, 2), Map("c" -> "C", "a" -> null), null), + Data(4, "d", null, null, null), + Data(4, "d", null, null, null), + Data(5, "e", Array.empty, Map.empty, null), + Data(5, "e", Array.empty, Map.empty, null) + )) + + val results = cardinalities(merge(df.select( + hll_init('c1), + hll_init('c2), + hll_init('c3), + hll_init('c4), + hll_init('c5), + hll_init_collection('c3), + hll_init_collection('c4), + hll_init_collection('c5) + ))) + + results should be(Seq( + 5, // 5 unique simple values + 5, // 5 unique simple values + 6, // 6 unique arrays (treated as simple types, nulls not counted) + 6, // 6 unique maps (treated as simple types, nulls not counted) + 3, // 3 unique arrays + 3, // 3 unique values across all arrays + 4, // 4 unique (k, v) pairs across all maps + 0 // 0 unique values across all arrays, nulls not counted + )) + } + "estimate multiples correctly" in { + import spark.implicits._ + + val createSampleData = + spark.createDataset(Seq( + Data(1, "a"), + Data(2, "b"), + Data(2, "b"), + Data(3, "c"), + Data(4, "d") + )).select(hll_init('c1), hll_init('c2)) + + val results = cardinalities(merge(createSampleData union createSampleData)) + + results should be(Seq(4, 4)) + } + } + + "HyperLogLog aggregate functions" should { + // @todo merge tests with grouping + "estimate cardinality correctly" in { + import spark.implicits._ + + val df = spark.createDataset[Data](Seq[Data]( + Data(1, "a", Array(1, 2, 3), Map("a" -> "A"), Array.empty), + Data(2, "b", Array(2, 3, 1), Map("b" -> "B"), Array(null)), + Data(2, "b", Array(2, 3, 1), Map("b" -> "B"), Array(null, null)), + Data(3, "c", Array(3, 1, 2), Map("a" -> "A", "b" -> "B"), null), + Data(2, "b", Array(1, 1, 1), Map("b" -> "B", "c" -> "C"), null), + Data(3, "c", Array(2, 2, 2), Map("c" -> "C", "a" -> null), null), + Data(4, "d", null, null, null), + Data(4, "d", null, null, null), + Data(5, "e", Array.empty, Map.empty, null), + Data(5, "e", Array.empty, Map.empty, null) + )) + + val results = cardinalities(df.select( + hll_init_agg('c1), + hll_init_agg('c2), + hll_init_agg('c3), + hll_init_agg('c4), + hll_init_agg('c5), + hll_init_collection_agg('c3), + hll_init_collection_agg('c4), + hll_init_collection_agg('c5) + )) + + results should be(Seq( + 5, // 5 unique simple values + 5, // 5 unique simple values + 6, // 6 unique arrays (treated as simple types, nulls not counted) + 6, // 6 unique maps (treated as simple types, nulls not counted) + 3, // 3 unique arrays + 3, // 3 unique values across all arrays + 4, // 4 unique (k, v) pairs across all maps + 0 // 0 unique values across all arrays, nulls not counted + )) + } + "estimate multiples correctly" in { + import spark.implicits._ + + val createSampleData = + spark.createDataset(Seq( + Data(1, "a"), + Data(2, "b"), + Data(2, "b"), + Data(3, "c"), + Data(4, "d") + )).select(hll_init_agg('c1), hll_init_agg('c2)) + + val results = cardinalities(createSampleData union createSampleData) + + results should be(Seq(4, 4)) + } + } + + def merge(df: DataFrame): DataFrame = + df.select( + df.columns.zipWithIndex.map { case (name, idx) => + hll_merge(col(name)).as(s"c$idx") + }: _* + ) + + def cardinalities(df: DataFrame): Seq[Long] = + df.select( + df.columns.zipWithIndex.map { case (name, idx) => + hll_cardinality(col(name)).as(s"c$idx") + }: _* + ).head.toSeq.map(_.asInstanceOf[Long]) +} diff --git a/build.sbt b/build.sbt index 29b0eec..131a41f 100644 --- a/build.sbt +++ b/build.sbt @@ -19,22 +19,31 @@ lazy val alchemy = (project in file(".")) resourceDirectory in Compile := baseDirectory.value / "alchemy/src/main/resources", resourceDirectory in Test := baseDirectory.value / "alchemy/src/test/resources", libraryDependencies ++= Seq( - scalaTest % Test withSources() - ) + scalaTest % Test withSources(), + "com.swoop" %% "spark-test-sugar" % "1.5.0" % Test withSources() + ), + libraryDependencies ++= sparkDependencies, + fork in Test := true // required for Spark ) lazy val test = (project in file("alchemy-test")) .settings( name := "spark-alchemy-test", libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-core" % sparkVersion % "provided" withSources(), - "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" withSources() - excludeAll ExclusionRule(organization = "org.mortbay.jetty"), - "org.apache.spark" %% "spark-hive" % sparkVersion % "provided" withSources(), scalaTest % Test withSources() - ) + ), + libraryDependencies ++= sparkDependencies ) +lazy val sparkDependencies = Seq( + "org.apache.logging.log4j" % "log4j-api" % "2.7" % "provided" withSources(), + "org.apache.logging.log4j" % "log4j-core" % "2.7" % "provided" withSources(), + "org.apache.spark" %% "spark-core" % sparkVersion % "provided" withSources(), + "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" withSources() + excludeAll ExclusionRule(organization = "org.mortbay.jetty"), + "org.apache.spark" %% "spark-hive" % sparkVersion % "provided" withSources() +) + enablePlugins(BuildInfoPlugin) enablePlugins(GitVersioning, GitBranchPrompt) enablePlugins(MicrositesPlugin)