Skip to content

Commit

Permalink
Improve tests / fix serialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Aug 18, 2014
1 parent f31b8ad commit c1f7114
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 17 deletions.
Expand Up @@ -127,7 +127,7 @@ object EmptyRow extends Row {
* the array is not copied, and thus could technically be mutated after creation, this is not
* allowed.
*/
class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
class GenericRow(protected[sql] val values: Array[Any]) extends Row {
/** No-arg constructor for serialization. */
def this() = this(null)

Expand Down
Expand Up @@ -435,15 +435,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin

leftEval.code ++ rightEval.code ++
q"""
val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
val iterator = rightSet.iterator
while (iterator.hasNext) {
leftSet.add(iterator.next())
}

val $nullTerm = false
val $primitiveTerm = leftSet
var $primitiveTerm: ${hashSetForType(elementType)} = null

{
val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
val iterator = rightSet.iterator
while (iterator.hasNext) {
leftSet.add(iterator.next())
}
$primitiveTerm = leftSet
}
""".children

case MaxOf(e1, e2) =>
Expand Down
Expand Up @@ -58,16 +58,17 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {

def eval(input: Row): Any = {
val itemEval = item.eval(input)
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]

if (itemEval != null) {
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
if (setEval != null) {
setEval.add(itemEval)
setEval
} else {
null
}
} else {
null
setEval
}
}

Expand Down
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import java.nio.ByteBuffer

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.util.collection.OpenHashSet

import scala.reflect.ClassTag
Expand Down Expand Up @@ -123,23 +124,22 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {

private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) {
val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
output.writeInt(hs.size)
val rowSerializer = kryo.getSerializer(classOf[Any]).asInstanceOf[Serializer[Any]]
val iterator = hs.iterator
while(iterator.hasNext) {
val row = iterator.next()
rowSerializer.write(kryo, output, row)
rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values)
}
}

def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = {
val rowSerializer = kryo.getSerializer(classOf[Any]).asInstanceOf[Serializer[Any]]

val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
val numItems = input.readInt()
val set = new OpenHashSet[Any](numItems + 1)
var i = 0
while (i < numItems) {
val row = rowSerializer.read(kryo, input, classOf[Any].asInstanceOf[Class[Any]])
val row = new GenericRow(rowSerializer.read(kryo, input, classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])
set.add(row)
i += 1
}
Expand Down
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.parquet._

Expand Down Expand Up @@ -149,7 +150,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists {
case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
case CollectHashSet(exprs) if exprs.size == 1 => false
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
Seq(IntegerType, LongType).contains(exprs.head.dataType) => false
case _ => true
}

Expand Down
Expand Up @@ -32,6 +32,71 @@ case class TestData(a: Int, b: String)
*/
class HiveQuerySuite extends HiveComparisonTest {

createQueryTest("count distinct 0 values",
"""
|SELECT COUNT(DISTINCT a) FROM (
| SELECT 'a' AS a FROM src LIMIT 0) table
""".stripMargin)

createQueryTest("count distinct 1 value strings",
"""
|SELECT COUNT(DISTINCT a) FROM (
| SELECT 'a' AS a FROM src LIMIT 1 UNION ALL
| SELECT 'b' AS a FROM src LIMIT 1) table
""".stripMargin)

createQueryTest("count distinct 1 value",
"""
|SELECT COUNT(DISTINCT a) FROM (
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
| SELECT 1 AS a FROM src LIMIT 1) table
""".stripMargin)

createQueryTest("count distinct 2 values",
"""
|SELECT COUNT(DISTINCT a) FROM (
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
| SELECT 2 AS a FROM src LIMIT 1) table
""".stripMargin)

createQueryTest("count distinct 2 values including null",
"""
|SELECT COUNT(DISTINCT a, 1) FROM (
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
| SELECT null AS a FROM src LIMIT 1) table
""".stripMargin)

createQueryTest("count distinct 1 value + null",
"""
|SELECT COUNT(DISTINCT a) FROM (
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
| SELECT null AS a FROM src LIMIT 1) table
""".stripMargin)

createQueryTest("count distinct 1 value long",
"""
|SELECT COUNT(DISTINCT a) FROM (
| SELECT 1L AS a FROM src LIMIT 1 UNION ALL
| SELECT 1L AS a FROM src LIMIT 1) table
""".stripMargin)

createQueryTest("count distinct 2 values long",
"""
|SELECT COUNT(DISTINCT a) FROM (
| SELECT 1L AS a FROM src LIMIT 1 UNION ALL
| SELECT 2L AS a FROM src LIMIT 1) table
""".stripMargin)

createQueryTest("count distinct 1 value + null long",
"""
|SELECT COUNT(DISTINCT a) FROM (
| SELECT 1L AS a FROM src LIMIT 1 UNION ALL
| SELECT 1L AS a FROM src LIMIT 1 UNION ALL
| SELECT null AS a FROM src LIMIT 1) table
""".stripMargin)

createQueryTest("null case",
"SELECT case when(true) then 1 else null end FROM src LIMIT 1")

Expand Down

0 comments on commit c1f7114

Please sign in to comment.