Skip to content

Commit

Permalink
Add sort_array support
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Jul 31, 2015
1 parent 9307f56 commit a42b678
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 7 deletions.
15 changes: 15 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
'sha1',
'sha2',
'size',
'sort_array',
'sparkPartitionId',
'struct',
'udf',
Expand Down Expand Up @@ -902,6 +903,20 @@ def size(col):
return Column(sc._jvm.functions.size(_to_java_column(col)))


@since(1.5)
def sort_array(col):
"""
Collection function: sorts the input array for the given column in ascending order.
:param col: name of column or expression
>>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
>>> df.select(sort_array(df.data)).collect()
[Row(sort_array(data)=[1, 2, 3]), Row(sort_array(data)=[1]), Row(sort_array(data)=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.sort_array(_to_java_column(col)))


class UserDefinedFunction(object):
"""
User defined function in Python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ object FunctionRegistry {

// collection functions
expression[Size]("size"),
expression[SortArray]("sort_array"),

// misc functions
expression[Crc32]("crc32"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types._

/**
Expand All @@ -39,3 +39,42 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;")
}
}

/**
* Sorts the input array in ascending order according to the natural ordering of
* the array elements and returns it.
*/
case class SortArray(child: Expression)
extends UnaryExpression with ExpectsInputTypes with CodegenFallback {

override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

@transient
private lazy val lt: (Any, Any) => Boolean = {
val ordering = child.dataType match {
case ArrayType(elementType, _) => elementType match {
case n: AtomicType => n.ordering.asInstanceOf[Ordering[Any]]
case other => sys.error(s"Type $other does not support ordered operations")
}
}

(left, right) => {
if (left == null && right == null) {
false
} else if (left == null) {
true
} else if (right == null) {
false
} else {
ordering.compare(left, right) < 0
}
}
}

override def nullSafeEval(value: Any): Seq[Any] = {
value.asInstanceOf[Seq[Any]].sortWith(lt)
}

override def prettyName: String = "sort_array"
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,16 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}

test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))

checkEvaluation(SortArray(a0), Seq(1, 2, 3))
checkEvaluation(SortArray(a1), Seq[Integer]())
checkEvaluation(SortArray(a2), Seq("a", "b"))

checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}
}
11 changes: 7 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2186,18 +2186,21 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////

/**
* Returns length of array or map
* Returns length of array or map.
*
* @group collection_funcs
* @since 1.5.0
*/
def size(columnName: String): Column = size(Column(columnName))
def size(e: Column): Column = Size(e.expr)

/**
* Returns length of array or map
* Sorts the input array for the given column in ascending order,
* according to the natural ordering of the array elements.
*
* @group collection_funcs
* @since 1.5.0
*/
def size(column: Column): Column = Size(column.expr)
def sort_array(e: Column): Column = SortArray(e.expr)


//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,36 @@ class DataFrameFunctionsSuite extends QueryTest {
)
}

test("sort_array function") {
val df = Seq(
(Array[Int](2, 1, 3), Array("b", "c", "a")),
(Array[Int](), Array[String]()),
(null, null)
).toDF("a", "b")
checkAnswer(
df.select(sort_array($"a"), sort_array($"b")),
Seq(
Row(Seq(1, 2, 3), Seq("a", "b", "c")),
Row(Seq[Int](), Seq[String]()),
Row(null, null))
)
checkAnswer(
df.selectExpr("sort_array(a)", "sort_array(b)"),
Seq(
Row(Seq(1, 2, 3), Seq("a", "b", "c")),
Row(Seq[Int](), Seq[String]()),
Row(null, null))
)
}

test("array size function") {
val df = Seq(
(Array[Int](1, 2), "x"),
(Array[Int](), "y"),
(Array[Int](1, 2, 3), "z")
).toDF("a", "b")
checkAnswer(
df.select(size("a")),
df.select(size($"a")),
Seq(Row(2), Row(0), Row(3))
)
checkAnswer(
Expand All @@ -290,7 +312,7 @@ class DataFrameFunctionsSuite extends QueryTest {
(Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z")
).toDF("a", "b")
checkAnswer(
df.select(size("a")),
df.select(size($"a")),
Seq(Row(2), Row(0), Row(3))
)
checkAnswer(
Expand Down

0 comments on commit a42b678

Please sign in to comment.