From 9afd5b0feb4b7676fe014cf2399ad7e9f150b952 Mon Sep 17 00:00:00 2001 From: Zac Blanco Date: Fri, 26 Apr 2024 13:20:04 -0400 Subject: [PATCH] Improve KLL Sketch perf for non-grouped queries The current method for calculating memory usage has a hidden cost. Within getEstimatedKllInMemorySize we call getSerializedSizeBytes. The code for the serialized bytes size actually serializes the entire internal state to a byte array first before returning the length. This is expensive and should be avoided. I am working on a PR to the upstream library to add a less-costly method but until released, I would like to fix this as non-grouped execution doesn't need the memory accounting for every sketch input. --- .../sketch/kll/KllSketchAggregationState.java | 9 +++---- .../sketch/kll/KllSketchStateSerializer.java | 4 ++-- .../KllSketchWithKAggregationFunction.java | 24 +++++++++---------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchAggregationState.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchAggregationState.java index efddc0145ad2..54e4d537afc6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchAggregationState.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchAggregationState.java @@ -34,6 +34,7 @@ import java.util.Comparator; import java.util.Map; +import java.util.function.Supplier; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; @@ -60,7 +61,7 @@ public interface KllSketchAggregationState @Nullable KllItemsSketch getSketch(); - void addMemoryUsage(long value); + void addMemoryUsage(Supplier usage); Type getType(); @@ -115,7 +116,7 @@ public void setSketch(KllItemsSketch sketch) } @Override - public void addMemoryUsage(long value) + public void addMemoryUsage(Supplier usage) { // noop } @@ -161,9 +162,9 @@ public KllItemsSketch getSketch() } @Override - public void addMemoryUsage(long value) + public void addMemoryUsage(Supplier usage) { - accumulatedSizeInBytes += value; + accumulatedSizeInBytes += usage.get(); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchStateSerializer.java index d10010e83e29..bb5f88ad9462 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchStateSerializer.java @@ -65,8 +65,8 @@ public void deserialize(Block block, int index, KllSketchAggregationState state) KllSketchAggregationState.SketchParameters parameters = KllSketchAggregationState.getSketchParameters(type); // use heapify over wrap in order to get a writable sketch for updates and merges KllItemsSketch sketch = KllItemsSketch.heapify(memory, parameters.getComparator(), parameters.getSerde()); - state.addMemoryUsage(-getEstimatedKllInMemorySize(state.getSketch(), type.getJavaType())); + state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(state.getSketch(), type.getJavaType())); state.setSketch(sketch); - state.addMemoryUsage(getEstimatedKllInMemorySize(state.getSketch(), type.getJavaType())); + state.addMemoryUsage(() -> getEstimatedKllInMemorySize(state.getSketch(), type.getJavaType())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchWithKAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchWithKAggregationFunction.java index 0064a6984676..69c5e000c9c7 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchWithKAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchWithKAggregationFunction.java @@ -54,9 +54,9 @@ public static void input(@AggregationState KllSketchAggregationState state, @Sql { initializeSketch(state, () -> Long::compareTo, ArrayOfLongsSerDe::new, k); KllItemsSketch sketch = state.getSketch(); - state.addMemoryUsage(-getEstimatedKllInMemorySize(sketch, long.class)); + state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, long.class)); state.getSketch().update(value); - state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, long.class)); + state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, long.class)); } @InputFunction @@ -65,9 +65,9 @@ public static void input(@AggregationState KllSketchAggregationState state, @Sql { initializeSketch(state, () -> Double::compareTo, ArrayOfDoublesSerDe::new, k); KllItemsSketch sketch = state.getSketch(); - state.addMemoryUsage(-getEstimatedKllInMemorySize(sketch, double.class)); + state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, double.class)); state.getSketch().update(value); - state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, double.class)); + state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, double.class)); } @InputFunction @@ -76,9 +76,9 @@ public static void input(@AggregationState KllSketchAggregationState state, @Sql { initializeSketch(state, () -> String::compareTo, ArrayOfStringsSerDe::new, k); KllItemsSketch sketch = state.getSketch(); - state.addMemoryUsage(-getEstimatedKllInMemorySize(sketch, Slice.class)); + state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, Slice.class)); state.getSketch().update(value.toStringUtf8()); - state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, Slice.class)); + state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, Slice.class)); } @InputFunction @@ -87,22 +87,22 @@ public static void input(@AggregationState KllSketchAggregationState state, @Sql { initializeSketch(state, () -> Boolean::compareTo, ArrayOfBooleansSerDe::new, k); KllItemsSketch sketch = state.getSketch(); - state.addMemoryUsage(-getEstimatedKllInMemorySize(sketch, boolean.class)); + state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, boolean.class)); state.getSketch().update(value); - state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, boolean.class)); + state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, boolean.class)); } @CombineFunction public static void combine(@AggregationState KllSketchAggregationState state, @AggregationState KllSketchAggregationState otherState) { if (state.getSketch() != null && otherState.getSketch() != null) { - state.addMemoryUsage(-getEstimatedKllInMemorySize(state.getSketch(), state.getType().getJavaType())); + state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(state.getSketch(), state.getType().getJavaType())); state.getSketch().merge(otherState.getSketch()); - state.addMemoryUsage(getEstimatedKllInMemorySize(state.getSketch(), state.getType().getJavaType())); + state.addMemoryUsage(() -> getEstimatedKllInMemorySize(state.getSketch(), state.getType().getJavaType())); } else if (state.getSketch() == null) { state.setSketch(otherState.getSketch()); - state.addMemoryUsage(getEstimatedKllInMemorySize(otherState.getSketch(), state.getType().getJavaType())); + state.addMemoryUsage(() -> getEstimatedKllInMemorySize(otherState.getSketch(), state.getType().getJavaType())); } } @@ -125,7 +125,7 @@ private static void initializeSketch(KllSketchAggregationState state, Suppli if (state.getSketch() == null) { KllItemsSketch sketch = KllItemsSketch.newHeapInstance((int) k, comparator.get(), serdeSupplier.get()); state.setSketch(sketch); - state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, state.getType().getJavaType())); + state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, state.getType().getJavaType())); } } }