Skip to content

Commit

Permalink
[apache#1398] [FOLLOW UP]
Browse files Browse the repository at this point in the history
  • Loading branch information
qijiale76 committed Jan 17, 2024
1 parent 071c46e commit 40cfd3e
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.hadoop.io.RawComparator;
import org.apache.hadoop.io.serializer.Serializer;
import org.apache.hadoop.mapreduce.RssMRUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
Expand Down Expand Up @@ -341,7 +341,7 @@ ShuffleBlockInfo createShuffleBlock(SortWriteBuffer wb) {
final long crc32 = ChecksumUtils.getCrc32(compressed);
compressTime += System.currentTimeMillis() - start;
final long blockId =
RssMRUtils.getBlockId(partitionId, taskAttemptId, getNextSeqNo(partitionId));
ClientUtils.getBlockId(partitionId, taskAttemptId, getNextSeqNo(partitionId));
uncompressedDataLen += data.length;
// add memory to indicate bytes which will be sent to shuffle server
inSendListBytes.addAndGet(wb.getDataLength());
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int a
if (appAttemptId < 1) {
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
long lowBytes = (long) taskAttemptID.getId() - (appAttemptId - 1) * 1000;
if (lowBytes > Constants.MAX_ATTEMPT_ID) {
throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed");
long highBytes = (long) taskAttemptID.getId() - (appAttemptId - 1) * 1000L;
if (highBytes > Constants.MAX_ATTEMPT_ID) {
throw new RssException(
"TaskAttempt " + taskAttemptID + " low bytes " + highBytes + " exceed");
}
long highBytes = taskAttemptID.getTaskID().getId();
if (highBytes > Constants.MAX_TASK_ID || highBytes < 0) {
long lowBytes = taskAttemptID.getTaskID().getId();
if (lowBytes > Constants.MAX_TASK_ID || highBytes < 0) {
throw new RssException(
"TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed");
"TaskAttempt " + taskAttemptID + " high bytes " + lowBytes + " exceed");
}
return (highBytes << (Constants.MAX_ATTEMPT_LENGTH)) + lowBytes;
return (highBytes << (Constants.TASK_ID_MAX_LENGTH)) + lowBytes;
}

public static TaskAttemptID createMRTaskAttemptId(
Expand All @@ -68,7 +69,7 @@ public static TaskAttemptID createMRTaskAttemptId(
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
TaskID taskID =
new TaskID(jobID, taskType, (int) (rssTaskAttemptId >> Constants.MAX_ATTEMPT_LENGTH));
new TaskID(jobID, taskType, (int) (rssTaskAttemptId >> Constants.ATTEMPT_ID_MAX_LENGTH));
return new TaskAttemptID(
taskID, (int) (rssTaskAttemptId & Constants.MAX_ATTEMPT_ID) + 1000 * (appAttemptId - 1));
}
Expand Down Expand Up @@ -221,33 +222,6 @@ public static String getString(Configuration rssJobConf, String key, String defa
return rssJobConf.get(key, defaultValue);
}

public static long getBlockId(long partitionId, long taskAttemptId, long nextSeqNo) {
if (taskAttemptId < 0 || taskAttemptId > Constants.MAX_TASK_ATTEMPT_ID) {
throw new RssException(
"Can't support attemptId ["
+ taskAttemptId
+ "], the max value should be "
+ Constants.MAX_TASK_ATTEMPT_ID);
}
if (nextSeqNo < 0 || nextSeqNo > Constants.MAX_SEQUENCE_NO) {
throw new RssException(
"Can't support sequence ["
+ nextSeqNo
+ "], the max value should be "
+ Constants.MAX_SEQUENCE_NO);
}
if (partitionId < 0 || partitionId > Constants.MAX_PARTITION_ID) {
throw new RssException(
"Can't support partitionId["
+ partitionId
+ "], the max value should be "
+ Constants.MAX_PARTITION_ID);
}
return (nextSeqNo << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH))
+ (partitionId << Constants.TASK_ATTEMPT_ID_MAX_LENGTH)
+ taskAttemptId;
}

public static long getTaskAttemptId(long blockId) {
return blockId & Constants.MAX_TASK_ATTEMPT_ID;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.apache.hadoop.mapred.Task;
import org.apache.hadoop.mapred.TaskStatus;
import org.apache.hadoop.mapred.TaskUmbilicalProtocol;
import org.apache.hadoop.mapreduce.MRIdHelper;
import org.apache.hadoop.mapreduce.RssMRConfig;
import org.apache.hadoop.mapreduce.RssMRUtils;
import org.apache.hadoop.util.Progress;
Expand All @@ -43,6 +42,7 @@
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
import org.apache.uniffle.client.util.DefaultIdHelper;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.hadoop.shim.HadoopShimImpl;
Expand Down Expand Up @@ -229,7 +229,7 @@ public RawKeyValueIterator run() throws IOException, InterruptedException {
taskIdBitmap,
serverInfoList,
readerJobConf,
new MRIdHelper(),
new DefaultIdHelper(),
expectedTaskIdsBitmapFilterEnable,
RssMRConfig.toRssConf(rssJobConf));
ShuffleReadClient shuffleReadClient =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.hadoop.mapred.JobConf;
import org.junit.jupiter.api.Test;

import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.Constants;
Expand Down Expand Up @@ -70,10 +71,10 @@ public void blockConvertTest() {
TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1);
long blockId = RssMRUtils.getBlockId(1, taskAttemptId, 0);
long blockId = ClientUtils.getBlockId(1, taskAttemptId, 0);
long newTaskAttemptId = RssMRUtils.getTaskAttemptId(blockId);
assertEquals(taskAttemptId, newTaskAttemptId);
blockId = RssMRUtils.getBlockId(2, taskAttemptId, 2);
blockId = ClientUtils.getBlockId(2, taskAttemptId, 2);
newTaskAttemptId = RssMRUtils.getTaskAttemptId(blockId);
assertEquals(taskAttemptId, newTaskAttemptId);
}
Expand All @@ -87,7 +88,7 @@ public void partitionIdConvertBlockTest() {
long mask = (1L << Constants.PARTITION_ID_MAX_LENGTH) - 1;
for (int partitionId = 0; partitionId <= 3000; partitionId++) {
for (int seqNo = 0; seqNo <= 10; seqNo++) {
long blockId = RssMRUtils.getBlockId(Long.valueOf(partitionId), taskAttemptId, seqNo);
long blockId = ClientUtils.getBlockId(Long.valueOf(partitionId), taskAttemptId, seqNo);
int newPartitionId =
Math.toIntExact((blockId >> Constants.TASK_ATTEMPT_ID_MAX_LENGTH) & mask);
assertEquals(partitionId, newPartitionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,12 @@ public <K, V> ShuffleWriter<K, V> getWriter(
} else {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
LOG.info("RssHandle appId is {}, shuffleId is {}, taskId is {}, attemptId is {}", rssHandle.getAppId(),
rssHandle.getShuffleId(), taskId, taskAttemptId);
LOG.info(
"RssHandle appId is {}, shuffleId is {}, taskId is {}, attemptId is {}",
rssHandle.getAppId(),
rssHandle.getShuffleId(),
taskId,
taskAttemptId);
return new RssShuffleWriter<>(
rssHandle.getAppId(),
shuffleId,
Expand Down
51 changes: 8 additions & 43 deletions client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -151,42 +151,6 @@ public static String uniqueIdentifierToAttemptId(String uniqueIdentifier) {
return StringUtils.join(ids, "_", 0, 7);
}

public static long getBlockId(long partitionId, long taskAttemptId, long nextSeqNo) {
LOG.info(
"GetBlockId, partitionId:{}, taskAttemptId:{}, nextSeqNo:{}",
partitionId,
taskAttemptId,
nextSeqNo);
if (taskAttemptId < 0 || taskAttemptId > Constants.MAX_TASK_ATTEMPT_ID) {
throw new RssException(
"Can't support taskAttemptId ["
+ taskAttemptId
+ "], the max value should be "
+ Constants.MAX_TASK_ATTEMPT_ID);
}
if (nextSeqNo < 0 || nextSeqNo > Constants.MAX_SEQUENCE_NO) {
throw new RssException(
"Can't support sequence ["
+ nextSeqNo
+ "], the max value should be "
+ Constants.MAX_SEQUENCE_NO);
}
if (partitionId < 0 || partitionId > Constants.MAX_PARTITION_ID) {
throw new RssException(
"Can't support partitionId ["
+ partitionId
+ "], the max value should be "
+ Constants.MAX_PARTITION_ID);
}
return (nextSeqNo << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH))
+ (partitionId << Constants.TASK_ATTEMPT_ID_MAX_LENGTH)
+ taskAttemptId;
}

public static long getTaskAttemptId(long blockId) {
return blockId & Constants.MAX_TASK_ATTEMPT_ID;
}

public static int estimateTaskConcurrency(Configuration jobConf, int mapNum, int reduceNum) {
double dynamicFactor =
jobConf.getDouble(
Expand Down Expand Up @@ -277,16 +241,17 @@ private static int mapVertexId(String vertexName) {
}

public static long convertTaskAttemptIdToLong(TezTaskAttemptID taskAttemptID) {
long lowBytes = taskAttemptID.getId();
if (lowBytes > Constants.MAX_ATTEMPT_ID) {
throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed");
long highBytes = taskAttemptID.getId();
if (highBytes > Constants.MAX_ATTEMPT_ID) {
throw new RssException(
"TaskAttempt " + taskAttemptID + " low bytes " + highBytes + " exceed");
}
long highBytes = taskAttemptID.getTaskID().getId();
if (highBytes > Constants.MAX_TASK_ID || highBytes < 0) {
long lowBytes = taskAttemptID.getTaskID().getId();
if (lowBytes > Constants.MAX_TASK_ID || lowBytes < 0) {
throw new RssException(
"TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed.");
"TaskAttempt " + taskAttemptID + " high bytes " + lowBytes + " exceed.");
}
long id = (highBytes << (Constants.MAX_ATTEMPT_LENGTH)) + lowBytes;
long id = (highBytes << (Constants.TASK_ID_MAX_LENGTH)) + lowBytes;
LOG.info("ConvertTaskAttemptIdToLong taskAttemptID:{}, id is {}, .", taskAttemptID, id);
return id;
}
Expand Down
28 changes: 0 additions & 28 deletions client-tez/src/main/java/org/apache/tez/common/TezIdHelper.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.tez.common.CallableWithNdc;
import org.apache.tez.common.RssTezConfig;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezIdHelper;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.shuffle.FetchResult;
Expand All @@ -43,6 +42,7 @@
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
import org.apache.uniffle.client.util.DefaultIdHelper;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.UnitConverter;
Expand Down Expand Up @@ -191,7 +191,7 @@ protected FetchResult callInternal() throws Exception {
taskIdBitmap,
new ArrayList<>(serverInfoSet),
hadoopConf,
new TezIdHelper(),
new DefaultIdHelper(),
expectedTaskIdsBitmapFilterEnable,
RssTezConfig.toRssConf(this.conf));
ShuffleReadClient shuffleReadClient =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
import org.apache.tez.common.InputContextUtils;
import org.apache.tez.common.RssTezConfig;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezIdHelper;
import org.apache.tez.common.TezUtilsInternal;
import org.apache.tez.common.UmbilicalUtils;
import org.apache.tez.common.counters.TaskCounter;
Expand Down Expand Up @@ -96,6 +95,7 @@
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
import org.apache.uniffle.client.util.DefaultIdHelper;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
Expand Down Expand Up @@ -1863,7 +1863,7 @@ private RssTezShuffleDataFetcher constructRssFetcherForPartition(
taskIdBitmap,
shuffleServerInfoList,
hadoopConf,
new TezIdHelper(),
new DefaultIdHelper(),
expectedTaskIdsBitmapFilterEnable,
RssTezConfig.toRssConf(conf));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.hadoop.io.RawComparator;
import org.apache.hadoop.io.serializer.Serializer;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
Expand All @@ -47,6 +46,7 @@

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
Expand Down Expand Up @@ -359,7 +359,7 @@ ShuffleBlockInfo createShuffleBlock(WriteBuffer wb) {
final long crc32 = ChecksumUtils.getCrc32(compressed);
compressTime += System.currentTimeMillis() - start;
final long blockId =
RssTezUtils.getBlockId((long) partitionId, taskAttemptId, getNextSeqNo(partitionId));
ClientUtils.getBlockId(partitionId, taskAttemptId, getNextSeqNo(partitionId));
LOG.info("blockId is {}", blockId);
uncompressedDataLen += data.length;
// add memory to indicate bytes which will be sent to shuffle server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.tez.dag.records.TezVertexID;
import org.junit.jupiter.api.Test;

import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.Constants;
Expand Down Expand Up @@ -80,11 +81,11 @@ public void blockConvertTest() {
TezTaskID tId = TezTaskID.getInstance(vId, 389);
TezTaskAttemptID tezTaskAttemptId = TezTaskAttemptID.getInstance(tId, 2);
long taskAttemptId = RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId);
long blockId = RssTezUtils.getBlockId(1, taskAttemptId, 0);
long newTaskAttemptId = RssTezUtils.getTaskAttemptId(blockId);
long blockId = ClientUtils.getBlockId(1, taskAttemptId, 0);
long newTaskAttemptId = ClientUtils.getTaskAttemptId(blockId);
assertEquals(taskAttemptId, newTaskAttemptId);
blockId = RssTezUtils.getBlockId(2, taskAttemptId, 2);
newTaskAttemptId = RssTezUtils.getTaskAttemptId(blockId);
blockId = ClientUtils.getBlockId(2, taskAttemptId, 2);
newTaskAttemptId = ClientUtils.getTaskAttemptId(blockId);
assertEquals(taskAttemptId, newTaskAttemptId);
}

Expand All @@ -99,7 +100,7 @@ public void testPartitionIdConvertBlock() {
long mask = (1L << Constants.PARTITION_ID_MAX_LENGTH) - 1;
for (int partitionId = 0; partitionId <= 3000; partitionId++) {
for (int seqNo = 0; seqNo <= 10; seqNo++) {
long blockId = RssTezUtils.getBlockId(Long.valueOf(partitionId), taskAttemptId, seqNo);
long blockId = ClientUtils.getBlockId(partitionId, taskAttemptId, seqNo);
int newPartitionId =
Math.toIntExact((blockId >> Constants.TASK_ATTEMPT_ID_MAX_LENGTH) & mask);
assertEquals(partitionId, newPartitionId);
Expand Down

0 comments on commit 40cfd3e

Please sign in to comment.