From 0edab9c72909e98a86c2074f7e5cb68834273654 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 28 Jul 2015 19:57:22 -0700 Subject: [PATCH] Add asending/descending support for sort_array --- python/pyspark/sql/functions.py | 12 +++-- .../expressions/collectionOperations.scala | 47 ++++++++++++++----- .../CollectionFunctionsSuite.scala | 16 +++++-- .../org/apache/spark/sql/functions.scala | 10 +++- .../spark/sql/DataFrameFunctionsSuite.scala | 23 ++++++++- 5 files changed, 86 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 01a85e7a65732..f1416cbf35894 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -904,17 +904,19 @@ def size(col): @since(1.5) -def sort_array(col): +def sort_array(col, asc=True): """ Collection function: sorts the input array for the given column in ascending order. :param col: name of column or expression >>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) - >>> df.select(sort_array(df.data)).collect() - [Row(sort_array(data)=[1, 2, 3]), Row(sort_array(data)=[1]), Row(sort_array(data)=[])] - """ + >>> df.select(sort_array(df.data).alias('r')).collect() + [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() + [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] + """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.sort_array(_to_java_column(col))) + return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) class UserDefinedFunction(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 093ccc0552f96..b63f75836700f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -42,24 +42,30 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType } /** - * Sorts the input array in ascending order according to the natural ordering of + * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. */ -case class SortArray(child: Expression) - extends UnaryExpression with ExpectsInputTypes with CodegenFallback { +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { - override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + def this(e: Expression) = this(e, Literal(true)) - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => TypeCheckResult.TypeCheckSuccess - case other => TypeCheckResult.TypeCheckFailure( - s"Type $other is not supported for ordering operations") + case _ @ ArrayType(n, _) => TypeCheckResult.TypeCheckFailure( + s"Type $n is not the AtomicType, we can not perform the ordering operations") + case other => + TypeCheckResult.TypeCheckFailure(s"ArrayType(AtomicType) is expected, but we got $other") } @transient private lazy val lt: (Any, Any) => Boolean = { - val ordering = child.dataType match { + val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] } @@ -76,8 +82,27 @@ case class SortArray(child: Expression) } } - override def nullSafeEval(value: Any): Seq[Any] = { - value.asInstanceOf[Seq[Any]].sortWith(lt) + @transient + private lazy val gt: (Any, Any) => Boolean = { + val ordering = base.dataType match { + case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + } + + (left, right) => { + if (left == null && right == null) { + true + } else if (left == null) { + false + } else if (right == null) { + true + } else { + ordering.compare(left, right) > 0 + } + } + } + + override def nullSafeEval(array: Any, ascending: Any): Seq[Any] = { + array.asInstanceOf[Seq[Any]].sortWith(if (ascending.asInstanceOf[Boolean]) lt else gt) } override def prettyName: String = "sort_array" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 78286f4be03be..2c7e85c446ec6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -48,10 +48,20 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) - checkEvaluation(SortArray(a0), Seq(1, 2, 3)) - checkEvaluation(SortArray(a1), Seq[Integer]()) - checkEvaluation(SortArray(a2), Seq("a", "b")) + checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) + checkEvaluation(new SortArray(a1), Seq[Integer]()) + checkEvaluation(new SortArray(a2), Seq("a", "b")) + checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) + checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) + checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) + checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) + checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) + checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) + checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 321a05b258c93..55b41577213a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2200,8 +2200,16 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def sort_array(e: Column): Column = SortArray(e.expr) + def sort_array(e: Column): Column = sort_array(e, true) + /** + * Sorts the input array for the given column in ascending / descending order, + * according to the natural ordering of the array elements. + * + * @group collection_funcs + * @since 1.5.0 + */ + def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index db083e34fb955..46921d14256b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -280,6 +280,13 @@ class DataFrameFunctionsSuite extends QueryTest { Row(Seq[Int](), Seq[String]()), Row(null, null)) ) + checkAnswer( + df.select(sort_array($"a", false), sort_array($"b", false)), + Seq( + Row(Seq(3, 2, 1), Seq("c", "b", "a")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) checkAnswer( df.selectExpr("sort_array(a)", "sort_array(b)"), Seq( @@ -287,12 +294,24 @@ class DataFrameFunctionsSuite extends QueryTest { Row(Seq[Int](), Seq[String]()), Row(null, null)) ) + checkAnswer( + df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), + Seq( + Row(Seq(1, 2, 3), Seq("c", "b", "a")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b") assert(intercept[AnalysisException] { df2.selectExpr("sort_array(a)").collect() - }.getMessage().contains("Type ArrayType(ArrayType(IntegerType,false),true) " + - "is not supported for ordering operations")) + }.getMessage().contains("Type ArrayType(IntegerType,false) is not the AtomicType, " + + "we can not perform the ordering operations")) + + val df3 = Seq(("xxx", "x")).toDF("a", "b") + assert(intercept[AnalysisException] { + df3.selectExpr("sort_array(a)").collect() + }.getMessage().contains("ArrayType(AtomicType) is expected, but we got StringType")) } test("array size function") {