Skip to content

Commit

Permalink
Enforce broadcast memory limits in Presto on Spark
Browse files Browse the repository at this point in the history
  • Loading branch information
arhimondr committed Oct 5, 2020
1 parent 13a8f64 commit 79a8c8f
Showing 1 changed file with 23 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
import java.util.concurrent.Future;

import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue;
import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalBroadcastMemoryLimit;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxBroadcastMemory;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.execution.QueryState.FAILED;
import static com.facebook.presto.execution.QueryState.FINISHED;
Expand Down Expand Up @@ -160,6 +162,7 @@
import static com.google.common.util.concurrent.Futures.getUnchecked;
import static io.airlift.units.DataSize.succinctBytes;
import static java.lang.Math.max;
import static java.lang.String.format;
import static java.nio.file.Files.notExists;
import static java.util.Collections.unmodifiableList;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -897,9 +900,29 @@ private <T extends PrestoSparkTaskOutput> RddAndMore<T> createRdd(SubPlan subPla
PlanFragment childFragment = child.getFragment();
if (childFragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION)) {
RddAndMore<PrestoSparkSerializedPage> childRdd = createRdd(child, PrestoSparkSerializedPage.class);

// TODO: The driver might still OOM on a very large broadcast, think of how to prevent that from happening
List<PrestoSparkSerializedPage> broadcastPages = childRdd.collectAndDestroyDependencies().stream()
.map(Tuple2::_2)
.collect(toList());

int compressedBroadcastSizeInBytes = broadcastPages.stream()
.mapToInt(page -> page.getBytes().length)
.sum();
int uncompressedBroadcastSizeInBytes = broadcastPages.stream()
.mapToInt(PrestoSparkSerializedPage::getUncompressedSizeInBytes)
.sum();
DataSize maxBroadcastSize = getQueryMaxBroadcastMemory(session);
long maxBroadcastSizeInBytes = maxBroadcastSize.toBytes();

if (compressedBroadcastSizeInBytes > maxBroadcastSizeInBytes) {
throw exceededLocalBroadcastMemoryLimit(maxBroadcastSize, format("Compressed broadcast size: %s", succinctBytes(compressedBroadcastSizeInBytes)));
}

if (uncompressedBroadcastSizeInBytes > maxBroadcastSizeInBytes) {
throw exceededLocalBroadcastMemoryLimit(maxBroadcastSize, format("Uncompressed broadcast size: %s", succinctBytes(compressedBroadcastSizeInBytes)));
}

Broadcast<List<PrestoSparkSerializedPage>> broadcast = sparkContext.broadcast(broadcastPages);
broadcastInputs.put(childFragment.getId(), broadcast);
broadcastDependencies.add(broadcast);
Expand Down

0 comments on commit 79a8c8f

Please sign in to comment.