Skip to content

Commit

Permalink
add arraymerge tdigest function
Browse files Browse the repository at this point in the history
  • Loading branch information
ampampamp committed Aug 11, 2023
1 parent dc3b362 commit 2d0ad86
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
4 changes: 4 additions & 0 deletions presto-docs/src/main/sphinx/functions/tdigest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ Functions
values in the digest). This is an inverse of ``destructure_tdigest``.

This function is particularly useful for adding externally-created tdigests to Presto.

.. function:: merge_tdigest(array<tdigest<double>>) -> tdigest<double>
Returns a merged ``tdigest`` of the T-digests in an array. This is the
scalar complement to the aggregation function ``merge``.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.tdigest.Centroid;
import com.facebook.presto.tdigest.TDigest;
Expand Down Expand Up @@ -181,4 +182,29 @@ public static Slice constructTDigest(

return tDigest.serialize();
}

@ScalarFunction(value = "merge_tdigest", visibility = EXPERIMENTAL)
@Description("Merge an array of TDigests into a single TDigest")
@SqlType("tdigest(double)")
@SqlNullable
public static Slice merge_tdigest(@SqlType("array(tdigest(double))") Block input)
{
if (input.getPositionCount() == 0) {
return null;
}
TDigest output = null;
for (int i = 0; i < input.getPositionCount(); i++) {
if (input.isNull(i)) {
continue;
}
TDigest tdigest = createTDigest(input.getSlice(i, 0, input.getSliceLength(i)));
if (output == null) {
output = tdigest;
}
else {
output.merge(tdigest);
}
}
return output == null ? null : output.serialize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,94 @@ public void testConstructTDigestInverse()
assertEquals(constructedSqlVarbinary, sqlVarbinary);
}

@Test
public void testMergeTDigestNullInput()
{
functionAssertions.assertFunction("merge_tdigest(null)", TDIGEST_DOUBLE, null);
}

@Test
public void testMergeTDigestEmptyArray()
{
functionAssertions.assertFunction("merge_tdigest(array[])", TDIGEST_DOUBLE, null);
}

@Test
public void testMergeTDigestEmptyArrayOfNull()
{
functionAssertions.assertFunction("merge_tdigest(array[null])", TDIGEST_DOUBLE, null);
}

@Test
public void testMergeTDigestEmptyArrayOfNulls()
{
functionAssertions.assertFunction("merge_tdigest(array[null, null, null])", TDIGEST_DOUBLE, null);
}

@Test
public void testMergeTDigests()
{
TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
addAll(digest1, 0.1);
TDigest digest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
addAll(digest2, 0.2);
SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue(
format("merge_tdigest(cast(array[%s, %s] as array(tdigest(double))))",
toSqlString(digest1),
toSqlString(digest2)),
TDIGEST_DOUBLE,
SqlVarbinary.class);
digest1.merge(digest2);
assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes()));
}

@Test
public void testMergeTDigestOneNull()
{
TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
addAll(digest1, 0.1);
SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue(
format("merge_tdigest(cast(array[%s, null] as array(tdigest(double))))",
toSqlString(digest1)),
TDIGEST_DOUBLE,
SqlVarbinary.class);
assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes()));
}

@Test
public void testMergeTDigestOneNullFirst()
{
TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
addAll(digest1, 0.1);
TDigest digest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
addAll(digest2, 0.2);
SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue(
format("merge_tdigest(cast(array[null, %s, %s] as array(tdigest(double))))",
toSqlString(digest1),
toSqlString(digest2)),
TDIGEST_DOUBLE,
SqlVarbinary.class);
digest1.merge(digest2);
assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes()));
}

@Test
public void testMergeTDigestOneNullMiddle()
{
TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
addAll(digest1, 0.1);
TDigest digest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
addAll(digest2, 0.2);
SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue(
format("merge_tdigest(cast(array[%s, null, %s] as array(tdigest(double))))",
toSqlString(digest1),
toSqlString(digest2)),
TDIGEST_DOUBLE,
SqlVarbinary.class);
digest1.merge(digest2);
assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes()));
}

// disabled because test takes almost 10s
@Test(enabled = false)
public void testBinomialDistribution()
Expand Down

0 comments on commit 2d0ad86

Please sign in to comment.