Skip to content

Commit

Permalink
[apache#1608] fix(spark): re-assign only once for the same faulty ser…
Browse files Browse the repository at this point in the history
…ver in one stage
  • Loading branch information
zuston committed Mar 29, 2024
1 parent cbf4f6f commit e704b56
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.JavaUtils;

/**
* Class for holding, 1. partition ID -> shuffle servers mapping. 2. remote storage info
Expand All @@ -41,7 +42,10 @@ public class ShuffleHandleInfo implements Serializable {
private Map<Integer, List<ShuffleServerInfo>> partitionToServers;

// partitionId -> replica -> failover servers
private Map<Integer, Map<Integer, List<ShuffleServerInfo>>> failoverPartitionServers;
private Map<Integer, Map<Integer, ShuffleServerInfo>> failoverPartitionServers;
// todo: support mores replacement servers for one faulty server.
private Map<String, ShuffleServerInfo> faultyServerReplacements;

// shuffle servers which is for store shuffle data
private Set<ShuffleServerInfo> shuffleServersForData;
// remoteStorage used for this job
Expand All @@ -62,16 +66,13 @@ public ShuffleHandleInfo(
this.shuffleServersForData.addAll(ssis);
}
this.remoteStorage = storageInfo;
this.faultyServerReplacements = JavaUtils.newConcurrentMap();
}

public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
return partitionToServers;
}

public Map<Integer, Map<Integer, List<ShuffleServerInfo>>> getFailoverPartitionServers() {
return failoverPartitionServers;
}

public Set<ShuffleServerInfo> getShuffleServersForData() {
return shuffleServersForData;
}
Expand All @@ -83,4 +84,34 @@ public RemoteStorageInfo getRemoteStorage() {
public int getShuffleId() {
return shuffleId;
}

public boolean isExistingFaultyServer(String serverId) {
return faultyServerReplacements.containsKey(serverId);
}

public ShuffleServerInfo useExistingReassignmentForMultiPartitions(
Set<Integer> partitionIds, String faultyServerId) {
return createNewReassignmentForMultiPartitions(partitionIds, faultyServerId, null);
}

public ShuffleServerInfo createNewReassignmentForMultiPartitions(
Set<Integer> partitionIds, String faultyServerId, ShuffleServerInfo replacement) {
if (replacement != null) {
faultyServerReplacements.put(faultyServerId, replacement);
}

replacement = faultyServerReplacements.get(faultyServerId);
for (Integer partitionId : partitionIds) {
List<ShuffleServerInfo> replicaServers = partitionToServers.get(partitionId);
for (int i = 0; i < replicaServers.size(); i++) {
if (replicaServers.get(i).getId().equals(faultyServerId)) {
Map<Integer, ShuffleServerInfo> replicaReplacements =
failoverPartitionServers.computeIfAbsent(
partitionId, k -> JavaUtils.newConcurrentMap());
replicaReplacements.put(i, replacement);
}
}
}
return replacement;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ boolean reassignAllShuffleServersForWholeStage(
int stageId, int stageAttemptNumber, int shuffleId, int numMaps);

ShuffleServerInfo reassignFaultyShuffleServerForTasks(
int shuffleId, Set<String> partitionIds, String faultyShuffleServerId);
int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public boolean reassignAllShuffleServersForWholeStage(

@Override
public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
int shuffleId, Set<String> partitionIds, String faultyShuffleServerId) {
int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
return mock(ShuffleServerInfo.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ public class RssShuffleManager extends RssShuffleManagerBase {
*/
private Map<String, Boolean> serverAssignedInfos = JavaUtils.newConcurrentMap();

private Map<String, ShuffleServerInfo> reassignedFaultyServers = JavaUtils.newConcurrentMap();

public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -888,39 +886,27 @@ public synchronized boolean reassignAllShuffleServersForWholeStage(
}
}

// this is only valid on driver side that exposed to being invoked by grpc server
@Override
public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
int shuffleId, Set<String> partitionIds, String faultyShuffleServerId) {
ShuffleServerInfo newShuffleServerInfo =
reassignedFaultyServers.computeIfAbsent(
faultyShuffleServerId,
id -> {
ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, id);
ShuffleHandleInfo shuffleHandleInfo = shuffleIdToShuffleHandleInfo.get(shuffleId);
for (String partitionId : partitionIds) {
Integer partitionIdInteger = Integer.valueOf(partitionId);
List<ShuffleServerInfo> shuffleServerInfoList =
shuffleHandleInfo.getPartitionToServers().get(partitionIdInteger);
for (int i = 0; i < shuffleServerInfoList.size(); i++) {
if (shuffleServerInfoList.get(i).getId().equals(faultyShuffleServerId)) {
shuffleHandleInfo
.getFailoverPartitionServers()
.computeIfAbsent(partitionIdInteger, k -> Maps.newHashMap());
shuffleHandleInfo
.getFailoverPartitionServers()
.get(partitionIdInteger)
.computeIfAbsent(i, j -> Lists.newArrayList())
.add(newAssignedServer);
}
}
}
return newAssignedServer;
});
return newShuffleServerInfo;
}
int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
ShuffleHandleInfo handleInfo = shuffleIdToShuffleHandleInfo.get(shuffleId);
synchronized (handleInfo) {
// find out whether this server has been marked faulty in this shuffle
// if it has been reassigned, directly return the replacement server.
if (handleInfo.isExistingFaultyServer(faultyShuffleServerId)) {
return handleInfo.useExistingReassignmentForMultiPartitions(
partitionIds, faultyShuffleServerId);
}

public Map<String, ShuffleServerInfo> getReassignedFaultyServers() {
return reassignedFaultyServers;
// get the newer server to replace faulty server.
ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, faultyShuffleServerId);
if (newAssignedServer != null) {
handleInfo.createNewReassignmentForMultiPartitions(
partitionIds, faultyShuffleServerId, newAssignedServer);
}
return newAssignedServer;
}
}

private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
import scala.collection.Seq;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.apache.hadoop.conf.Configuration;
Expand Down Expand Up @@ -145,8 +143,6 @@ public class RssShuffleManager extends RssShuffleManagerBase {
*/
private Map<String, Boolean> serverAssignedInfos;

private Map<String, ShuffleServerInfo> reassignedFaultyServers;

public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.sparkConf = conf;
boolean supportsRelocation =
Expand Down Expand Up @@ -275,7 +271,6 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.shuffleIdToShuffleHandleInfo = JavaUtils.newConcurrentMap();
this.failuresShuffleServerIds = Sets.newHashSet();
this.serverAssignedInfos = JavaUtils.newConcurrentMap();
this.reassignedFaultyServers = JavaUtils.newConcurrentMap();
}

public CompletableFuture<Long> sendData(AddBlockEvent event) {
Expand Down Expand Up @@ -1180,39 +1175,27 @@ public synchronized boolean reassignAllShuffleServersForWholeStage(
}
}

// this is only valid on driver side that exposed to being invoked by grpc server
@Override
public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
int shuffleId, Set<String> partitionIds, String faultyShuffleServerId) {
ShuffleServerInfo newShuffleServerInfo =
reassignedFaultyServers.computeIfAbsent(
faultyShuffleServerId,
id -> {
ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, id);
ShuffleHandleInfo shuffleHandleInfo = shuffleIdToShuffleHandleInfo.get(shuffleId);
for (String partitionId : partitionIds) {
Integer partitionIdInteger = Integer.valueOf(partitionId);
List<ShuffleServerInfo> shuffleServerInfoList =
shuffleHandleInfo.getPartitionToServers().get(partitionIdInteger);
for (int i = 0; i < shuffleServerInfoList.size(); i++) {
if (shuffleServerInfoList.get(i).getId().equals(faultyShuffleServerId)) {
shuffleHandleInfo
.getFailoverPartitionServers()
.computeIfAbsent(partitionIdInteger, k -> Maps.newHashMap());
shuffleHandleInfo
.getFailoverPartitionServers()
.get(partitionIdInteger)
.computeIfAbsent(i, j -> Lists.newArrayList())
.add(newAssignedServer);
}
}
}
return newAssignedServer;
});
return newShuffleServerInfo;
}
int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
ShuffleHandleInfo handleInfo = shuffleIdToShuffleHandleInfo.get(shuffleId);
synchronized (handleInfo) {
// find out whether this server has been marked faulty in this shuffle
// if it has been reassigned, directly return the replacement server.
if (handleInfo.isExistingFaultyServer(faultyShuffleServerId)) {
return handleInfo.useExistingReassignmentForMultiPartitions(
partitionIds, faultyShuffleServerId);
}

public Map<String, ShuffleServerInfo> getReassignedFaultyServers() {
return reassignedFaultyServers;
// get the newer server to replace faulty server.
ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, faultyShuffleServerId);
if (newAssignedServer != null) {
handleInfo.createNewReassignmentForMultiPartitions(
partitionIds, faultyShuffleServerId, newAssignedServer);
}
return newAssignedServer;
}
}

private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -121,6 +122,8 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {

private final BlockingQueue<Object> finishEventQueue = new LinkedBlockingQueue<>();

private final Map<String, ShuffleServerInfo> faultyServers = new HashMap<>();

// Only for tests
@VisibleForTesting
public RssShuffleWriter(
Expand Down Expand Up @@ -461,7 +464,6 @@ private void reSendFailedBlockIds(Set<TrackingBlockStatus> failedBlockStatusSet)
List<ShuffleBlockInfo> failedBlockInfoList = Lists.newArrayList();
Map<ShuffleServerInfo, List<TrackingBlockStatus>> faultyServerToPartitions =
failedBlockStatusSet.stream().collect(Collectors.groupingBy(d -> d.getShuffleServerInfo()));
Map<String, ShuffleServerInfo> faultyServers = shuffleManager.getReassignedFaultyServers();
faultyServerToPartitions.entrySet().stream()
.forEach(
t -> {
Expand Down Expand Up @@ -531,6 +533,9 @@ private ShuffleServerInfo reAssignFaultyShuffleServer(
throw new RssException(
"reassign server response with statusCode[" + response.getStatusCode() + "]");
}
if (response.getShuffleServer() == null) {
throw new RssException("empty newer reassignment server!");
}
return response.getShuffleServer();
} catch (Exception e) {
throw new RssException(
Expand Down
2 changes: 1 addition & 1 deletion proto/src/main/proto/Rss.proto
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ message ReassignServersReponse{

message RssReassignFaultyShuffleServerRequest{
int32 shuffleId = 1;
repeated string partitionIds = 2;
repeated int32 partitionIds = 2;
string faultyShuffleServerId = 3;
}

Expand Down

0 comments on commit e704b56

Please sign in to comment.