From 071c46ebf42dfb9a69fb2a103c8bd9e7e7d40a34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BD=90=E5=AE=B6=E4=B9=90=2826731624=29?= Date: Tue, 16 Jan 2024 12:29:10 +0800 Subject: [PATCH] [#1398] fix(MR)(TEZ): Limit attemptId to 4 bit and move it from 18 bit atomicInt to 21 bit taskAttemptId in 63 bit BlockId. --- .../hadoop/mapred/SortWriteBufferManager.java | 4 +- .../apache/hadoop/mapreduce/RssMRUtils.java | 63 ++++++------------- .../org/apache/tez/common/RssTezUtils.java | 56 +++++------------ .../sort/buffer/WriteBufferManager.java | 2 +- .../uniffle/client/util/ClientUtils.java | 4 +- .../apache/uniffle/common/util/Constants.java | 6 +- 6 files changed, 46 insertions(+), 89 deletions(-) diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java index 46e94f23cb..964d154b8b 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java @@ -401,8 +401,8 @@ List> getWaitSendBuffers() { } // it's run in single thread, and is not thread safe - private int getNextSeqNo(int partitionId) { - partitionToSeqNo.computeIfAbsent(partitionId, key -> 0); + private long getNextSeqNo(int partitionId) { + partitionToSeqNo.putIfAbsent(partitionId, 0); int seqNo = partitionToSeqNo.get(partitionId); partitionToSeqNo.put(partitionId, seqNo + 1); return seqNo; diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java index 67ca72b54c..4ce86f086e 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java @@ -42,27 +42,24 @@ public class RssMRUtils { private static final Logger LOG = LoggerFactory.getLogger(RssMRUtils.class); - private static final int MAX_ATTEMPT_LENGTH = 6; - private static final long MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1; // Class TaskAttemptId have two field id and mapId, rss taskAttemptID have 21 bits, // mapId is 19 bits, id is 2 bits. MR have a trick logic, taskAttemptId will increase // 1000 * (appAttemptId - 1), so we will decrease it. public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int appAttemptId) { - long lowBytes = taskAttemptID.getTaskID().getId(); - if (lowBytes > Constants.MAX_TASK_ATTEMPT_ID) { - throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed"); - } if (appAttemptId < 1) { throw new RssException("appAttemptId " + appAttemptId + " is wrong"); } - long highBytes = (long) taskAttemptID.getId() - (appAttemptId - 1) * 1000; - if (highBytes > MAX_ATTEMPT_ID || highBytes < 0) { + long lowBytes = (long) taskAttemptID.getId() - (appAttemptId - 1) * 1000; + if (lowBytes > Constants.MAX_ATTEMPT_ID) { + throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed"); + } + long highBytes = taskAttemptID.getTaskID().getId(); + if (highBytes > Constants.MAX_TASK_ID || highBytes < 0) { throw new RssException( "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed"); } - return (highBytes << (Constants.TASK_ATTEMPT_ID_MAX_LENGTH + Constants.PARTITION_ID_MAX_LENGTH)) - + lowBytes; + return (highBytes << (Constants.MAX_ATTEMPT_LENGTH)) + lowBytes; } public static TaskAttemptID createMRTaskAttemptId( @@ -71,13 +68,9 @@ public static TaskAttemptID createMRTaskAttemptId( throw new RssException("appAttemptId " + appAttemptId + " is wrong"); } TaskID taskID = - new TaskID(jobID, taskType, (int) (rssTaskAttemptId & Constants.MAX_TASK_ATTEMPT_ID)); + new TaskID(jobID, taskType, (int) (rssTaskAttemptId >> Constants.MAX_ATTEMPT_LENGTH)); return new TaskAttemptID( - taskID, - (int) - (rssTaskAttemptId - >> (Constants.TASK_ATTEMPT_ID_MAX_LENGTH + Constants.PARTITION_ID_MAX_LENGTH)) - + 1000 * (appAttemptId - 1)); + taskID, (int) (rssTaskAttemptId & Constants.MAX_ATTEMPT_ID) + 1000 * (appAttemptId - 1)); } public static ShuffleWriteClient createShuffleClient(JobConf jobConf) { @@ -228,18 +221,18 @@ public static String getString(Configuration rssJobConf, String key, String defa return rssJobConf.get(key, defaultValue); } - public static long getBlockId(long partitionId, long taskAttemptId, int nextSeqNo) { - long attemptId = - taskAttemptId >> (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH); - if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) { + public static long getBlockId(long partitionId, long taskAttemptId, long nextSeqNo) { + if (taskAttemptId < 0 || taskAttemptId > Constants.MAX_TASK_ATTEMPT_ID) { throw new RssException( - "Can't support attemptId [" + attemptId + "], the max value should be " + MAX_ATTEMPT_ID); + "Can't support attemptId [" + + taskAttemptId + + "], the max value should be " + + Constants.MAX_TASK_ATTEMPT_ID); } - long atomicInt = (nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId; - if (atomicInt < 0 || atomicInt > Constants.MAX_SEQUENCE_NO) { + if (nextSeqNo < 0 || nextSeqNo > Constants.MAX_SEQUENCE_NO) { throw new RssException( "Can't support sequence [" - + atomicInt + + nextSeqNo + "], the max value should be " + Constants.MAX_SEQUENCE_NO); } @@ -250,29 +243,13 @@ public static long getBlockId(long partitionId, long taskAttemptId, int nextSeqN + "], the max value should be " + Constants.MAX_PARTITION_ID); } - long taskId = - taskAttemptId - - (attemptId - << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH)); - if (taskId < 0 || taskId > Constants.MAX_TASK_ATTEMPT_ID) { - throw new RssException( - "Can't support taskId[" - + taskId - + "], the max value should be " - + Constants.MAX_TASK_ATTEMPT_ID); - } - return (atomicInt << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH)) + return (nextSeqNo << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH)) + (partitionId << Constants.TASK_ATTEMPT_ID_MAX_LENGTH) - + taskId; + + taskAttemptId; } public static long getTaskAttemptId(long blockId) { - long mapId = blockId & Constants.MAX_TASK_ATTEMPT_ID; - long attemptId = - (blockId >> (Constants.TASK_ATTEMPT_ID_MAX_LENGTH + Constants.PARTITION_ID_MAX_LENGTH)) - & MAX_ATTEMPT_ID; - return (attemptId << (Constants.TASK_ATTEMPT_ID_MAX_LENGTH + Constants.PARTITION_ID_MAX_LENGTH)) - + mapId; + return blockId & Constants.MAX_TASK_ATTEMPT_ID; } public static int estimateTaskConcurrency(JobConf jobConf) { diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java index 4b9f055a49..be1f010ebb 100644 --- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java +++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java @@ -59,9 +59,6 @@ public class RssTezUtils { private static final Logger LOG = LoggerFactory.getLogger(RssTezUtils.class); - private static final int MAX_ATTEMPT_LENGTH = 6; - private static final long MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1; - public static final String HOST_NAME = "hostname"; public static final String UNDERLINE_DELIMITER = "_"; @@ -154,57 +151,40 @@ public static String uniqueIdentifierToAttemptId(String uniqueIdentifier) { return StringUtils.join(ids, "_", 0, 7); } - public static long getBlockId(long partitionId, long taskAttemptId, int nextSeqNo) { + public static long getBlockId(long partitionId, long taskAttemptId, long nextSeqNo) { LOG.info( "GetBlockId, partitionId:{}, taskAttemptId:{}, nextSeqNo:{}", partitionId, taskAttemptId, nextSeqNo); - long attemptId = - taskAttemptId >> (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH); - if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) { + if (taskAttemptId < 0 || taskAttemptId > Constants.MAX_TASK_ATTEMPT_ID) { throw new RssException( - "Can't support attemptId [" + attemptId + "], the max value should be " + MAX_ATTEMPT_ID); + "Can't support taskAttemptId [" + + taskAttemptId + + "], the max value should be " + + Constants.MAX_TASK_ATTEMPT_ID); } - long atomicInt = (nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId; - if (atomicInt < 0 || atomicInt > Constants.MAX_SEQUENCE_NO) { + if (nextSeqNo < 0 || nextSeqNo > Constants.MAX_SEQUENCE_NO) { throw new RssException( "Can't support sequence [" - + atomicInt + + nextSeqNo + "], the max value should be " + Constants.MAX_SEQUENCE_NO); } if (partitionId < 0 || partitionId > Constants.MAX_PARTITION_ID) { throw new RssException( - "Can't support partitionId[" + "Can't support partitionId [" + partitionId + "], the max value should be " + Constants.MAX_PARTITION_ID); } - long taskId = - taskAttemptId - - (attemptId - << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH)); - - if (taskId < 0 || taskId > Constants.MAX_TASK_ATTEMPT_ID) { - throw new RssException( - "Can't support taskId[" - + taskId - + "], the max value should be " - + Constants.MAX_TASK_ATTEMPT_ID); - } - return (atomicInt << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH)) + return (nextSeqNo << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH)) + (partitionId << Constants.TASK_ATTEMPT_ID_MAX_LENGTH) - + taskId; + + taskAttemptId; } public static long getTaskAttemptId(long blockId) { - long mapId = blockId & Constants.MAX_TASK_ATTEMPT_ID; - long attemptId = - (blockId >> (Constants.TASK_ATTEMPT_ID_MAX_LENGTH + Constants.PARTITION_ID_MAX_LENGTH)) - & MAX_ATTEMPT_ID; - return (attemptId << (Constants.TASK_ATTEMPT_ID_MAX_LENGTH + Constants.PARTITION_ID_MAX_LENGTH)) - + mapId; + return blockId & Constants.MAX_TASK_ATTEMPT_ID; } public static int estimateTaskConcurrency(Configuration jobConf, int mapNum, int reduceNum) { @@ -297,18 +277,16 @@ private static int mapVertexId(String vertexName) { } public static long convertTaskAttemptIdToLong(TezTaskAttemptID taskAttemptID) { - long lowBytes = taskAttemptID.getTaskID().getId(); - if (lowBytes > Constants.MAX_TASK_ATTEMPT_ID) { + long lowBytes = taskAttemptID.getId(); + if (lowBytes > Constants.MAX_ATTEMPT_ID) { throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed"); } - long highBytes = taskAttemptID.getId(); - if (highBytes > MAX_ATTEMPT_ID || highBytes < 0) { + long highBytes = taskAttemptID.getTaskID().getId(); + if (highBytes > Constants.MAX_TASK_ID || highBytes < 0) { throw new RssException( "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed."); } - long id = - (highBytes << (Constants.TASK_ATTEMPT_ID_MAX_LENGTH + Constants.PARTITION_ID_MAX_LENGTH)) - + lowBytes; + long id = (highBytes << (Constants.MAX_ATTEMPT_LENGTH)) + lowBytes; LOG.info("ConvertTaskAttemptIdToLong taskAttemptID:{}, id is {}, .", taskAttemptID, id); return id; } diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java index b88beddaf8..ab9124a256 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java @@ -424,7 +424,7 @@ List> getWaitSendBuffers() { return waitSendBuffers; } - private int getNextSeqNo(int partitionId) { + private long getNextSeqNo(int partitionId) { partitionToSeqNo.putIfAbsent(partitionId, 0); int seqNo = partitionToSeqNo.get(partitionId); partitionToSeqNo.put(partitionId, seqNo + 1); diff --git a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java index 60484d0a9c..89f59c9fa1 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java +++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java @@ -35,9 +35,9 @@ public class ClientUtils { // BlockId is long and composed of partitionId, executorId and AtomicInteger. - // AtomicInteger is first 19 bit, max value is 2^19 - 1 + // AtomicInteger is first 18 bit, max value is 2^18 - 1 // partitionId is next 24 bit, max value is 2^24 - 1 - // taskAttemptId is rest of 20 bit, max value is 2^20 - 1 + // taskAttemptId is rest of 21 bit, max value is 2^21 - 1 public static Long getBlockId(long partitionId, long taskAttemptId, long atomicInt) { if (atomicInt < 0 || atomicInt > Constants.MAX_SEQUENCE_NO) { throw new IllegalArgumentException( diff --git a/common/src/main/java/org/apache/uniffle/common/util/Constants.java b/common/src/main/java/org/apache/uniffle/common/util/Constants.java index f49449d08d..3b3c0d81d4 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/Constants.java +++ b/common/src/main/java/org/apache/uniffle/common/util/Constants.java @@ -38,8 +38,10 @@ private Constants() {} public static final long MAX_SEQUENCE_NO = (1 << Constants.ATOMIC_INT_MAX_LENGTH) - 1; public static final long MAX_PARTITION_ID = (1 << Constants.PARTITION_ID_MAX_LENGTH) - 1; public static final long MAX_TASK_ATTEMPT_ID = (1 << Constants.TASK_ATTEMPT_ID_MAX_LENGTH) - 1; - public static final long MAX_TASK_ID = (1 << Constants.TASK_ID_MAX_LENGTH) - 1; - public static final long MAX_ATTEMPT_ID = (1 << Constants.ATTEMPT_ID_MAX_LENGTH) - 1; + public static final int MAX_TASK_LENGTH = 17; + public static final int MAX_ATTEMPT_LENGTH = 4; + public static final long MAX_TASK_ID = (1 << MAX_TASK_LENGTH) - 1; + public static final long MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1; public static final long INVALID_BLOCK_ID = -1L; public static final String KEY_SPLIT_CHAR = "/"; public static final String COMMA_SPLIT_CHAR = ",";