diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 89a2a5ceaa9bf..fb542e6cff81a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -51,6 +51,7 @@ 'sha1', 'sha2', 'size', + 'sort_array', 'sparkPartitionId', 'struct', 'udf', @@ -570,8 +571,10 @@ def length(col): def format_number(col, d): """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, and returns the result as a string. + :param col: the column name of the numeric value to be formatted :param d: the N decimal places + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() [Row(v=u'5.0000')] """ @@ -968,6 +971,23 @@ def soundex(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(1.5) +def sort_array(col, asc=True): + """ + Collection function: sorts the input array for the given column in ascending order. + + :param col: name of column or expression + + >>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df.select(sort_array(df.data).alias('r')).collect() + [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() + [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ee44cbcba68e7..6e144518bb009 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -233,6 +233,7 @@ object FunctionRegistry { // collection functions expression[Size]("size"), + expression[SortArray]("sort_array"), // misc functions expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1a00dbc254de1..0a530596a98c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import java.util.Comparator + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ /** @@ -39,3 +42,78 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;") } } + +/** + * Sorts the input array in ascending / descending order according to the natural ordering of + * the array elements and returns it. + */ +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + + def this(e: Expression) = this(e, Literal(true)) + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case _ @ ArrayType(n: AtomicType, _) => TypeCheckResult.TypeCheckSuccess + case _ @ ArrayType(n, _) => TypeCheckResult.TypeCheckFailure( + s"Type $n is not the AtomicType, we can not perform the ordering operations") + case other => + TypeCheckResult.TypeCheckFailure(s"ArrayType(AtomicType) is expected, but we got $other") + } + + @transient + private lazy val lt = { + val ordering = base.dataType match { + case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + } + + new Comparator[Any]() { + override def compare(o1: Any, o2: Any): Int = { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + -1 + } else if (o2 == null) { + 1 + } else { + ordering.compare(o1, o2) + } + } + } + } + + @transient + private lazy val gt = { + val ordering = base.dataType match { + case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + } + + new Comparator[Any]() { + override def compare(o1: Any, o2: Any): Int = { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + 1 + } else if (o2 == null) { + -1 + } else { + -ordering.compare(o1, o2) + } + } + } + } + + override def nullSafeEval(array: Any, ascending: Any): Any = { + val data = array.asInstanceOf[ArrayData].toArray().asInstanceOf[Array[AnyRef]] + java.util.Arrays.sort( + data, + if (ascending.asInstanceOf[Boolean]) lt else gt) + new GenericArrayData(data.asInstanceOf[Array[Any]]) + } + + override def prettyName: String = "sort_array" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 28c41b57169f9..2c7e85c446ec6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -43,4 +43,26 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } + + test("Sort Array") { + val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) + + checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) + checkEvaluation(new SortArray(a1), Seq[Integer]()) + checkEvaluation(new SortArray(a2), Seq("a", "b")) + checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) + checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) + checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) + checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) + checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) + checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) + checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) + + checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 57bb00a7417af..3c9421f5cd14b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2223,19 +2223,30 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns length of array or map + * Returns length of array or map. + * * @group collection_funcs * @since 1.5.0 */ - def size(columnName: String): Column = size(Column(columnName)) + def size(e: Column): Column = Size(e.expr) /** - * Returns length of array or map + * Sorts the input array for the given column in ascending order, + * according to the natural ordering of the array elements. + * * @group collection_funcs * @since 1.5.0 */ - def size(column: Column): Column = Size(column.expr) + def sort_array(e: Column): Column = sort_array(e, true) + /** + * Sorts the input array for the given column in ascending / descending order, + * according to the natural ordering of the array elements. + * + * @group collection_funcs + * @since 1.5.0 + */ + def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1baec5d37699d..46921d14256b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -267,6 +267,53 @@ class DataFrameFunctionsSuite extends QueryTest { ) } + test("sort_array function") { + val df = Seq( + (Array[Int](2, 1, 3), Array("b", "c", "a")), + (Array[Int](), Array[String]()), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(sort_array($"a"), sort_array($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + checkAnswer( + df.select(sort_array($"a", false), sort_array($"b", false)), + Seq( + Row(Seq(3, 2, 1), Seq("c", "b", "a")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("sort_array(a)", "sort_array(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), + Seq( + Row(Seq(1, 2, 3), Seq("c", "b", "a")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + + val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b") + assert(intercept[AnalysisException] { + df2.selectExpr("sort_array(a)").collect() + }.getMessage().contains("Type ArrayType(IntegerType,false) is not the AtomicType, " + + "we can not perform the ordering operations")) + + val df3 = Seq(("xxx", "x")).toDF("a", "b") + assert(intercept[AnalysisException] { + df3.selectExpr("sort_array(a)").collect() + }.getMessage().contains("ArrayType(AtomicType) is expected, but we got StringType")) + } + test("array size function") { val df = Seq( (Array[Int](1, 2), "x"), @@ -274,7 +321,7 @@ class DataFrameFunctionsSuite extends QueryTest { (Array[Int](1, 2, 3), "z") ).toDF("a", "b") checkAnswer( - df.select(size("a")), + df.select(size($"a")), Seq(Row(2), Row(0), Row(3)) ) checkAnswer( @@ -290,7 +337,7 @@ class DataFrameFunctionsSuite extends QueryTest { (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z") ).toDF("a", "b") checkAnswer( - df.select(size("a")), + df.select(size($"a")), Seq(Row(2), Row(0), Row(3)) ) checkAnswer(