Skip to content

Commit

Permalink
[apache#1608][part-5] feat(spark3): always use the latest assignment …
Browse files Browse the repository at this point in the history
…and load balance for huge partition
  • Loading branch information
zuston committed Apr 17, 2024
1 parent 60fce8e commit 450a9ab
Show file tree
Hide file tree
Showing 22 changed files with 862 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ public class RssSparkConfig {
.withDescription(
"The memory spill switch triggered by Spark TaskMemoryManager, default value is false.");

public static final ConfigOption<Integer> RSS_PARTITION_REASSIGN_LOAD_BALANCE_SERVER_NUM =
ConfigOptions.key("rss.client.partitionReassign.loadBalanceServerNum")
.intType()
.defaultValue(5)
.withDescription(
"The shuffle server num for load balance of huge partition when partition reassign is triggered.");

public static final ConfigOption<Integer> RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES =
ConfigOptions.key("rss.client.partitionReassign.blockRetryMaxTimes")
.intType()
.defaultValue(1)
.withDescription("The block retry max times when partition reassign is enabled.");

public static final String SPARK_RSS_CONFIG_PREFIX = "spark.";

public static final ConfigEntry<Integer> RSS_PARTITION_NUM_PER_RANGE =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.stream.Collectors;

import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.collections.CollectionUtils;

import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
import org.apache.uniffle.common.RemoteStorageInfo;
Expand All @@ -54,6 +55,8 @@ public class ShuffleHandleInfo implements Serializable {
private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionReplicaAssignedServers;
// faulty servers replacement mapping
private Map<String, Set<ShuffleServerInfo>> faultyServerToReplacements;
// The collection of partition ids that need to be load balanced, such as huge partition.
private Set<Integer> loadBalancePartitionCandidates = new HashSet<>();

public static final ShuffleHandleInfo EMPTY_HANDLE_INFO =
new ShuffleHandleInfo(-1, Collections.EMPTY_MAP, RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
Expand Down Expand Up @@ -131,17 +134,29 @@ protected boolean isMarkedAsFaultyServer(String serverId) {
return faultyServerToReplacements.containsKey(serverId);
}

public Set<ShuffleServerInfo> getExistingReplacements(String faultyServerId) {
public Set<ShuffleServerInfo> getReplacements(String faultyServerId) {
return faultyServerToReplacements.get(faultyServerId);
}

public void updateReassignment(
public void updateAssignment(
Set<Integer> partitionIds, String faultyServerId, Set<ShuffleServerInfo> replacements) {
updateAssignment(partitionIds, faultyServerId, replacements, new HashSet<>());
}

public void updateAssignment(
Set<Integer> partitionIds,
String faultyServerId,
Set<ShuffleServerInfo> replacements,
Set<Integer> needLoadBalancePartitionIds) {
if (replacements == null) {
return;
}

faultyServerToReplacements.put(faultyServerId, replacements);

if (CollectionUtils.isNotEmpty(needLoadBalancePartitionIds)) {
loadBalancePartitionCandidates.addAll(needLoadBalancePartitionIds);
}

// todo: optimize the multiple for performance
for (Integer partitionId : partitionIds) {
Map<Integer, List<ShuffleServerInfo>> replicaServers =
Expand Down Expand Up @@ -180,6 +195,39 @@ public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
return partitionToServers;
}

/**
* key: partitionId, value: the servers for multiple replicas.
*
* <p>Leveraging the partition reassign mechanism, it could support multiple servers for one
* partition replica for huge partition load balance or reassignment multiple times. But it will
* use the different policies.
*
* <p>For the former, this will use the hash to get one from the candidates. For the latter, this
* will always get the last one that is available for now.
*/
public Map<Integer, List<ShuffleServerInfo>> getLatestAssignmentPlan(long taskAttemptId) {
Map<Integer, List<ShuffleServerInfo>> plan = new HashMap<>();
for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
partitionReplicaAssignedServers.entrySet()) {
int partitionId = entry.getKey();
boolean isNeedLoadBalance = loadBalancePartitionCandidates.contains(partitionId);
Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
replicaServers.entrySet()) {
ShuffleServerInfo candidate;
int candidateSize = replicaServerEntry.getValue().size();
// todo: loop find the next candidate if current candidate is in faulty list.
if (isNeedLoadBalance) {
candidate = replicaServerEntry.getValue().get((int) (taskAttemptId % candidateSize));
} else {
candidate = replicaServerEntry.getValue().get(candidateSize - 1);
}
plan.computeIfAbsent(partitionId, x -> new ArrayList<>()).add(candidate);
}
}
return plan;
}

public PartitionDataReplicaRequirementTracking createPartitionReplicaTracking() {
PartitionDataReplicaRequirementTracking replicaRequirement =
new PartitionDataReplicaRequirementTracking(shuffleId, partitionReplicaAssignedServers);
Expand Down Expand Up @@ -224,6 +272,9 @@ public static RssProtos.ShuffleHandleInfo toProto(ShuffleHandleInfo handleInfo)
}

public static ShuffleHandleInfo fromProto(RssProtos.ShuffleHandleInfo handleProto) {
if (handleProto == null) {
return null;
}
Map<Integer, Map<Integer, List<ShuffleServerInfo>>> partitionToServers = new HashMap<>();
for (Map.Entry<Integer, RssProtos.PartitionReplicaServers> entry :
handleProto.getPartitionToServersMap().entrySet()) {
Expand All @@ -247,4 +298,8 @@ public static ShuffleHandleInfo fromProto(RssProtos.ShuffleHandleInfo handleProt
handle.remoteStorage = remoteStorageInfo;
return handle;
}

public Set<String> listFaultyServers() {
return faultyServerToReplacements.keySet();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.writer;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.shuffle.ShuffleHandleInfo;

import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;

/** This class is to wrap the shuffleHandleInfo to speed up the partitionAssignment getting. */
public class ShuffleHandleInfoWrapper {
private ShuffleHandleInfo handle;
private final long taskAttemptId;
private final Set<String> faultyServers;
private Map<Integer, List<ShuffleServerInfo>> latestAssignment;

public ShuffleHandleInfoWrapper(long taskAttemptId, ShuffleHandleInfo shuffleHandleInfo) {
this.taskAttemptId = taskAttemptId;
this.faultyServers = new HashSet<>();
this.update(shuffleHandleInfo, null);
}

public List<ShuffleServerInfo> retrievePartitionAssignment(int taskAttemptId) {
return latestAssignment.get(taskAttemptId);
}

public boolean isReassigned(String serverId) {
return faultyServers.contains(serverId);
}

public void update(ShuffleHandleInfo handle, String faultyServerId) {
if (handle == null) {
throw new RssException("Errors on updating shuffle handle by the empty handleInfo.");
}
this.handle = handle;
this.latestAssignment = handle.getLatestAssignmentPlan(taskAttemptId);
if (faultyServerId != null) {
this.faultyServers.add(faultyServerId);
}
}

// Only for tests
@VisibleForTesting
public ShuffleHandleInfo getRef() {
return handle;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ public class WriteBufferManager extends MemoryConsumer {
private ShuffleWriteMetrics shuffleWriteMetrics;
// cache partition -> records
private Map<Integer, WriterBuffer> buffers;
private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
private int serializerBufferSize;
private int bufferSegmentSize;
private long copyTime = 0;
Expand All @@ -98,6 +97,7 @@ public class WriteBufferManager extends MemoryConsumer {
private int memorySpillTimeoutSec;
private boolean isRowBased;
private BlockIdLayout blockIdLayout;
private Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc;

public WriteBufferManager(
int shuffleId,
Expand Down Expand Up @@ -127,19 +127,18 @@ public WriteBufferManager(
long taskAttemptId,
BufferManagerOptions bufferManagerOptions,
Serializer serializer,
Map<Integer, List<ShuffleServerInfo>> partitionToServers,
TaskMemoryManager taskMemoryManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssConf rssConf,
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc) {
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc,
Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc) {
super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
this.bufferSize = bufferManagerOptions.getBufferSize();
this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
this.buffers = Maps.newHashMap();
this.shuffleId = shuffleId;
this.taskId = taskId;
this.taskAttemptId = taskAttemptId;
this.partitionToServers = partitionToServers;
this.shuffleWriteMetrics = shuffleWriteMetrics;
this.serializerBufferSize = bufferManagerOptions.getSerializerBufferSize();
this.bufferSegmentSize = bufferManagerOptions.getBufferSegmentSize();
Expand All @@ -164,6 +163,31 @@ public WriteBufferManager(
this.memorySpillTimeoutSec = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
this.memorySpillEnabled = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_ENABLED);
this.blockIdLayout = BlockIdLayout.from(rssConf);
this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
}

public WriteBufferManager(
int shuffleId,
String taskId,
long taskAttemptId,
BufferManagerOptions bufferManagerOptions,
Serializer serializer,
Map<Integer, List<ShuffleServerInfo>> partitionToServers,
TaskMemoryManager taskMemoryManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssConf rssConf,
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc) {
this(
shuffleId,
taskId,
taskAttemptId,
bufferManagerOptions,
serializer,
taskMemoryManager,
shuffleWriteMetrics,
rssConf,
spillFunc,
partitionId -> partitionToServers.get(partitionId));
}

/** add serialized columnar data directly when integrate with gluten */
Expand Down Expand Up @@ -344,7 +368,7 @@ protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer wb)
compressed.length,
crc32,
compressed,
partitionToServers.get(partitionId),
partitionAssignmentRetrieveFunc.apply(partitionId),
uncompressLength,
wb.getMemoryUsed(),
taskAttemptId);
Expand Down Expand Up @@ -573,4 +597,9 @@ public void setSpillFunc(
public void setSendSizeLimit(long sendSizeLimit) {
this.sendSizeLimit = sendSizeLimit;
}

public void setPartitionAssignmentRetrieveFunc(
Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc) {
this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.ShuffleHandleInfo;

import org.apache.uniffle.common.ShuffleServerInfo;

/**
* This is a proxy interface that mainly delegates the un-registration of shuffles to the
* MapOutputTrackerMaster on the driver. It provides a unified interface that hides implementation
Expand Down Expand Up @@ -78,6 +76,9 @@ public interface RssShuffleManagerInterface {
boolean reassignAllShuffleServersForWholeStage(
int stageId, int stageAttemptNumber, int shuffleId, int numMaps);

ShuffleServerInfo reassignFaultyShuffleServerForTasks(
int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId);
ShuffleHandleInfo reassignFaultyShuffleServerForTasks(
int shuffleId,
Set<Integer> partitionIds,
String faultyShuffleServerId,
Set<Integer> needLoadBalancePartitionIds);
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,17 @@ public void reassignShuffleServers(
public void reassignFaultyShuffleServer(
RssProtos.RssReassignFaultyShuffleServerRequest request,
StreamObserver<RssProtos.RssReassignFaultyShuffleServerResponse> responseObserver) {
ShuffleServerInfo shuffleServerInfo =
ShuffleHandleInfo handle =
shuffleManager.reassignFaultyShuffleServerForTasks(
request.getShuffleId(),
Sets.newHashSet(request.getPartitionIdsList()),
request.getFaultyShuffleServerId());
request.getFaultyShuffleServerId(),
Sets.newHashSet(request.getNeedLoadBalancePartitionIdsList()));
RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
RssProtos.RssReassignFaultyShuffleServerResponse reply =
RssProtos.RssReassignFaultyShuffleServerResponse.newBuilder()
.setStatus(code)
.setServer(ShuffleServerInfo.convertToShuffleServerId(shuffleServerInfo))
.setHandle(ShuffleHandleInfo.toProto(handle))
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void testReassignment() {

assertFalse(handleInfo.isMarkedAsFaultyServer("a"));
Set<Integer> partitions = Sets.newHashSet(1);
handleInfo.updateReassignment(partitions, "a", Sets.newHashSet(createFakeServerInfo("d")));
handleInfo.updateAssignment(partitions, "a", Sets.newHashSet(createFakeServerInfo("d")));
assertTrue(handleInfo.isMarkedAsFaultyServer("a"));
}

Expand All @@ -66,7 +66,7 @@ public void testListAllPartitionAssignmentServers() {

// case1
Set<Integer> partitions = Sets.newHashSet(2);
handleInfo.updateReassignment(partitions, "c", Sets.newHashSet(createFakeServerInfo("d")));
handleInfo.updateAssignment(partitions, "c", Sets.newHashSet(createFakeServerInfo("d")));

Map<Integer, List<ShuffleServerInfo>> partitionAssignment =
handleInfo.listPartitionAssignedServers();
Expand All @@ -77,15 +77,15 @@ public void testListAllPartitionAssignmentServers() {

// case2: reassign multiple times for one partition, it will not append the same replacement
// servers
handleInfo.updateReassignment(partitions, "c", Sets.newHashSet(createFakeServerInfo("d")));
handleInfo.updateAssignment(partitions, "c", Sets.newHashSet(createFakeServerInfo("d")));
partitionAssignment = handleInfo.listPartitionAssignedServers();
assertEquals(
Arrays.asList(createFakeServerInfo("c"), createFakeServerInfo("d")),
partitionAssignment.get(2));

// case3: reassign multiple times for one partition, it will append the non-existing replacement
// servers
handleInfo.updateReassignment(
handleInfo.updateAssignment(
partitions, "c", Sets.newHashSet(createFakeServerInfo("d"), createFakeServerInfo("e")));
partitionAssignment = handleInfo.listPartitionAssignedServers();
assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

import org.apache.spark.shuffle.ShuffleHandleInfo;

import org.apache.uniffle.common.ShuffleServerInfo;

import static org.mockito.Mockito.mock;

public class DummyRssShuffleManager implements RssShuffleManagerInterface {
Expand Down Expand Up @@ -69,8 +67,11 @@ public boolean reassignAllShuffleServersForWholeStage(
}

@Override
public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
return mock(ShuffleServerInfo.class);
public ShuffleHandleInfo reassignFaultyShuffleServerForTasks(
int shuffleId,
Set<Integer> partitionIds,
String faultyShuffleServerId,
Set<Integer> needLoadBalancePartitionIds) {
return mock(ShuffleHandleInfo.class);
}
}
Loading

0 comments on commit 450a9ab

Please sign in to comment.