diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java index 5a93c2b117..f040d469ea 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java @@ -21,6 +21,7 @@ import java.util.List; import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.function.TupleConsumer; public class AddBlockEvent { @@ -28,20 +29,19 @@ public class AddBlockEvent { private List shuffleDataInfoList; private List processedCallbackChain; + // The var is to indicate if the blocks fail to send, whether the writer will resend to + // re-assignment servers. + // if so, the failure blocks will not be released. + private boolean isResendEnabled = false; + + private TupleConsumer blockProcessedCallback; + public AddBlockEvent(String taskId, List shuffleDataInfoList) { this.taskId = taskId; this.shuffleDataInfoList = shuffleDataInfoList; this.processedCallbackChain = new ArrayList<>(); } - public AddBlockEvent( - String taskId, List shuffleBlockInfoList, Runnable callback) { - this.taskId = taskId; - this.shuffleDataInfoList = shuffleBlockInfoList; - this.processedCallbackChain = new ArrayList<>(); - addCallback(callback); - } - /** @param callback, should not throw any exception and execute fast. */ public void addCallback(Runnable callback) { processedCallbackChain.add(callback); @@ -59,6 +59,23 @@ public List getProcessedCallbackChain() { return processedCallbackChain; } + public void withBlockProcessedCallback( + TupleConsumer blockProcessedCallback) { + this.blockProcessedCallback = blockProcessedCallback; + } + + public TupleConsumer getBlockProcessedCallback() { + return blockProcessedCallback; + } + + public void enableBlockResend() { + this.isResendEnabled = true; + } + + public boolean isBlockResendEnabled() { + return isResendEnabled; + } + @Override public String toString() { return "AddBlockEvent: TaskId[" + taskId + "], " + shuffleDataInfoList; diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java index 30f649f688..32ac55639c 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java @@ -88,14 +88,22 @@ public CompletableFuture send(AddBlockEvent event) { () -> { String taskId = event.getTaskId(); List shuffleBlockInfoList = event.getShuffleDataInfoList(); + SendShuffleDataResult result = null; try { - SendShuffleDataResult result = + result = shuffleWriteClient.sendShuffleData( rssAppId, shuffleBlockInfoList, () -> !isValidTask(taskId)); putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); putFailedBlockSendTracker( taskToFailedBlockSendTracker, taskId, result.getFailedBlockSendTracker()); } finally { + Set succeedBlockIds = result.getSuccessBlockIds(); + for (ShuffleBlockInfo block : shuffleBlockInfoList) { + event + .getBlockProcessedCallback() + .accept(block, succeedBlockIds.contains(block.getBlockId())); + } + List callbackChain = Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST); for (Runnable runnable : callbackChain) { diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index d8261047fc..b8f141bc23 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -410,12 +410,10 @@ private void requestExecutorMemory(long leastMem) { public List buildBlockEvents(List shuffleBlockInfoList) { long totalSize = 0; - long memoryUsed = 0; List events = new ArrayList<>(); List shuffleBlockInfosPerEvent = Lists.newArrayList(); for (ShuffleBlockInfo sbi : shuffleBlockInfoList) { totalSize += sbi.getSize(); - memoryUsed += sbi.getFreeMemory(); shuffleBlockInfosPerEvent.add(sbi); // split shuffle data according to the size if (totalSize > sendSizeLimit) { @@ -427,20 +425,9 @@ public List buildBlockEvents(List shuffleBlockI + totalSize + " bytes"); } - // Use final temporary variables for closures - final long memoryUsedTemp = memoryUsed; - final List shuffleBlocksTemp = shuffleBlockInfosPerEvent; - events.add( - new AddBlockEvent( - taskId, - shuffleBlockInfosPerEvent, - () -> { - freeAllocatedMemory(memoryUsedTemp); - shuffleBlocksTemp.stream().forEach(x -> x.getData().release()); - })); + events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); shuffleBlockInfosPerEvent = Lists.newArrayList(); totalSize = 0; - memoryUsed = 0; } } if (!shuffleBlockInfosPerEvent.isEmpty()) { @@ -453,16 +440,7 @@ public List buildBlockEvents(List shuffleBlockI + " bytes"); } // Use final temporary variables for closures - final long memoryUsedTemp = memoryUsed; - final List shuffleBlocksTemp = shuffleBlockInfosPerEvent; - events.add( - new AddBlockEvent( - taskId, - shuffleBlockInfosPerEvent, - () -> { - freeAllocatedMemory(memoryUsedTemp); - shuffleBlocksTemp.stream().forEach(x -> x.getData().release()); - })); + events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); } return events; } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 2fc0340510..3241b21442 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -112,7 +112,8 @@ public class RssShuffleWriter extends ShuffleWriter { private final Set blockIds = Sets.newConcurrentHashSet(); private TaskContext taskContext; private SparkConf sparkConf; - private boolean taskFailRetry; + private boolean isBlockFailSentRetryEnabled; + private int blockFailSentMaxTimes = 2; /** used by columnar rss shuffle writer implementation */ protected final long taskAttemptId; @@ -189,7 +190,7 @@ private RssShuffleWriter( this.taskFailureCallback = taskFailureCallback; this.taskContext = context; this.sparkConf = sparkConf; - this.taskFailRetry = + this.isBlockFailSentRetryEnabled = sparkConf.getBoolean( RssClientConf.RSS_TASK_FAILED_RETRY_ENABLED.key(), RssClientConf.RSS_TASK_FAILED_RETRY_ENABLED.defaultValue()); @@ -265,8 +266,8 @@ private void writeImpl(Iterator> records) { long recordCount = 0; while (records.hasNext()) { recordCount++; - // Task should fast fail when sending data failed - checkIfBlocksFailed(); + + dataCheckOrRetry(); Product2 record = records.next(); K key = record._1(); @@ -359,6 +360,26 @@ protected List> postBlockEvent( List shuffleBlockInfoList) { List> futures = new ArrayList<>(); for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) { + event.withBlockProcessedCallback( + (block, isSuccessful) -> { + boolean isRelease = false; + if (!isBlockFailSentRetryEnabled) { + isRelease = true; + } else { + if (isSuccessful) { + isRelease = true; + } else { + if (block.getRetryCounter() >= blockFailSentMaxTimes - 1) { + isRelease = true; + } + } + } + + if (isRelease) { + bufferManager.freeAllocatedMemory(block.getFreeMemory()); + block.getData().release(); + } + }); event.addCallback( () -> { boolean ret = finishEventQueue.add(new Object()); @@ -382,7 +403,7 @@ protected void checkBlockSendResult(Set blockIds) { while (true) { try { finishEventQueue.clear(); - checkIfBlocksFailed(); + dataCheckOrRetry(); Set successBlockIds = shuffleManager.getSuccessBlockIds(taskId); blockIds.removeAll(successBlockIds); if (blockIds.isEmpty()) { @@ -418,12 +439,67 @@ protected void checkBlockSendResult(Set blockIds) { } } + private void dataCheckOrRetry() { + if (isBlockFailSentRetryEnabled) { + collectBlocksToResendOrFastFail(); + } else { + if (hasAnyBlockFailure()) { + throw new RssSendFailedException(); + } + } + } + + private boolean hasAnyBlockFailure() { + Set failedBlockIds = shuffleManager.getFailedBlockIds(taskId); + if (!failedBlockIds.isEmpty()) { + LOG.error( + "Errors on sending blocks for task[{}]. {} blocks can't be sent to remote servers: {}", + taskId, + failedBlockIds.size(), + shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers()); + return true; + } + return false; + } + + private void collectBlocksToResendOrFastFail() { + if (!isBlockFailSentRetryEnabled) { + return; + } + + FailedBlockSendTracker failedTracker = shuffleManager.getBlockIdsFailedSendTracker(taskId); + Set failedBlockIds = failedTracker.getFailedBlockIds(); + if (failedBlockIds == null || failedBlockIds.isEmpty()) { + return; + } + + Set resendCandidates = new HashSet<>(); + // to check whether the blocks resent exceed the max resend count. + for (Long blockId : failedBlockIds) { + List retryRecords = failedTracker.getFailedBlockStatus(blockId); + // todo: support retry times by config + if (retryRecords.size() >= blockFailSentMaxTimes) { + LOG.error( + "Partial blocks for taskId: [{}] retry exceeding the max retry times. Fast fail! faulty server list: {}", + taskId, + retryRecords.stream().map(x -> x.getShuffleServerInfo()).collect(Collectors.toSet())); + // fast fail if any blocks failure with multiple retry times + throw new RssSendFailedException(); + } + + // todo: if setting multi replica and another replica is succeed to send, no need to resend + resendCandidates.add(retryRecords.get(retryRecords.size() - 1)); + } + + resendFailedBlocks(resendCandidates); + } + private void checkIfBlocksFailed() { Set failedBlockIds = shuffleManager.getFailedBlockIds(taskId); - if (taskFailRetry && !failedBlockIds.isEmpty()) { + if (isBlockFailSentRetryEnabled && !failedBlockIds.isEmpty()) { Set shouldResendBlockSet = shouldResendBlockStatusSet(failedBlockIds); try { - reSendFailedBlockIds(shouldResendBlockSet); + resendFailedBlocks(shouldResendBlockSet); } catch (Exception e) { LOG.error("resend failed blocks failed.", e); } @@ -456,7 +532,7 @@ private Set shouldResendBlockStatusSet(Set failedBloc return resendBlockStatusSet; } - private void reSendFailedBlockIds(Set failedBlockStatusSet) { + private void resendFailedBlocks(Set failedBlockStatusSet) { List reAssignSeverBlockInfoList = Lists.newArrayList(); List failedBlockInfoList = Lists.newArrayList(); Map> faultyServerToPartitions = @@ -471,36 +547,37 @@ private void reSendFailedBlockIds(Set failedBlockStatusSet) .collect(Collectors.toSet()); ShuffleServerInfo dynamicShuffleServer = faultyServers.get(t.getKey().getId()); if (dynamicShuffleServer == null) { + // todo: merge multiple requests into one. dynamicShuffleServer = reAssignFaultyShuffleServer(partitionIds, t.getKey().getId()); faultyServers.put(t.getKey().getId(), dynamicShuffleServer); } ShuffleServerInfo finalDynamicShuffleServer = dynamicShuffleServer; - failedBlockStatusSet.forEach( - trackingBlockStatus -> { - ShuffleBlockInfo failedBlockInfo = trackingBlockStatus.getShuffleBlockInfo(); - failedBlockInfoList.add(failedBlockInfo); - reAssignSeverBlockInfoList.add( - new ShuffleBlockInfo( - failedBlockInfo.getShuffleId(), - failedBlockInfo.getPartitionId(), - failedBlockInfo.getBlockId(), - failedBlockInfo.getLength(), - failedBlockInfo.getCrc(), - failedBlockInfo.getData(), - Lists.newArrayList(finalDynamicShuffleServer), - failedBlockInfo.getUncompressLength(), - failedBlockInfo.getFreeMemory(), - taskAttemptId)); - }); + for (TrackingBlockStatus blockStatus : failedBlockStatusSet) { + ShuffleBlockInfo failedBlockInfo = blockStatus.getShuffleBlockInfo(); + failedBlockInfoList.add(failedBlockInfo); + ShuffleBlockInfo newBlock = + new ShuffleBlockInfo( + failedBlockInfo.getShuffleId(), + failedBlockInfo.getPartitionId(), + failedBlockInfo.getBlockId(), + failedBlockInfo.getLength(), + failedBlockInfo.getCrc(), + failedBlockInfo.getData(), + Lists.newArrayList(finalDynamicShuffleServer), + failedBlockInfo.getUncompressLength(), + failedBlockInfo.getFreeMemory(), + taskAttemptId); + newBlock.setRetryCounter(failedBlockInfo.getRetryCounter() + 1); + reAssignSeverBlockInfoList.add(newBlock); + } }); - clearFailedBlockIdsStates(failedBlockInfoList, faultyServers); + clearFailedBlockStates(failedBlockInfoList, faultyServers); processShuffleBlockInfos(reAssignSeverBlockInfoList); - checkIfBlocksFailed(); } - private void clearFailedBlockIdsStates( + private void clearFailedBlockStates( List failedBlockInfoList, Map faultyServers) { failedBlockInfoList.forEach( shuffleBlockInfo -> { diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java index 8de75d90d4..45493dc9e2 100644 --- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java +++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java @@ -37,6 +37,8 @@ public class ShuffleBlockInfo { private int uncompressLength; private long freeMemory; + private int retryCounter = 0; + public ShuffleBlockInfo( int shuffleId, int partitionId, @@ -84,6 +86,14 @@ public ShuffleBlockInfo( this.taskAttemptId = taskAttemptId; } + public int getRetryCounter() { + return retryCounter; + } + + public void setRetryCounter(int retryCounter) { + this.retryCounter = retryCounter; + } + public long getBlockId() { return blockId; } diff --git a/common/src/main/java/org/apache/uniffle/common/exception/RssSendFailedException.java b/common/src/main/java/org/apache/uniffle/common/exception/RssSendFailedException.java index 4ae3678fc0..c6a949292b 100644 --- a/common/src/main/java/org/apache/uniffle/common/exception/RssSendFailedException.java +++ b/common/src/main/java/org/apache/uniffle/common/exception/RssSendFailedException.java @@ -18,6 +18,11 @@ package org.apache.uniffle.common.exception; public class RssSendFailedException extends RssException { + + public RssSendFailedException() { + super(""); + } + public RssSendFailedException(String message) { super(message); } diff --git a/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java b/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java new file mode 100644 index 0000000000..2a46387022 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.function; + +@FunctionalInterface +public interface TupleConsumer { + void accept(T t, F f); +}