Skip to content

Commit

Permalink
Introduce FaultTolerantPartitioningScheme
Browse files Browse the repository at this point in the history
To encapsulate partition assignment logic
  • Loading branch information
arhimondr committed Sep 16, 2022
1 parent 65e99a2 commit 45eb8ef
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 149 deletions.
Expand Up @@ -52,4 +52,9 @@ public InternalNode getAssignedNode(Split split)
{
return getAssignedNode(getBucket(split));
}

public ToIntFunction<Split> getSplitToBucketFunction()
{
return splitToBucket;
}
}
@@ -0,0 +1,84 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.scheduler;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;

import java.util.List;
import java.util.Optional;
import java.util.function.ToIntFunction;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class FaultTolerantPartitioningScheme
{
private final int partitionCount;
private final Optional<int[]> bucketToPartitionMap;
private final Optional<ToIntFunction<Split>> splitToBucketFunction;
private final Optional<List<InternalNode>> partitionToNodeMap;

@VisibleForTesting
FaultTolerantPartitioningScheme(
int partitionCount,
Optional<int[]> bucketToPartitionMap,
Optional<ToIntFunction<Split>> splitToBucketFunction,
Optional<List<InternalNode>> partitionToNodeMap)
{
checkArgument(partitionCount > 0, "partitionCount must be greater than zero");
this.partitionCount = partitionCount;
this.bucketToPartitionMap = requireNonNull(bucketToPartitionMap, "bucketToPartitionMap is null");
this.splitToBucketFunction = requireNonNull(splitToBucketFunction, "splitToBucketFunction is null");
requireNonNull(partitionToNodeMap, "partitionToNodeMap is null");
partitionToNodeMap.ifPresent(map -> checkArgument(
map.size() == partitionCount,
"partitionToNodeMap size (%s) must be equal to partitionCount (%s)",
map.size(),
partitionCount));
this.partitionToNodeMap = partitionToNodeMap.map(ImmutableList::copyOf);
}

public int getPartitionCount()
{
return partitionCount;
}

public Optional<int[]> getBucketToPartitionMap()
{
return bucketToPartitionMap;
}

public int getPartition(Split split)
{
checkState(bucketToPartitionMap.isPresent(), "bucketToPartitionMap is expected to be present");
checkState(splitToBucketFunction.isPresent(), "splitToBucketFunction is expected to be present");
int bucket = splitToBucketFunction.get().applyAsInt(split);
checkState(
bucketToPartitionMap.get().length > bucket,
"invalid bucketToPartitionMap size (%s), bucket to partition mapping not found for bucket %s",
bucketToPartitionMap.get().length,
bucket);
return bucketToPartitionMap.get()[bucket];
}

public Optional<InternalNode> getNodeRequirement(int partition)
{
checkArgument(partition < partitionCount, "partition is expected to be less than %s", partitionCount);
return partitionToNodeMap.map(map -> map.get(partition));
}
}
@@ -0,0 +1,90 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.scheduler;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.metadata.InternalNode;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;

import javax.annotation.concurrent.NotThreadSafe;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;

import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static java.util.Objects.requireNonNull;

@NotThreadSafe
public class FaultTolerantPartitioningSchemeFactory
{
private final NodePartitioningManager nodePartitioningManager;
private final Session session;
private final int partitionCount;

private final Map<PartitioningHandle, FaultTolerantPartitioningScheme> cache = new HashMap<>();

public FaultTolerantPartitioningSchemeFactory(NodePartitioningManager nodePartitioningManager, Session session, int partitionCount)
{
this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
this.session = requireNonNull(session, "session is null");
this.partitionCount = partitionCount;
}

public FaultTolerantPartitioningScheme get(PartitioningHandle handle)
{
return cache.computeIfAbsent(handle, this::create);
}

private FaultTolerantPartitioningScheme create(PartitioningHandle partitioningHandle)
{
if (partitioningHandle.equals(FIXED_HASH_DISTRIBUTION)) {
return new FaultTolerantPartitioningScheme(
partitionCount,
Optional.of(IntStream.range(0, partitionCount).toArray()),
Optional.empty(),
Optional.empty());
}
if (partitioningHandle.getCatalogHandle().isPresent()) {
// TODO This caps the number of partitions to the number of available nodes. Perhaps a better approach is required for fault tolerant execution.
BucketNodeMap bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle);
int bucketCount = bucketNodeMap.getBucketCount();
int[] bucketToPartition = new int[bucketCount];
// make sure all buckets mapped to the same node map to the same partition, such that locality requirements are respected in scheduling
Map<InternalNode, Integer> nodeToPartition = new HashMap<>();
List<InternalNode> partitionToNodeMap = new ArrayList<>();
for (int bucket = 0; bucket < bucketCount; bucket++) {
InternalNode node = bucketNodeMap.getAssignedNode(bucket);
Integer partitionId = nodeToPartition.get(node);
if (partitionId == null) {
partitionId = partitionToNodeMap.size();
nodeToPartition.put(node, partitionId);
partitionToNodeMap.add(node);
}
bucketToPartition[bucket] = partitionId;
}
return new FaultTolerantPartitioningScheme(
partitionToNodeMap.size(),
Optional.of(bucketToPartition),
Optional.of(bucketNodeMap.getSplitToBucketFunction()),
Optional.of(ImmutableList.copyOf(partitionToNodeMap)));
}
return new FaultTolerantPartitioningScheme(1, Optional.empty(), Optional.empty(), Optional.empty());
}
}
Expand Up @@ -34,7 +34,6 @@
import io.trino.execution.StageInfo;
import io.trino.execution.TaskId;
import io.trino.failuredetector.FailureDetector;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Metadata;
import io.trino.operator.RetryPolicy;
import io.trino.server.DynamicFilterService;
Expand All @@ -44,7 +43,6 @@
import io.trino.spi.exchange.ExchangeManager;
import io.trino.spi.exchange.ExchangeSourceHandle;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.plan.PlanFragmentId;
Expand All @@ -60,8 +58,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Ticker.systemTicker;
Expand All @@ -76,7 +72,6 @@
import static io.trino.SystemSessionProperties.getRetryPolicy;
import static io.trino.execution.QueryState.FINISHING;
import static io.trino.operator.RetryPolicy.TASK;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
Expand Down Expand Up @@ -214,8 +209,10 @@ private Scheduler createScheduler()
});

Session session = queryStateMachine.getSession();
int partitionCount = getFaultTolerantExecutionPartitionCount(session);
Function<PartitioningHandle, BucketToPartition> bucketToPartitionCache = createBucketToPartitionCache(nodePartitioningManager, session, partitionCount);
FaultTolerantPartitioningSchemeFactory partitioningSchemeFactory = new FaultTolerantPartitioningSchemeFactory(
nodePartitioningManager,
session,
getFaultTolerantExecutionPartitionCount(session));

ImmutableList.Builder<FaultTolerantStageScheduler> schedulers = ImmutableList.builder();
Map<PlanFragmentId, Exchange> exchanges = new HashMap<>();
Expand All @@ -235,9 +232,10 @@ private Scheduler createScheduler()

boolean outputStage = stageManager.getOutputStage().getStageId().equals(stage.getStageId());
ExchangeContext exchangeContext = new ExchangeContext(session.getQueryId(), new ExchangeId("external-exchange-" + stage.getStageId().getId()));
FaultTolerantPartitioningScheme sinkPartitioningScheme = partitioningSchemeFactory.get(fragment.getPartitioningScheme().getPartitioning().getHandle());
Exchange exchange = exchangeManager.createExchange(
exchangeContext,
partitionCount,
sinkPartitioningScheme.getPartitionCount(),
// order of output records for coordinator consumed stages must be preserved as the stage
// may produce sorted dataset (for example an output of a global OrderByOperator)
outputStage);
Expand All @@ -256,7 +254,7 @@ private Scheduler createScheduler()
sourceExchanges.put(childFragmentId, sourceExchange);
}

BucketToPartition inputBucketToPartition = bucketToPartitionCache.apply(fragment.getPartitioning());
FaultTolerantPartitioningScheme sourcePartitioningScheme = partitioningSchemeFactory.get(fragment.getPartitioning());
FaultTolerantStageScheduler scheduler = new FaultTolerantStageScheduler(
session,
stage,
Expand All @@ -269,10 +267,9 @@ private Scheduler createScheduler()
(future, delay) -> scheduledExecutorService.schedule(() -> future.set(null), delay.toMillis(), MILLISECONDS),
systemTicker(),
exchange,
bucketToPartitionCache.apply(fragment.getPartitioningScheme().getPartitioning().getHandle()).getBucketToPartitionMap(),
sinkPartitioningScheme,
sourceExchanges.buildOrThrow(),
inputBucketToPartition.getBucketToPartitionMap(),
inputBucketToPartition.getBucketNodeMap(),
sourcePartitioningScheme,
remainingTaskRetryAttemptsOverall,
taskRetryAttemptsPerTask,
maxTasksWaitingForNodePerStage,
Expand Down Expand Up @@ -511,68 +508,6 @@ private void closeNodeAllocator()
}
}

private static Function<PartitioningHandle, BucketToPartition> createBucketToPartitionCache(NodePartitioningManager nodePartitioningManager, Session session, int partitionCount)
{
Map<PartitioningHandle, BucketToPartition> cachingMap = new HashMap<>();
return partitioningHandle ->
cachingMap.computeIfAbsent(
partitioningHandle,
handle -> createBucketToPartitionMap(session, partitionCount, handle, nodePartitioningManager));
}

private static BucketToPartition createBucketToPartitionMap(
Session session,
int partitionCount,
PartitioningHandle partitioningHandle,
NodePartitioningManager nodePartitioningManager)
{
if (partitioningHandle.equals(FIXED_HASH_DISTRIBUTION)) {
return new BucketToPartition(Optional.of(IntStream.range(0, partitionCount).toArray()), Optional.empty());
}
if (partitioningHandle.getCatalogHandle().isPresent()) {
BucketNodeMap bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle);
int bucketCount = bucketNodeMap.getBucketCount();
int[] bucketToPartition = new int[bucketCount];
// make sure all buckets mapped to the same node map to the same partition, such that locality requirements are respected in scheduling
Map<InternalNode, Integer> nodeToPartition = new HashMap<>();
int nextPartitionId = 0;
for (int bucket = 0; bucket < bucketCount; bucket++) {
InternalNode node = bucketNodeMap.getAssignedNode(bucket);
Integer partitionId = nodeToPartition.get(node);
if (partitionId == null) {
partitionId = nextPartitionId;
nextPartitionId++;
nodeToPartition.put(node, partitionId);
}
bucketToPartition[bucket] = partitionId;
}
return new BucketToPartition(Optional.of(bucketToPartition), Optional.of(bucketNodeMap));
}
return new BucketToPartition(Optional.empty(), Optional.empty());
}

private static class BucketToPartition
{
private final Optional<int[]> bucketToPartitionMap;
private final Optional<BucketNodeMap> bucketNodeMap;

private BucketToPartition(Optional<int[]> bucketToPartitionMap, Optional<BucketNodeMap> bucketNodeMap)
{
this.bucketToPartitionMap = requireNonNull(bucketToPartitionMap, "bucketToPartitionMap is null");
this.bucketNodeMap = requireNonNull(bucketNodeMap, "bucketNodeMap is null");
}

public Optional<int[]> getBucketToPartitionMap()
{
return bucketToPartitionMap;
}

public Optional<BucketNodeMap> getBucketNodeMap()
{
return bucketNodeMap;
}
}

private static boolean isFinishingOrDone(QueryStateMachine queryStateMachine)
{
QueryState queryState = queryStateMachine.getQueryState();
Expand Down
Expand Up @@ -115,11 +115,10 @@ public class FaultTolerantStageScheduler
private final int maxTasksWaitingForNodePerStage;

private final Exchange sinkExchange;
private final Optional<int[]> sinkBucketToPartitionMap;
private final FaultTolerantPartitioningScheme sinkPartitioningScheme;

private final Map<PlanFragmentId, Exchange> sourceExchanges;
private final Optional<int[]> sourceBucketToPartitionMap;
private final Optional<BucketNodeMap> sourceBucketNodeMap;
private final FaultTolerantPartitioningScheme sourcePartitioningScheme;

private final DelayedFutureCompletor futureCompletor;

Expand Down Expand Up @@ -187,10 +186,9 @@ public FaultTolerantStageScheduler(
DelayedFutureCompletor futureCompletor,
Ticker ticker,
Exchange sinkExchange,
Optional<int[]> sinkBucketToPartitionMap,
FaultTolerantPartitioningScheme sinkPartitioningScheme,
Map<PlanFragmentId, Exchange> sourceExchanges,
Optional<int[]> sourceBucketToPartitionMap,
Optional<BucketNodeMap> sourceBucketNodeMap,
FaultTolerantPartitioningScheme sourcePartitioningScheme,
AtomicInteger remainingRetryAttemptsOverall,
int taskRetryAttemptsPerTask,
int maxTasksWaitingForNodePerStage,
Expand All @@ -206,10 +204,9 @@ public FaultTolerantStageScheduler(
this.taskExecutionStats = requireNonNull(taskExecutionStats, "taskExecutionStats is null");
this.futureCompletor = requireNonNull(futureCompletor, "futureCompletor is null");
this.sinkExchange = requireNonNull(sinkExchange, "sinkExchange is null");
this.sinkBucketToPartitionMap = requireNonNull(sinkBucketToPartitionMap, "sinkBucketToPartitionMap is null");
this.sinkPartitioningScheme = requireNonNull(sinkPartitioningScheme, "sinkPartitioningScheme is null");
this.sourceExchanges = ImmutableMap.copyOf(requireNonNull(sourceExchanges, "sourceExchanges is null"));
this.sourceBucketToPartitionMap = requireNonNull(sourceBucketToPartitionMap, "sourceBucketToPartitionMap is null");
this.sourceBucketNodeMap = requireNonNull(sourceBucketNodeMap, "sourceBucketNodeMap is null");
this.sourcePartitioningScheme = requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null");
this.remainingRetryAttemptsOverall = requireNonNull(remainingRetryAttemptsOverall, "remainingRetryAttemptsOverall is null");
this.maxRetryAttemptsPerTask = taskRetryAttemptsPerTask;
this.maxTasksWaitingForNodePerStage = maxTasksWaitingForNodePerStage;
Expand Down Expand Up @@ -277,8 +274,7 @@ public synchronized void schedule()
stage.getFragment(),
exchangeSources,
stage::recordGetSplitTime,
sourceBucketToPartitionMap,
sourceBucketNodeMap);
sourcePartitioningScheme);
}

while (!pendingPartitions.isEmpty() || !queuedPartitions.isEmpty() || !taskSource.isFinished()) {
Expand Down Expand Up @@ -393,7 +389,7 @@ private void startTask(int partition, NodeAllocator.NodeLease nodeLease, MemoryR
node,
partition,
attemptId,
sinkBucketToPartitionMap,
sinkPartitioningScheme.getBucketToPartitionMap(),
outputBuffers,
taskSplits,
allSourcePlanNodeIds,
Expand Down

0 comments on commit 45eb8ef

Please sign in to comment.