From 163e3f1df94f6b7d3dadb46a87dbb3a2bade3f95 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 19 Jul 2015 16:48:47 -0700 Subject: [PATCH] [SPARK-8241][SQL] string function: concat_ws. I also changed the semantics of concat w.r.t. null back to the same behavior as Hive. That is to say, concat now returns null if any input is null. Author: Reynold Xin Closes #7504 from rxin/concat_ws and squashes the following commits: 83fd950 [Reynold Xin] Fixed type casting. 3ae85f7 [Reynold Xin] Write null better. cdc7be6 [Reynold Xin] Added code generation for pure string mode. a61c4e4 [Reynold Xin] Updated comments. 2d51406 [Reynold Xin] [SPARK-8241][SQL] string function: concat_ws. --- .../catalyst/analysis/FunctionRegistry.scala | 11 ++- .../expressions/stringOperations.scala | 72 ++++++++++++++++--- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../analysis/HiveTypeCoercionSuite.scala | 11 ++- .../expressions/StringExpressionsSuite.scala | 31 +++++++- .../org/apache/spark/sql/functions.scala | 24 +++++++ .../spark/sql/StringFunctionsSuite.scala | 19 +++-- .../execution/HiveCompatibilitySuite.scala | 4 +- .../apache/spark/unsafe/types/UTF8String.java | 58 +++++++++++++-- .../spark/unsafe/types/UTF8StringSuite.java | 62 ++++++++++++---- 10 files changed, 256 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4b256adcc60c6..71e87b98d86fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -153,6 +153,7 @@ object FunctionRegistry { expression[Ascii]("ascii"), expression[Base64]("base64"), expression[Concat]("concat"), + expression[ConcatWs]("concat_ws"), expression[Encode]("encode"), expression[Decode]("decode"), expression[FormatNumber]("format_number"), @@ -211,7 +212,10 @@ object FunctionRegistry { val builder = (expressions: Seq[Expression]) => { if (varargCtor.isDefined) { // If there is an apply method that accepts Seq[Expression], use that one. - varargCtor.get.newInstance(expressions).asInstanceOf[Expression] + Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match { + case Success(e) => e + case Failure(e) => throw new AnalysisException(e.getMessage) + } } else { // Otherwise, find an ctor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) @@ -221,7 +225,10 @@ object FunctionRegistry { case Failure(e) => throw new AnalysisException(s"Invalid number of arguments for function $name") } - f.newInstance(expressions : _*).asInstanceOf[Expression] + Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { + case Success(e) => e + case Failure(e) => throw new AnalysisException(e.getMessage) + } } } (name, builder) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 560b1bc2d889f..5f8ac716f79a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -34,19 +34,14 @@ import org.apache.spark.unsafe.types.UTF8String /** * An expression that concatenates multiple input strings into a single string. - * Input expressions that are evaluated to nulls are skipped. - * - * For example, `concat("a", null, "b")` is evaluated to `"ab"`. - * - * Note that this is different from Hive since Hive outputs null if any input is null. - * We never output null. + * If any input is null, concat returns null. */ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) override def dataType: DataType = StringType - override def nullable: Boolean = false + override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) override def eval(input: InternalRow): Any = { @@ -56,15 +51,76 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evals = children.map(_.gen(ctx)) - val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ") + val inputs = evals.map { eval => + s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + }.mkString(", ") evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; UTF8String ${ev.primitive} = UTF8String.concat($inputs); + if (${ev.primitive} == null) { + ${ev.isNull} = true; + } """ } } +/** + * An expression that concatenates multiple input strings or array of strings into a single string, + * using a given separator (the first child). + * + * Returns null if the separator is null. Otherwise, concat_ws skips all null values. + */ +case class ConcatWs(children: Seq[Expression]) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + require(children.nonEmpty, s"$prettyName requires at least one argument.") + + override def prettyName: String = "concat_ws" + + /** The 1st child (separator) is str, and rest are either str or array of str. */ + override def inputTypes: Seq[AbstractDataType] = { + val arrayOrStr = TypeCollection(ArrayType(StringType), StringType) + StringType +: Seq.fill(children.size - 1)(arrayOrStr) + } + + override def dataType: DataType = StringType + + override def nullable: Boolean = children.head.nullable + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val flatInputs = children.flatMap { child => + child.eval(input) match { + case s: UTF8String => Iterator(s) + case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]] + case null => Iterator(null.asInstanceOf[UTF8String]) + } + } + UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (children.forall(_.dataType == StringType)) { + // All children are strings. In that case we can construct a fixed size array. + val evals = children.map(_.gen(ctx)) + + val inputs = evals.map { eval => + s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + }.mkString(", ") + + evals.map(_.code).mkString("\n") + s""" + UTF8String ${ev.primitive} = UTF8String.concatWs($inputs); + boolean ${ev.isNull} = ${ev.primitive} == null; + """ + } else { + // Contains a mix of strings and arrays. Fall back to interpreted mode for now. + super.genCode(ctx, ev) + } + } +} + + trait StringRegexExpression extends ImplicitCastInputTypes { self: BinaryExpression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 2d133eea19fe0..e98fd2583b931 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = this - override private[sql] def acceptsType(other: DataType): Boolean = this == other + override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f9442bccc4a7a..7ee2333a81dfe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -37,7 +37,6 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(NullType, IntegerType, IntegerType) shouldCast(NullType, DecimalType, DecimalType.Unlimited) - // TODO: write the entire implicit cast table out for test cases. shouldCast(ByteType, IntegerType, IntegerType) shouldCast(IntegerType, IntegerType, IntegerType) shouldCast(IntegerType, LongType, LongType) @@ -86,6 +85,16 @@ class HiveTypeCoercionSuite extends PlanTest { DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe => shouldCast(tpe, NumericType, tpe) } + + shouldCast( + ArrayType(StringType, false), + TypeCollection(ArrayType(StringType), StringType), + ArrayType(StringType, false)) + + shouldCast( + ArrayType(StringType, true), + TypeCollection(ArrayType(StringType), StringType), + ArrayType(StringType, true)) } test("ineligible implicit type cast") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 0ed567a90dd1f..96f433be8b065 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -26,7 +26,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("concat") { def testConcat(inputs: String*): Unit = { - val expected = inputs.filter(_ != null).mkString + val expected = if (inputs.contains(null)) null else inputs.mkString checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow) } @@ -46,6 +46,35 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("concat_ws") { + def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { + val inputExprs = inputs.map { + case s: Seq[_] => Literal.create(s, ArrayType(StringType)) + case null => Literal.create(null, StringType) + case s: String => Literal.create(s, StringType) + } + val sepExpr = Literal.create(sep, StringType) + checkEvaluation(ConcatWs(sepExpr +: inputExprs), expected, EmptyRow) + } + + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + testConcatWs(null, null) + testConcatWs(null, null, "a", "b") + testConcatWs("", "") + testConcatWs("ab", "哈哈", "ab") + testConcatWs("a哈哈b", "哈哈", "a", "b") + testConcatWs("a哈哈b", "哈哈", "a", null, "b") + testConcatWs("a哈哈b哈哈c", "哈哈", null, "a", null, "b", "c") + + testConcatWs("ab", "哈哈", Seq("ab")) + testConcatWs("a哈哈b", "哈哈", Seq("a", "b")) + testConcatWs("a哈哈b哈哈c哈哈d", "哈哈", Seq("a", null, "b"), null, "c", Seq(null, "d")) + testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq.empty[String]) + testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq[String](null)) + // scalastyle:on + } + test("StringComparison") { val row = create_row("abc", null) val c1 = 'a.string.at(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f67c89437bb4a..b5140dca0487f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1732,6 +1732,30 @@ object functions { concat((columnName +: columnNames).map(Column.apply): _*) } + /** + * Concatenates input strings together into a single string, using the given separator. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat_ws(sep: String, exprs: Column*): Column = { + ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) + } + + /** + * Concatenates input strings together into a single string, using the given separator. + * + * This is the variant of concat_ws that takes in the column names. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat_ws(sep: String, columnName: String, columnNames: String*): Column = { + concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*) + } + /** * Computes the length of a given string / binary value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 4eff33ed45042..fe4de8d8b855f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -30,14 +30,25 @@ class StringFunctionsSuite extends QueryTest { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") checkAnswer( - df.select(concat($"a", $"b", $"c")), - Row("ab")) + df.select(concat($"a", $"b"), concat($"a", $"b", $"c")), + Row("ab", null)) checkAnswer( - df.selectExpr("concat(a, b, c)"), - Row("ab")) + df.selectExpr("concat(a, b)", "concat(a, b, c)"), + Row("ab", null)) } + test("string concat_ws") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat_ws("||", $"a", $"b", $"c")), + Row("a||b")) + + checkAnswer( + df.selectExpr("concat_ws('||', a, b, c)"), + Row("a||b")) + } test("string Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2689d904d6541..b12b3838e615c 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -263,9 +263,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_2", "timestamp_udf", - // Hive outputs NULL if any concat input has null. We never output null for concat. - "udf_concat", - // Unlike Hive, we do support log base in (0, 1.0], therefore disable this "udf7" ) @@ -856,6 +853,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_case", "udf_ceil", "udf_ceiling", + "udf_concat", "udf_concat_insert1", "udf_concat_insert2", "udf_concat_ws", diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 9723b6e0834b2..3eecd657e6ef9 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -397,26 +397,62 @@ public UTF8String lpad(int len, UTF8String pad) { } /** - * Concatenates input strings together into a single string. A null input is skipped. - * For example, concat("a", null, "c") would yield "ac". + * Concatenates input strings together into a single string. Returns null if any input is null. */ public static UTF8String concat(UTF8String... inputs) { - if (inputs == null) { - return fromBytes(new byte[0]); - } - // Compute the total length of the result. int totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { totalLength += inputs[i].numBytes; + } else { + return null; } } // Allocate a new byte array, and copy the inputs one by one into it. final byte[] result = new byte[totalLength]; int offset = 0; + for (int i = 0; i < inputs.length; i++) { + int len = inputs[i].numBytes; + PlatformDependent.copyMemory( + inputs[i].base, inputs[i].offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + return fromBytes(result); + } + + /** + * Concatenates input strings together into a single string using the separator. + * A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c". + */ + public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { + if (separator == null) { + return null; + } + + int numInputBytes = 0; // total number of bytes from the inputs + int numInputs = 0; // number of non-null inputs for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + numInputBytes += inputs[i].numBytes; + numInputs++; + } + } + + if (numInputs == 0) { + // Return an empty string if there is no input, or all the inputs are null. + return fromBytes(new byte[0]); + } + + // Allocate a new byte array, and copy the inputs one by one into it. + // The size of the new array is the size of all inputs, plus the separators. + final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes]; + int offset = 0; + + for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; PlatformDependent.copyMemory( @@ -424,6 +460,16 @@ public static UTF8String concat(UTF8String... inputs) { result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, len); offset += len; + + j++; + // Add separator if this is not the last input. + if (j < numInputs) { + PlatformDependent.copyMemory( + separator.base, separator.offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + separator.numBytes); + offset += separator.numBytes; + } } } return fromBytes(result); diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 0db7522b50c1a..7d0c49e2fb84c 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -88,16 +88,50 @@ public void upperAndLower() { @Test public void concatTest() { - assertEquals(concat(), fromString("")); - assertEquals(concat(null), fromString("")); - assertEquals(concat(fromString("")), fromString("")); - assertEquals(concat(fromString("ab")), fromString("ab")); - assertEquals(concat(fromString("a"), fromString("b")), fromString("ab")); - assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc")); - assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac")); - assertEquals(concat(fromString("a"), null, null), fromString("a")); - assertEquals(concat(null, null, null), fromString("")); - assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头")); + assertEquals(fromString(""), concat()); + assertEquals(null, concat((UTF8String) null)); + assertEquals(fromString(""), concat(fromString(""))); + assertEquals(fromString("ab"), concat(fromString("ab"))); + assertEquals(fromString("ab"), concat(fromString("a"), fromString("b"))); + assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c"))); + assertEquals(null, concat(fromString("a"), null, fromString("c"))); + assertEquals(null, concat(fromString("a"), null, null)); + assertEquals(null, concat(null, null, null)); + assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头"))); + } + + @Test + public void concatWsTest() { + // Returns null if the separator is null + assertEquals(null, concatWs(null, (UTF8String)null)); + assertEquals(null, concatWs(null, fromString("a"))); + + // If separator is null, concatWs should skip all null inputs and never return null. + UTF8String sep = fromString("哈哈"); + assertEquals( + fromString(""), + concatWs(sep, fromString(""))); + assertEquals( + fromString("ab"), + concatWs(sep, fromString("ab"))); + assertEquals( + fromString("a哈哈b"), + concatWs(sep, fromString("a"), fromString("b"))); + assertEquals( + fromString("a哈哈b哈哈c"), + concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); + assertEquals( + fromString("a哈哈c"), + concatWs(sep, fromString("a"), null, fromString("c"))); + assertEquals( + fromString("a"), + concatWs(sep, fromString("a"), null, null)); + assertEquals( + fromString(""), + concatWs(sep, null, null, null)); + assertEquals( + fromString("数据哈哈砖头"), + concatWs(sep, fromString("数据"), fromString("砖头"))); } @Test @@ -215,14 +249,18 @@ public void pad() { assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); - assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者"))); + assertEquals( + fromString("孙行者孙行者孙行数据砖头"), + fromString("数据砖头").lpad(12, fromString("孙行者"))); assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); - assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); + assertEquals( + fromString("数据砖头孙行者孙行者孙行"), + fromString("数据砖头").rpad(12, fromString("孙行者"))); } @Test