Skip to content

Commit

Permalink
Java bindings for approx_percentile (#9094)
Browse files Browse the repository at this point in the history
This PR builds on #8983 and adds Java bindings.

Authors:
  - Andy Grove (https://github.com/andygrove)
  - https://github.com/nvdbaranec

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #9094
  • Loading branch information
andygrove committed Sep 24, 2021
1 parent ba76310 commit 0b89459
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 3 deletions.
47 changes: 46 additions & 1 deletion java/src/main/java/ai/rapids/cudf/Aggregation.java
Expand Up @@ -65,7 +65,9 @@ enum Kind {
M2(26),
MERGE_M2(27),
RANK(28),
DENSE_RANK(29);
DENSE_RANK(29),
TDIGEST(30), // This can take a delta argument for accuracy level
MERGE_TDIGEST(31); // This can take a delta argument for accuracy level

final int nativeId;

Expand Down Expand Up @@ -864,6 +866,44 @@ static MergeM2Aggregation mergeM2() {
return new MergeM2Aggregation();
}

static class TDigestAggregation extends Aggregation {
private final int delta;

public TDigestAggregation(Kind kind, int delta) {
super(kind);
this.delta = delta;
}

@Override
long createNativeInstance() {
return Aggregation.createTDigestAgg(kind.nativeId, delta);
}

@Override
public int hashCode() {
return 31 * kind.hashCode() + delta;
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof TDigestAggregation) {
TDigestAggregation o = (TDigestAggregation) other;
return o.delta == this.delta;
}
return false;
}
}

static TDigestAggregation createTDigest(int delta) {
return new TDigestAggregation(Kind.TDIGEST, delta);
}

static TDigestAggregation mergeTDigest(int delta) {
return new TDigestAggregation(Kind.MERGE_TDIGEST, delta);
}

/**
* Create one of the aggregations that only needs a kind, no other parameters. This does not
* work for all types and for code safety reasons each kind is added separately.
Expand Down Expand Up @@ -909,4 +949,9 @@ static MergeM2Aggregation mergeM2() {
* Create a merge sets aggregation.
*/
private static native long createMergeSetsAgg(boolean nullsEqual, boolean nansEqual);

/**
* Create a TDigest aggregation.
*/
private static native long createTDigestAgg(int kind, int delta);
}
38 changes: 37 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ColumnView.java
Expand Up @@ -1423,12 +1423,39 @@ public Scalar reduce(ReductionAggregation aggregation, DType outType) {
}
}

/**
* Calculate various percentiles of this ColumnVector, which must contain centroids produced by
* a t-digest aggregation.
*
* @param percentiles Required percentiles [0,1]
* @return Column containing the approximate percentile values as a list of doubles, in
* the same order as the input percentiles
*/
public final ColumnVector approxPercentile(double[] percentiles) {
try (ColumnVector cv = ColumnVector.fromDoubles(percentiles)) {
return approxPercentile(cv);
}
}

/**
* Calculate various percentiles of this ColumnVector, which must contain centroids produced by
* a t-digest aggregation.
*
* @param percentiles Column containing percentiles [0,1]
* @return Column containing the approximate percentile values as a list of doubles, in
* the same order as the input percentiles
*/
public final ColumnVector approxPercentile(ColumnVector percentiles) {
return new ColumnVector(approxPercentile(getNativeView(), percentiles.getNativeView()));
}

/**
* Calculate various quantiles of this ColumnVector. It is assumed that this is already sorted
* in the desired order.
* @param method the method used to calculate the quantiles
* @param quantiles the quantile values [0,1]
* @return the quantiles as doubles, in the same order passed in. The type can be changed in future
* @return Column containing the approximate percentile values as a list of doubles, in
* the same order as the input percentiles
*/
public final ColumnVector quantile(QuantileMethod method, double[] quantiles) {
return new ColumnVector(quantile(getNativeView(), method.nativeId, quantiles));
Expand Down Expand Up @@ -3544,6 +3571,15 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
*/
private static native long upperStrings(long cudfViewHandle);

/**
* Native method to compute approx percentiles.
* @param cudfColumnHandle T-Digest column
* @param percentilesHandle Percentiles
* @return native handle of the resulting cudf column, used to construct the Java column
* by the approxPercentile method.
*/
private static native long approxPercentile(long cudfColumnHandle, long percentilesHandle) throws CudfException;

private static native long quantile(long cudfColumnHandle, int quantileMethod, double[] quantiles) throws CudfException;

private static native long rollingWindow(
Expand Down
22 changes: 22 additions & 0 deletions java/src/main/java/ai/rapids/cudf/GroupByAggregation.java
Expand Up @@ -293,4 +293,26 @@ public static GroupByAggregation mergeSets(NullEquality nullEquality, NaNEqualit
public static GroupByAggregation mergeM2() {
return new GroupByAggregation(Aggregation.mergeM2());
}

/**
* Compute a t-digest from on a fixed-width numeric input column.
*
* @param delta Required accuracy (number of buckets).
* @return A list of centroids per grouping, where each centroid has a mean value and a
* weight. The number of centroids will be <= delta.
*/
public static GroupByAggregation createTDigest(int delta) {
return new GroupByAggregation(Aggregation.createTDigest(delta));
}

/**
* Merge t-digests.
*
* @param delta Required accuracy (number of buckets).
* @return A list of centroids per grouping, where each centroid has a mean value and a
* weight. The number of centroids will be <= delta.
*/
public static GroupByAggregation mergeTDigest(int delta) {
return new GroupByAggregation(Aggregation.mergeTDigest(delta));
}
}
23 changes: 22 additions & 1 deletion java/src/main/native/src/AggregationJni.cpp
Expand Up @@ -85,7 +85,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv
return cudf::make_rank_aggregation();
case 29: // DENSE_RANK
return cudf::make_dense_rank_aggregation();

default: throw std::logic_error("Unsupported No Parameter Aggregation Operation");
}
}();
Expand Down Expand Up @@ -131,6 +130,28 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createDdofAgg(JNIEnv *en
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createTDigestAgg(JNIEnv *env,
jclass class_object,
jint kind, jint delta) {
try {
cudf::jni::auto_set_device(env);

std::unique_ptr<cudf::aggregation> ret;
// These numbers come from Aggregation.java and must stay in sync
switch (kind) {
case 30: // TDIGEST
ret = cudf::make_tdigest_aggregation<cudf::groupby_aggregation>(delta);
break;
case 31: // MERGE_TDIGEST
ret = cudf::make_merge_tdigest_aggregation<cudf::groupby_aggregation>(delta);
break;
default: throw std::logic_error("Unsupported TDigest Aggregation Operation");
}
return reinterpret_cast<jlong>(ret.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCountLikeAgg(JNIEnv *env,
jclass class_object,
jint kind,
Expand Down
19 changes: 19 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Expand Up @@ -289,6 +289,25 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_scan(JNIEnv *env, jclass,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_approxPercentile(JNIEnv *env, jclass clazz,
jlong input_column,
jlong percentiles_column) {
JNI_NULL_CHECK(env, input_column, "input_column native handle is null", 0);
JNI_NULL_CHECK(env, percentiles_column, "percentiles_column native handle is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *n_input_column = reinterpret_cast<cudf::column_view *>(input_column);
std::unique_ptr<cudf::structs_column_view> input_view =
std::make_unique<cudf::structs_column_view>(*n_input_column);
cudf::column_view *n_percentiles_column =
reinterpret_cast<cudf::column_view *>(percentiles_column);
std::unique_ptr<cudf::column> result =
cudf::percentile_approx(*input_view, *n_percentiles_column);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_quantile(JNIEnv *env, jclass clazz,
jlong input_column,
jint quantile_method,
Expand Down
100 changes: 100 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Expand Up @@ -3484,6 +3484,106 @@ void testGroupByReplaceNulls() {
}
}

@Test
void testGroupByApproxPercentileReproCase() {
double[] percentiles = {0.25, 0.50, 0.75};
try (Table t1 = new Table.TestBuilder()
.column("a", "a", "b", "c", "d")
.column(1084.0, 1719.0, 15948.0, 148029.0, 1269761.0)
.build();
Table t2 = t1
.groupBy(0)
.aggregate(GroupByAggregation.createTDigest(100).onColumn(1));
Table sorted = t2.orderBy(OrderByArg.asc(0));
ColumnVector actual = sorted.getColumn(1).approxPercentile(percentiles);
ColumnVector expected = ColumnVector.fromLists(
new ListType(false, new BasicType(false, DType.FLOAT64)),
Arrays.asList(1084.0, 1084.0, 1719.0),
Arrays.asList(15948.0, 15948.0, 15948.0),
Arrays.asList(148029.0, 148029.0, 148029.0),
Arrays.asList(1269761.0, 1269761.0, 1269761.0)
)) {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testGroupByApproxPercentile() {
double[] percentiles = {0.25, 0.50, 0.75};
try (Table t1 = new Table.TestBuilder()
.column("a", "a", "a", "b", "b", "b")
.column(100, 150, 160, 70, 110, 160)
.build();
Table t2 = t1
.groupBy(0)
.aggregate(GroupByAggregation.createTDigest(1000).onColumn(1));
Table sorted = t2.orderBy(OrderByArg.asc(0));
ColumnVector actual = sorted.getColumn(1).approxPercentile(percentiles);
ColumnVector expected = ColumnVector.fromLists(
new ListType(false, new BasicType(false, DType.FLOAT64)),
Arrays.asList(100d, 150d, 160d),
Arrays.asList(70d, 110d, 160d)
)) {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testMergeApproxPercentile() {
double[] percentiles = {0.25, 0.50, 0.75};
try (Table t1 = new Table.TestBuilder()
.column("a", "a", "a", "b", "b", "b")
.column(100, 150, 160, 70, 110, 160)
.build();
Table t2 = t1
.groupBy(0)
.aggregate(GroupByAggregation.createTDigest(1000).onColumn(1));
Table t3 = t1
.groupBy(0)
.aggregate(GroupByAggregation.createTDigest(1000).onColumn(1));
Table t4 = Table.concatenate(t2, t3);
Table t5 = t4
.groupBy(0)
.aggregate(GroupByAggregation.mergeTDigest(1000).onColumn(1));
Table sorted = t5.orderBy(OrderByArg.asc(0));
ColumnVector actual = sorted.getColumn(1).approxPercentile(percentiles);
ColumnVector expected = ColumnVector.fromLists(
new ListType(false, new BasicType(false, DType.FLOAT64)),
Arrays.asList(100d, 150d, 160d),
Arrays.asList(70d, 110d, 160d)
)) {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testMergeApproxPercentile2() {
double[] percentiles = {0.25, 0.50, 0.75};
try (Table t1 = new Table.TestBuilder()
.column("a", "a", "a", "b", "b", "b")
.column(70, 110, 160, 100, 150, 160)
.build();
Table t2 = t1
.groupBy(0)
.aggregate(GroupByAggregation.createTDigest(1000).onColumn(1));
Table t3 = t1
.groupBy(0)
.aggregate(GroupByAggregation.createTDigest(1000).onColumn(1));
Table t4 = Table.concatenate(t2, t3);
Table t5 = t4
.groupBy(0)
.aggregate(GroupByAggregation.mergeTDigest(1000).onColumn(1));
Table sorted = t5.orderBy(OrderByArg.asc(0));
ColumnVector actual = sorted.getColumn(1).approxPercentile(percentiles);
ColumnVector expected = ColumnVector.fromLists(
new ListType(false, new BasicType(false, DType.FLOAT64)),
Arrays.asList(70d, 110d, 160d),
Arrays.asList(100d, 150d, 160d)
)) {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testGroupByUniqueCount() {
try (Table t1 = new Table.TestBuilder()
Expand Down

0 comments on commit 0b89459

Please sign in to comment.