Skip to content

Commit

Permalink
Add asending/descending support for sort_array
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Jul 31, 2015
1 parent 80fc0f8 commit 0edab9c
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 22 deletions.
12 changes: 7 additions & 5 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,17 +904,19 @@ def size(col):


@since(1.5)
def sort_array(col):
def sort_array(col, asc=True):
"""
Collection function: sorts the input array for the given column in ascending order.
:param col: name of column or expression
>>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
>>> df.select(sort_array(df.data)).collect()
[Row(sort_array(data)=[1, 2, 3]), Row(sort_array(data)=[1]), Row(sort_array(data)=[])]
"""
>>> df.select(sort_array(df.data).alias('r')).collect()
[Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
>>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
[Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.sort_array(_to_java_column(col)))
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))


class UserDefinedFunction(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,30 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
}

/**
* Sorts the input array in ascending order according to the natural ordering of
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
case class SortArray(child: Expression)
extends UnaryExpression with ExpectsInputTypes with CodegenFallback {
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {

override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
def this(e: Expression) = this(e, Literal(true))

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => TypeCheckResult.TypeCheckSuccess
case other => TypeCheckResult.TypeCheckFailure(
s"Type $other is not supported for ordering operations")
case _ @ ArrayType(n, _) => TypeCheckResult.TypeCheckFailure(
s"Type $n is not the AtomicType, we can not perform the ordering operations")
case other =>
TypeCheckResult.TypeCheckFailure(s"ArrayType(AtomicType) is expected, but we got $other")
}

@transient
private lazy val lt: (Any, Any) => Boolean = {
val ordering = child.dataType match {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
}

Expand All @@ -76,8 +82,27 @@ case class SortArray(child: Expression)
}
}

override def nullSafeEval(value: Any): Seq[Any] = {
value.asInstanceOf[Seq[Any]].sortWith(lt)
@transient
private lazy val gt: (Any, Any) => Boolean = {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
}

(left, right) => {
if (left == null && right == null) {
true
} else if (left == null) {
false
} else if (right == null) {
true
} else {
ordering.compare(left, right) > 0
}
}
}

override def nullSafeEval(array: Any, ascending: Any): Seq[Any] = {
array.asInstanceOf[Seq[Any]].sortWith(if (ascending.asInstanceOf[Boolean]) lt else gt)
}

override def prettyName: String = "sort_array"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,20 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))

checkEvaluation(SortArray(a0), Seq(1, 2, 3))
checkEvaluation(SortArray(a1), Seq[Integer]())
checkEvaluation(SortArray(a2), Seq("a", "b"))
checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
checkEvaluation(new SortArray(a1), Seq[Integer]())
checkEvaluation(new SortArray(a2), Seq("a", "b"))
checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))

checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}
Expand Down
10 changes: 9 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2200,8 +2200,16 @@ object functions {
* @group collection_funcs
* @since 1.5.0
*/
def sort_array(e: Column): Column = SortArray(e.expr)
def sort_array(e: Column): Column = sort_array(e, true)

/**
* Sorts the input array for the given column in ascending / descending order,
* according to the natural ordering of the array elements.
*
* @group collection_funcs
* @since 1.5.0
*/
def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr)

//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,19 +280,38 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(Seq[Int](), Seq[String]()),
Row(null, null))
)
checkAnswer(
df.select(sort_array($"a", false), sort_array($"b", false)),
Seq(
Row(Seq(3, 2, 1), Seq("c", "b", "a")),
Row(Seq[Int](), Seq[String]()),
Row(null, null))
)
checkAnswer(
df.selectExpr("sort_array(a)", "sort_array(b)"),
Seq(
Row(Seq(1, 2, 3), Seq("a", "b", "c")),
Row(Seq[Int](), Seq[String]()),
Row(null, null))
)
checkAnswer(
df.selectExpr("sort_array(a, true)", "sort_array(b, false)"),
Seq(
Row(Seq(1, 2, 3), Seq("c", "b", "a")),
Row(Seq[Int](), Seq[String]()),
Row(null, null))
)

val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b")
assert(intercept[AnalysisException] {
df2.selectExpr("sort_array(a)").collect()
}.getMessage().contains("Type ArrayType(ArrayType(IntegerType,false),true) " +
"is not supported for ordering operations"))
}.getMessage().contains("Type ArrayType(IntegerType,false) is not the AtomicType, " +
"we can not perform the ordering operations"))

val df3 = Seq(("xxx", "x")).toDF("a", "b")
assert(intercept[AnalysisException] {
df3.selectExpr("sort_array(a)").collect()
}.getMessage().contains("ArrayType(AtomicType) is expected, but we got StringType"))
}

test("array size function") {
Expand Down

0 comments on commit 0edab9c

Please sign in to comment.