From df13ca05c475e98bf5c218a4503513065611a47f Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 27 Jun 2024 21:39:06 +0800 Subject: [PATCH] [SPARK-48735][SQL] Performance Improvement for BIN function ### What changes were proposed in this pull request? This PR implemented a long-to-binary form UTF8String method directly to improve the performance of the BIN function. It omits the procedure of encoding/decoding and array copying. ### Why are the changes needed? performance improvement ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - new unit tests - offline benchmarking ~2x ### Was this patch authored or co-authored using generative AI tooling? no Closes #47119 from yaooqinn/SPARK-48735. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../apache/spark/unsafe/types/UTF8String.java | 19 +++++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 17 ++++++++++ .../expressions/mathExpressions.scala | 8 ++--- .../analyzer-results/ansi/math.sql.out | 28 ++++++++++++++++ .../sql-tests/analyzer-results/math.sql.out | 28 ++++++++++++++++ .../test/resources/sql-tests/inputs/math.sql | 5 +++ .../sql-tests/results/ansi/math.sql.out | 32 +++++++++++++++++++ .../resources/sql-tests/results/math.sql.out | 32 +++++++++++++++++++ 8 files changed, 164 insertions(+), 5 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 12a7b06232ee7..49d3088f8a2f0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -102,6 +102,8 @@ public final class UTF8String implements Comparable, Externalizable, private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); + public static final UTF8String ZERO_UTF8 = UTF8String.fromString("0"); + /** * Creates an UTF8String from byte array, which should be encoded in UTF-8. @@ -1867,4 +1869,21 @@ public void read(Kryo kryo, Input in) { in.read((byte[]) base); } + /** + * Convert a long value to its binary format stripping leading zeros. + */ + public static UTF8String toBinaryString(long val) { + int zeros = Long.numberOfLeadingZeros(val); + if (zeros == Long.SIZE) { + return UTF8String.ZERO_UTF8; + } else { + int length = Long.SIZE - zeros; + byte[] bytes = new byte[length]; + do { + bytes[--length] = (byte) ((val & 0x1) == 1 ? '1': '0'); + val >>>= 1; + } while (length > 0); + return fromBytes(bytes); + } + } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index f9b351697e8b3..07793a24e5eed 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -1110,4 +1110,21 @@ public void isValid() { testIsValid("0x9C 0x76 0x17", "0xEF 0xBF 0xBD 0x76 0x17"); } + @Test + public void toBinaryString() { + assertEquals(ZERO_UTF8, UTF8String.toBinaryString(0)); + assertEquals(UTF8String.fromString("1"), UTF8String.toBinaryString(1)); + assertEquals(UTF8String.fromString("10"), UTF8String.toBinaryString(2)); + assertEquals(UTF8String.fromString("100"), UTF8String.toBinaryString(4)); + assertEquals(UTF8String.fromString("111"), UTF8String.toBinaryString(7)); + assertEquals( + UTF8String.fromString("1111111111111111111111111111111111111111111111111111111111110011"), + UTF8String.toBinaryString(-13)); + assertEquals( + UTF8String.fromString("1000000000000000000000000000000000000000000000000000000000000000"), + UTF8String.toBinaryString(Long.MIN_VALUE)); + assertEquals( + UTF8String.fromString("111111111111111111111111111111111111111111111111111111111111111"), + UTF8String.toBinaryString(Long.MAX_VALUE)); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 5981b42aead85..00274a16b888b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1008,11 +1008,10 @@ case class Bin(child: Expression) override def dataType: DataType = SQLConf.get.defaultStringType protected override def nullSafeEval(input: Any): Any = - UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long])) + UTF8String.toBinaryString(input.asInstanceOf[Long]) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c) => - s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") + defineCodeGen(ctx, ev, c => s"UTF8String.toBinaryString($c)") } override protected def withNewChildInternal(newChild: Expression): Bin = copy(child = newChild) @@ -1021,7 +1020,6 @@ case class Bin(child: Expression) object Hex { private final val hexDigits = Array[Byte]('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F') - private final val ZERO_UTF8 = UTF8String.fromBytes(Array[Byte]('0')) // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 val unhexDigits = { @@ -1053,7 +1051,7 @@ object Hex { def hex(num: Long): UTF8String = { val zeros = jl.Long.numberOfLeadingZeros(num) - if (zeros == jl.Long.SIZE) return ZERO_UTF8 + if (zeros == jl.Long.SIZE) return UTF8String.ZERO_UTF8 val len = (jl.Long.SIZE - zeros + 3) / 4 var numBuf = num val value = new Array[Byte](len) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/math.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/math.sql.out index 7eb7fcff356a4..8d59b678e92f7 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/math.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/math.sql.out @@ -431,3 +431,31 @@ SELECT conv('-9223372036854775807', 36, 10) -- !query analysis Project [conv(-9223372036854775807, 36, 10, true) AS conv(-9223372036854775807, 36, 10)#x] +- OneRowRelation + + +-- !query +SELECT BIN(0) +-- !query analysis +Project [bin(cast(0 as bigint)) AS bin(0)#x] ++- OneRowRelation + + +-- !query +SELECT BIN(25) +-- !query analysis +Project [bin(cast(25 as bigint)) AS bin(25)#x] ++- OneRowRelation + + +-- !query +SELECT BIN(25L) +-- !query analysis +Project [bin(25) AS bin(25)#x] ++- OneRowRelation + + +-- !query +SELECT BIN(25.5) +-- !query analysis +Project [bin(cast(25.5 as bigint)) AS bin(25.5)#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/math.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/math.sql.out index e4dd1994b2c9e..0d9b9267cd089 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/math.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/math.sql.out @@ -431,3 +431,31 @@ SELECT conv('-9223372036854775807', 36, 10) -- !query analysis Project [conv(-9223372036854775807, 36, 10, false) AS conv(-9223372036854775807, 36, 10)#x] +- OneRowRelation + + +-- !query +SELECT BIN(0) +-- !query analysis +Project [bin(cast(0 as bigint)) AS bin(0)#x] ++- OneRowRelation + + +-- !query +SELECT BIN(25) +-- !query analysis +Project [bin(cast(25 as bigint)) AS bin(25)#x] ++- OneRowRelation + + +-- !query +SELECT BIN(25L) +-- !query analysis +Project [bin(25) AS bin(25)#x] ++- OneRowRelation + + +-- !query +SELECT BIN(25.5) +-- !query analysis +Project [bin(cast(25.5 as bigint)) AS bin(25.5)#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/inputs/math.sql b/sql/core/src/test/resources/sql-tests/inputs/math.sql index 96fb0eeef7ac3..398a8b3290b18 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/math.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/math.sql @@ -77,3 +77,8 @@ SELECT conv('9223372036854775808', 10, 16); SELECT conv('92233720368547758070', 10, 16); SELECT conv('9223372036854775807', 36, 10); SELECT conv('-9223372036854775807', 36, 10); + +SELECT BIN(0); +SELECT BIN(25); +SELECT BIN(25L); +SELECT BIN(25.5); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out index 8cd1536d7f726..9b886218f3ad9 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out @@ -797,3 +797,35 @@ org.apache.spark.SparkArithmeticException "fragment" : "conv('-9223372036854775807', 36, 10)" } ] } + + +-- !query +SELECT BIN(0) +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT BIN(25) +-- !query schema +struct +-- !query output +11001 + + +-- !query +SELECT BIN(25L) +-- !query schema +struct +-- !query output +11001 + + +-- !query +SELECT BIN(25.5) +-- !query schema +struct +-- !query output +11001 diff --git a/sql/core/src/test/resources/sql-tests/results/math.sql.out b/sql/core/src/test/resources/sql-tests/results/math.sql.out index d3df5cb933574..88a857a00f0f6 100644 --- a/sql/core/src/test/resources/sql-tests/results/math.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/math.sql.out @@ -493,3 +493,35 @@ SELECT conv('-9223372036854775807', 36, 10) struct -- !query output 18446744073709551615 + + +-- !query +SELECT BIN(0) +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT BIN(25) +-- !query schema +struct +-- !query output +11001 + + +-- !query +SELECT BIN(25L) +-- !query schema +struct +-- !query output +11001 + + +-- !query +SELECT BIN(25.5) +-- !query schema +struct +-- !query output +11001