From 40cfd3eaf967c9594ffccf8f400cd07875ca4e27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BD=90=E5=AE=B6=E4=B9=90=2826731624=29?= Date: Wed, 17 Jan 2024 17:06:23 +0800 Subject: [PATCH] [#1398] [FOLLOW UP] --- .../hadoop/mapred/SortWriteBufferManager.java | 4 +- .../apache/hadoop/mapreduce/MRIdHelper.java | 28 ---------- .../apache/hadoop/mapreduce/RssMRUtils.java | 44 ++++------------ .../mapreduce/task/reduce/RssShuffle.java | 4 +- .../hadoop/mapreduce/RssMRUtilsTest.java | 7 +-- .../spark/shuffle/RssShuffleManager.java | 8 ++- .../org/apache/tez/common/RssTezUtils.java | 51 +++---------------- .../org/apache/tez/common/TezIdHelper.java | 28 ---------- .../shuffle/impl/RssTezFetcherTask.java | 4 +- .../orderedgrouped/RssShuffleScheduler.java | 4 +- .../sort/buffer/WriteBufferManager.java | 4 +- .../apache/tez/common/RssTezUtilsTest.java | 11 ++-- .../apache/tez/common/TezIdHelperTest.java | 37 -------------- .../uniffle/client/util/ClientUtils.java | 12 +++-- .../apache/uniffle/common/util/Constants.java | 6 +-- 15 files changed, 52 insertions(+), 200 deletions(-) delete mode 100644 client-mr/core/src/main/java/org/apache/hadoop/mapreduce/MRIdHelper.java delete mode 100644 client-tez/src/main/java/org/apache/tez/common/TezIdHelper.java delete mode 100644 client-tez/src/test/java/org/apache/tez/common/TezIdHelperTest.java diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java index 964d154b8b..585641f49e 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java @@ -37,12 +37,12 @@ import com.google.common.util.concurrent.Uninterruptibles; import org.apache.hadoop.io.RawComparator; import org.apache.hadoop.io.serializer.Serializer; -import org.apache.hadoop.mapreduce.RssMRUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.response.SendShuffleDataResult; +import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.compression.Codec; @@ -341,7 +341,7 @@ ShuffleBlockInfo createShuffleBlock(SortWriteBuffer wb) { final long crc32 = ChecksumUtils.getCrc32(compressed); compressTime += System.currentTimeMillis() - start; final long blockId = - RssMRUtils.getBlockId(partitionId, taskAttemptId, getNextSeqNo(partitionId)); + ClientUtils.getBlockId(partitionId, taskAttemptId, getNextSeqNo(partitionId)); uncompressedDataLen += data.length; // add memory to indicate bytes which will be sent to shuffle server inSendListBytes.addAndGet(wb.getDataLength()); diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/MRIdHelper.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/MRIdHelper.java deleted file mode 100644 index e20a3d87fc..0000000000 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/MRIdHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.mapreduce; - -import org.apache.uniffle.common.util.IdHelper; - -public class MRIdHelper implements IdHelper { - - @Override - public long getTaskAttemptId(long blockId) { - return RssMRUtils.getTaskAttemptId(blockId); - } -} diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java index 4ce86f086e..63ba9df678 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java @@ -50,16 +50,17 @@ public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int a if (appAttemptId < 1) { throw new RssException("appAttemptId " + appAttemptId + " is wrong"); } - long lowBytes = (long) taskAttemptID.getId() - (appAttemptId - 1) * 1000; - if (lowBytes > Constants.MAX_ATTEMPT_ID) { - throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed"); + long highBytes = (long) taskAttemptID.getId() - (appAttemptId - 1) * 1000L; + if (highBytes > Constants.MAX_ATTEMPT_ID) { + throw new RssException( + "TaskAttempt " + taskAttemptID + " low bytes " + highBytes + " exceed"); } - long highBytes = taskAttemptID.getTaskID().getId(); - if (highBytes > Constants.MAX_TASK_ID || highBytes < 0) { + long lowBytes = taskAttemptID.getTaskID().getId(); + if (lowBytes > Constants.MAX_TASK_ID || highBytes < 0) { throw new RssException( - "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed"); + "TaskAttempt " + taskAttemptID + " high bytes " + lowBytes + " exceed"); } - return (highBytes << (Constants.MAX_ATTEMPT_LENGTH)) + lowBytes; + return (highBytes << (Constants.TASK_ID_MAX_LENGTH)) + lowBytes; } public static TaskAttemptID createMRTaskAttemptId( @@ -68,7 +69,7 @@ public static TaskAttemptID createMRTaskAttemptId( throw new RssException("appAttemptId " + appAttemptId + " is wrong"); } TaskID taskID = - new TaskID(jobID, taskType, (int) (rssTaskAttemptId >> Constants.MAX_ATTEMPT_LENGTH)); + new TaskID(jobID, taskType, (int) (rssTaskAttemptId >> Constants.ATTEMPT_ID_MAX_LENGTH)); return new TaskAttemptID( taskID, (int) (rssTaskAttemptId & Constants.MAX_ATTEMPT_ID) + 1000 * (appAttemptId - 1)); } @@ -221,33 +222,6 @@ public static String getString(Configuration rssJobConf, String key, String defa return rssJobConf.get(key, defaultValue); } - public static long getBlockId(long partitionId, long taskAttemptId, long nextSeqNo) { - if (taskAttemptId < 0 || taskAttemptId > Constants.MAX_TASK_ATTEMPT_ID) { - throw new RssException( - "Can't support attemptId [" - + taskAttemptId - + "], the max value should be " - + Constants.MAX_TASK_ATTEMPT_ID); - } - if (nextSeqNo < 0 || nextSeqNo > Constants.MAX_SEQUENCE_NO) { - throw new RssException( - "Can't support sequence [" - + nextSeqNo - + "], the max value should be " - + Constants.MAX_SEQUENCE_NO); - } - if (partitionId < 0 || partitionId > Constants.MAX_PARTITION_ID) { - throw new RssException( - "Can't support partitionId[" - + partitionId - + "], the max value should be " - + Constants.MAX_PARTITION_ID); - } - return (nextSeqNo << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH)) - + (partitionId << Constants.TASK_ATTEMPT_ID_MAX_LENGTH) - + taskAttemptId; - } - public static long getTaskAttemptId(long blockId) { return blockId & Constants.MAX_TASK_ATTEMPT_ID; } diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java index 47d0615464..81b8c1c852 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java @@ -33,7 +33,6 @@ import org.apache.hadoop.mapred.Task; import org.apache.hadoop.mapred.TaskStatus; import org.apache.hadoop.mapred.TaskUmbilicalProtocol; -import org.apache.hadoop.mapreduce.MRIdHelper; import org.apache.hadoop.mapreduce.RssMRConfig; import org.apache.hadoop.mapreduce.RssMRUtils; import org.apache.hadoop.util.Progress; @@ -43,6 +42,7 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.request.CreateShuffleReadClientRequest; +import org.apache.uniffle.client.util.DefaultIdHelper; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.hadoop.shim.HadoopShimImpl; @@ -229,7 +229,7 @@ public RawKeyValueIterator run() throws IOException, InterruptedException { taskIdBitmap, serverInfoList, readerJobConf, - new MRIdHelper(), + new DefaultIdHelper(), expectedTaskIdsBitmapFilterEnable, RssMRConfig.toRssConf(rssJobConf)); ShuffleReadClient shuffleReadClient = diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java index cb5c2c6562..e0c9ab8cb5 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java @@ -23,6 +23,7 @@ import org.apache.hadoop.mapred.JobConf; import org.junit.jupiter.api.Test; +import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.client.util.RssClientConfig; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.Constants; @@ -70,10 +71,10 @@ public void blockConvertTest() { TaskID taskId = new TaskID(jobID, TaskType.MAP, 233); TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1); long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1); - long blockId = RssMRUtils.getBlockId(1, taskAttemptId, 0); + long blockId = ClientUtils.getBlockId(1, taskAttemptId, 0); long newTaskAttemptId = RssMRUtils.getTaskAttemptId(blockId); assertEquals(taskAttemptId, newTaskAttemptId); - blockId = RssMRUtils.getBlockId(2, taskAttemptId, 2); + blockId = ClientUtils.getBlockId(2, taskAttemptId, 2); newTaskAttemptId = RssMRUtils.getTaskAttemptId(blockId); assertEquals(taskAttemptId, newTaskAttemptId); } @@ -87,7 +88,7 @@ public void partitionIdConvertBlockTest() { long mask = (1L << Constants.PARTITION_ID_MAX_LENGTH) - 1; for (int partitionId = 0; partitionId <= 3000; partitionId++) { for (int seqNo = 0; seqNo <= 10; seqNo++) { - long blockId = RssMRUtils.getBlockId(Long.valueOf(partitionId), taskAttemptId, seqNo); + long blockId = ClientUtils.getBlockId(Long.valueOf(partitionId), taskAttemptId, seqNo); int newPartitionId = Math.toIntExact((blockId >> Constants.TASK_ATTEMPT_ID_MAX_LENGTH) & mask); assertEquals(partitionId, newPartitionId); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index d1473b66b6..b189706ca1 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -454,8 +454,12 @@ public ShuffleWriter getWriter( } else { writeMetrics = context.taskMetrics().shuffleWriteMetrics(); } - LOG.info("RssHandle appId is {}, shuffleId is {}, taskId is {}, attemptId is {}", rssHandle.getAppId(), - rssHandle.getShuffleId(), taskId, taskAttemptId); + LOG.info( + "RssHandle appId is {}, shuffleId is {}, taskId is {}, attemptId is {}", + rssHandle.getAppId(), + rssHandle.getShuffleId(), + taskId, + taskAttemptId); return new RssShuffleWriter<>( rssHandle.getAppId(), shuffleId, diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java index be1f010ebb..e30dc1a185 100644 --- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java +++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java @@ -151,42 +151,6 @@ public static String uniqueIdentifierToAttemptId(String uniqueIdentifier) { return StringUtils.join(ids, "_", 0, 7); } - public static long getBlockId(long partitionId, long taskAttemptId, long nextSeqNo) { - LOG.info( - "GetBlockId, partitionId:{}, taskAttemptId:{}, nextSeqNo:{}", - partitionId, - taskAttemptId, - nextSeqNo); - if (taskAttemptId < 0 || taskAttemptId > Constants.MAX_TASK_ATTEMPT_ID) { - throw new RssException( - "Can't support taskAttemptId [" - + taskAttemptId - + "], the max value should be " - + Constants.MAX_TASK_ATTEMPT_ID); - } - if (nextSeqNo < 0 || nextSeqNo > Constants.MAX_SEQUENCE_NO) { - throw new RssException( - "Can't support sequence [" - + nextSeqNo - + "], the max value should be " - + Constants.MAX_SEQUENCE_NO); - } - if (partitionId < 0 || partitionId > Constants.MAX_PARTITION_ID) { - throw new RssException( - "Can't support partitionId [" - + partitionId - + "], the max value should be " - + Constants.MAX_PARTITION_ID); - } - return (nextSeqNo << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH)) - + (partitionId << Constants.TASK_ATTEMPT_ID_MAX_LENGTH) - + taskAttemptId; - } - - public static long getTaskAttemptId(long blockId) { - return blockId & Constants.MAX_TASK_ATTEMPT_ID; - } - public static int estimateTaskConcurrency(Configuration jobConf, int mapNum, int reduceNum) { double dynamicFactor = jobConf.getDouble( @@ -277,16 +241,17 @@ private static int mapVertexId(String vertexName) { } public static long convertTaskAttemptIdToLong(TezTaskAttemptID taskAttemptID) { - long lowBytes = taskAttemptID.getId(); - if (lowBytes > Constants.MAX_ATTEMPT_ID) { - throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed"); + long highBytes = taskAttemptID.getId(); + if (highBytes > Constants.MAX_ATTEMPT_ID) { + throw new RssException( + "TaskAttempt " + taskAttemptID + " low bytes " + highBytes + " exceed"); } - long highBytes = taskAttemptID.getTaskID().getId(); - if (highBytes > Constants.MAX_TASK_ID || highBytes < 0) { + long lowBytes = taskAttemptID.getTaskID().getId(); + if (lowBytes > Constants.MAX_TASK_ID || lowBytes < 0) { throw new RssException( - "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed."); + "TaskAttempt " + taskAttemptID + " high bytes " + lowBytes + " exceed."); } - long id = (highBytes << (Constants.MAX_ATTEMPT_LENGTH)) + lowBytes; + long id = (highBytes << (Constants.TASK_ID_MAX_LENGTH)) + lowBytes; LOG.info("ConvertTaskAttemptIdToLong taskAttemptID:{}, id is {}, .", taskAttemptID, id); return id; } diff --git a/client-tez/src/main/java/org/apache/tez/common/TezIdHelper.java b/client-tez/src/main/java/org/apache/tez/common/TezIdHelper.java deleted file mode 100644 index 79748ecaf9..0000000000 --- a/client-tez/src/main/java/org/apache/tez/common/TezIdHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.tez.common; - -import org.apache.uniffle.common.util.IdHelper; - -public class TezIdHelper implements IdHelper { - - @Override - public long getTaskAttemptId(long blockId) { - return RssTezUtils.getTaskAttemptId(blockId); - } -} diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java index 65ee99b4bb..2c693dab93 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java @@ -29,7 +29,6 @@ import org.apache.tez.common.CallableWithNdc; import org.apache.tez.common.RssTezConfig; import org.apache.tez.common.RssTezUtils; -import org.apache.tez.common.TezIdHelper; import org.apache.tez.runtime.api.InputContext; import org.apache.tez.runtime.library.common.InputAttemptIdentifier; import org.apache.tez.runtime.library.common.shuffle.FetchResult; @@ -43,6 +42,7 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.request.CreateShuffleReadClientRequest; +import org.apache.uniffle.client.util.DefaultIdHelper; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.util.UnitConverter; @@ -191,7 +191,7 @@ protected FetchResult callInternal() throws Exception { taskIdBitmap, new ArrayList<>(serverInfoSet), hadoopConf, - new TezIdHelper(), + new DefaultIdHelper(), expectedTaskIdsBitmapFilterEnable, RssTezConfig.toRssConf(this.conf)); ShuffleReadClient shuffleReadClient = diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java index 89146ae945..e557a1689c 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java @@ -66,7 +66,6 @@ import org.apache.tez.common.InputContextUtils; import org.apache.tez.common.RssTezConfig; import org.apache.tez.common.RssTezUtils; -import org.apache.tez.common.TezIdHelper; import org.apache.tez.common.TezUtilsInternal; import org.apache.tez.common.UmbilicalUtils; import org.apache.tez.common.counters.TaskCounter; @@ -96,6 +95,7 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.request.CreateShuffleReadClientRequest; +import org.apache.uniffle.client.util.DefaultIdHelper; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.exception.RssException; @@ -1863,7 +1863,7 @@ private RssTezShuffleDataFetcher constructRssFetcherForPartition( taskIdBitmap, shuffleServerInfoList, hadoopConf, - new TezIdHelper(), + new DefaultIdHelper(), expectedTaskIdsBitmapFilterEnable, RssTezConfig.toRssConf(conf)); diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java index ab9124a256..9283e12907 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java @@ -37,7 +37,6 @@ import com.google.common.util.concurrent.Uninterruptibles; import org.apache.hadoop.io.RawComparator; import org.apache.hadoop.io.serializer.Serializer; -import org.apache.tez.common.RssTezUtils; import org.apache.tez.common.counters.TezCounter; import org.apache.tez.dag.records.TezDAGID; import org.apache.tez.dag.records.TezTaskAttemptID; @@ -47,6 +46,7 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.response.SendShuffleDataResult; +import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.compression.Codec; @@ -359,7 +359,7 @@ ShuffleBlockInfo createShuffleBlock(WriteBuffer wb) { final long crc32 = ChecksumUtils.getCrc32(compressed); compressTime += System.currentTimeMillis() - start; final long blockId = - RssTezUtils.getBlockId((long) partitionId, taskAttemptId, getNextSeqNo(partitionId)); + ClientUtils.getBlockId(partitionId, taskAttemptId, getNextSeqNo(partitionId)); LOG.info("blockId is {}", blockId); uncompressedDataLen += data.length; // add memory to indicate bytes which will be sent to shuffle server diff --git a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java index c21173d15e..01a0e033b2 100644 --- a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java +++ b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java @@ -31,6 +31,7 @@ import org.apache.tez.dag.records.TezVertexID; import org.junit.jupiter.api.Test; +import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.Constants; @@ -80,11 +81,11 @@ public void blockConvertTest() { TezTaskID tId = TezTaskID.getInstance(vId, 389); TezTaskAttemptID tezTaskAttemptId = TezTaskAttemptID.getInstance(tId, 2); long taskAttemptId = RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId); - long blockId = RssTezUtils.getBlockId(1, taskAttemptId, 0); - long newTaskAttemptId = RssTezUtils.getTaskAttemptId(blockId); + long blockId = ClientUtils.getBlockId(1, taskAttemptId, 0); + long newTaskAttemptId = ClientUtils.getTaskAttemptId(blockId); assertEquals(taskAttemptId, newTaskAttemptId); - blockId = RssTezUtils.getBlockId(2, taskAttemptId, 2); - newTaskAttemptId = RssTezUtils.getTaskAttemptId(blockId); + blockId = ClientUtils.getBlockId(2, taskAttemptId, 2); + newTaskAttemptId = ClientUtils.getTaskAttemptId(blockId); assertEquals(taskAttemptId, newTaskAttemptId); } @@ -99,7 +100,7 @@ public void testPartitionIdConvertBlock() { long mask = (1L << Constants.PARTITION_ID_MAX_LENGTH) - 1; for (int partitionId = 0; partitionId <= 3000; partitionId++) { for (int seqNo = 0; seqNo <= 10; seqNo++) { - long blockId = RssTezUtils.getBlockId(Long.valueOf(partitionId), taskAttemptId, seqNo); + long blockId = ClientUtils.getBlockId(partitionId, taskAttemptId, seqNo); int newPartitionId = Math.toIntExact((blockId >> Constants.TASK_ATTEMPT_ID_MAX_LENGTH) & mask); assertEquals(partitionId, newPartitionId); diff --git a/client-tez/src/test/java/org/apache/tez/common/TezIdHelperTest.java b/client-tez/src/test/java/org/apache/tez/common/TezIdHelperTest.java deleted file mode 100644 index 739e897795..0000000000 --- a/client-tez/src/test/java/org/apache/tez/common/TezIdHelperTest.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.tez.common; - -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class TezIdHelperTest { - - @Test - public void testTetTaskAttemptId() { - TezIdHelper tezIdHelper = new TezIdHelper(); - assertEquals(0, tezIdHelper.getTaskAttemptId(27262976)); - assertEquals(1, tezIdHelper.getTaskAttemptId(27262977)); - assertEquals( - 0, RssTezUtils.taskIdStrToTaskId("attempt_1680867852986_0012_1_01_000000_0_10003")); - assertEquals( - tezIdHelper.getTaskAttemptId(27262976), - RssTezUtils.taskIdStrToTaskId("attempt_1680867852986_0012_1_01_000000_0_10003")); - } -} diff --git a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java index 89f59c9fa1..76404702c8 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java +++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java @@ -34,6 +34,10 @@ public class ClientUtils { + public static long getTaskAttemptId(long blockId) { + return blockId & Constants.MAX_TASK_ATTEMPT_ID; + } + // BlockId is long and composed of partitionId, executorId and AtomicInteger. // AtomicInteger is first 18 bit, max value is 2^18 - 1 // partitionId is next 24 bit, max value is 2^24 - 1 @@ -68,10 +72,7 @@ public static Long getBlockId(long partitionId, long taskAttemptId, long atomicI public static Long getTaskAttemptId(long taskId, long attemptId) { if (taskId > Constants.MAX_TASK_ID) { throw new IllegalArgumentException( - "Can't support taskId[" - + taskId - + "], the max value should be " - + Constants.MAX_TASK_ID); + "Can't support taskId[" + taskId + "], the max value should be " + Constants.MAX_TASK_ID); } if (attemptId > Constants.MAX_ATTEMPT_ID) { throw new IllegalArgumentException( @@ -80,7 +81,8 @@ public static Long getTaskAttemptId(long taskId, long attemptId) { + "], the max value should be " + Constants.MAX_ATTEMPT_ID); } - return taskId + attemptId << (Constants.TASK_ATTEMPT_ID_MAX_LENGTH - Constants.ATTEMPT_ID_MAX_LENGTH); + return taskId + attemptId + << (Constants.TASK_ATTEMPT_ID_MAX_LENGTH - Constants.ATTEMPT_ID_MAX_LENGTH); } public static RemoteStorageInfo fetchRemoteStorage( diff --git a/common/src/main/java/org/apache/uniffle/common/util/Constants.java b/common/src/main/java/org/apache/uniffle/common/util/Constants.java index 3b3c0d81d4..41ae6fa8b0 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/Constants.java +++ b/common/src/main/java/org/apache/uniffle/common/util/Constants.java @@ -38,10 +38,8 @@ private Constants() {} public static final long MAX_SEQUENCE_NO = (1 << Constants.ATOMIC_INT_MAX_LENGTH) - 1; public static final long MAX_PARTITION_ID = (1 << Constants.PARTITION_ID_MAX_LENGTH) - 1; public static final long MAX_TASK_ATTEMPT_ID = (1 << Constants.TASK_ATTEMPT_ID_MAX_LENGTH) - 1; - public static final int MAX_TASK_LENGTH = 17; - public static final int MAX_ATTEMPT_LENGTH = 4; - public static final long MAX_TASK_ID = (1 << MAX_TASK_LENGTH) - 1; - public static final long MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1; + public static final long MAX_TASK_ID = (1 << TASK_ID_MAX_LENGTH) - 1; + public static final long MAX_ATTEMPT_ID = (1 << ATTEMPT_ID_MAX_LENGTH) - 1; public static final long INVALID_BLOCK_ID = -1L; public static final String KEY_SPLIT_CHAR = "/"; public static final String COMMA_SPLIT_CHAR = ",";