diff --git a/presto-docs/src/main/sphinx/functions/tdigest.rst b/presto-docs/src/main/sphinx/functions/tdigest.rst index 78db33a5f4ba..607e3e952e89 100644 --- a/presto-docs/src/main/sphinx/functions/tdigest.rst +++ b/presto-docs/src/main/sphinx/functions/tdigest.rst @@ -96,3 +96,7 @@ Functions values in the digest). This is an inverse of ``destructure_tdigest``. This function is particularly useful for adding externally-created tdigests to Presto. + + .. function:: merge_tdigest(array>) -> tdigest + Returns a merged ``tdigest`` of the T-digests in an array. This is the + scalar complement to the aggregation function ``merge``. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/TDigestFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/TDigestFunctions.java index 3a331e1b57a5..90ba357ea973 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/TDigestFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/TDigestFunctions.java @@ -20,6 +20,7 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.tdigest.Centroid; import com.facebook.presto.tdigest.TDigest; @@ -181,4 +182,29 @@ public static Slice constructTDigest( return tDigest.serialize(); } + + @ScalarFunction(value = "merge_tdigest", visibility = EXPERIMENTAL) + @Description("Merge an array of TDigests into a single TDigest") + @SqlType("tdigest(double)") + @SqlNullable + public static Slice merge_tdigest(@SqlType("array(tdigest(double))") Block input) + { + if (input.getPositionCount() == 0) { + return null; + } + TDigest output = null; + for (int i = 0; i < input.getPositionCount(); i++) { + if (input.isNull(i)) { + continue; + } + TDigest tdigest = createTDigest(input.getSlice(i, 0, input.getSliceLength(i))); + if (output == null) { + output = tdigest; + } + else { + output.merge(tdigest); + } + } + return output == null ? null : output.serialize(); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestTDigestFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestTDigestFunctions.java index 48f955a6902c..d6d56c01a130 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestTDigestFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestTDigestFunctions.java @@ -665,6 +665,94 @@ public void testConstructTDigestInverse() assertEquals(constructedSqlVarbinary, sqlVarbinary); } + @Test + public void testMergeTDigestNullInput() + { + functionAssertions.assertFunction("merge_tdigest(null)", TDIGEST_DOUBLE, null); + } + + @Test + public void testMergeTDigestEmptyArray() + { + functionAssertions.assertFunction("merge_tdigest(array[])", TDIGEST_DOUBLE, null); + } + + @Test + public void testMergeTDigestEmptyArrayOfNull() + { + functionAssertions.assertFunction("merge_tdigest(array[null])", TDIGEST_DOUBLE, null); + } + + @Test + public void testMergeTDigestEmptyArrayOfNulls() + { + functionAssertions.assertFunction("merge_tdigest(array[null, null, null])", TDIGEST_DOUBLE, null); + } + + @Test + public void testMergeTDigests() + { + TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR); + addAll(digest1, 0.1); + TDigest digest2 = createTDigest(STANDARD_COMPRESSION_FACTOR); + addAll(digest2, 0.2); + SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue( + format("merge_tdigest(cast(array[%s, %s] as array(tdigest(double))))", + toSqlString(digest1), + toSqlString(digest2)), + TDIGEST_DOUBLE, + SqlVarbinary.class); + digest1.merge(digest2); + assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes())); + } + + @Test + public void testMergeTDigestOneNull() + { + TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR); + addAll(digest1, 0.1); + SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue( + format("merge_tdigest(cast(array[%s, null] as array(tdigest(double))))", + toSqlString(digest1)), + TDIGEST_DOUBLE, + SqlVarbinary.class); + assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes())); + } + + @Test + public void testMergeTDigestOneNullFirst() + { + TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR); + addAll(digest1, 0.1); + TDigest digest2 = createTDigest(STANDARD_COMPRESSION_FACTOR); + addAll(digest2, 0.2); + SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue( + format("merge_tdigest(cast(array[null, %s, %s] as array(tdigest(double))))", + toSqlString(digest1), + toSqlString(digest2)), + TDIGEST_DOUBLE, + SqlVarbinary.class); + digest1.merge(digest2); + assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes())); + } + + @Test + public void testMergeTDigestOneNullMiddle() + { + TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR); + addAll(digest1, 0.1); + TDigest digest2 = createTDigest(STANDARD_COMPRESSION_FACTOR); + addAll(digest2, 0.2); + SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue( + format("merge_tdigest(cast(array[%s, null, %s] as array(tdigest(double))))", + toSqlString(digest1), + toSqlString(digest2)), + TDIGEST_DOUBLE, + SqlVarbinary.class); + digest1.merge(digest2); + assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes())); + } + // disabled because test takes almost 10s @Test(enabled = false) public void testBinomialDistribution()