Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support zstd compression for shuffle data #91

Merged
merged 5 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/main/java/com/uber/rss/common/Compression.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

package com.uber.rss.common;

import com.github.luben.zstd.ZstdInputStream;
import com.github.luben.zstd.ZstdOutputStream;
import com.uber.rss.exceptions.RssException;
import com.uber.rss.exceptions.RssUnsupportedCompressionException;
import net.jpountz.lz4.*;
import net.jpountz.xxhash.XXHashFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.zip.Checksum;
Expand All @@ -28,6 +32,7 @@ public class Compression {
private static final Logger logger = LoggerFactory.getLogger(Compression.class);

public final static String COMPRESSION_CODEC_LZ4 = "lz4";
public final static String COMPRESSION_CODEC_ZSTD = "zstd";

private static final int defaultLz4BlockSize = 65536;
private static final int defaultLz4ChecksumSeed = -1756908916;
Expand All @@ -41,6 +46,12 @@ public static OutputStream compressStream(OutputStream stream, String codec) {
LZ4Compressor compressor = LZ4Factory.fastestInstance().fastCompressor();
Checksum defaultLz4Checksum = XXHashFactory.fastestInstance().newStreamingHash32(defaultLz4ChecksumSeed).asChecksum();
return new LZ4BlockOutputStream(stream, defaultLz4BlockSize, compressor, defaultLz4Checksum, true);
} else if (codec.equals(Compression.COMPRESSION_CODEC_ZSTD)) {
try {
return new ZstdOutputStream(stream);
} catch (IOException e) {
throw new RssException("Failed to create ZstdOutputStream", e);
}
} else {
throw new RssUnsupportedCompressionException(String.format("Unsupported compression codec: %s", codec));
}
Expand All @@ -55,6 +66,12 @@ public static InputStream decompressStream(InputStream stream, String codec) {
LZ4FastDecompressor decompressor = LZ4Factory.fastestInstance().fastDecompressor();
Checksum defaultLz4Checksum = XXHashFactory.fastestInstance().newStreamingHash32(defaultLz4ChecksumSeed).asChecksum();
return new LZ4BlockInputStream(stream, decompressor, defaultLz4Checksum, false);
} else if (codec.equals(Compression.COMPRESSION_CODEC_ZSTD)) {
try {
return new ZstdInputStream(stream);
} catch (IOException e) {
throw new RssException("Failed to create ZstdInputStream", e);
}
} else {
throw new RssUnsupportedCompressionException(String.format("Unsupported compression codec: %s", codec));
}
Expand Down
29 changes: 25 additions & 4 deletions src/main/java/com/uber/rss/tools/PartitionFileChecker.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@

package com.uber.rss.tools;

import com.github.luben.zstd.Zstd;
import com.github.luben.zstd.ZstdInputStream;
import com.uber.rss.common.Compression;
import com.uber.rss.exceptions.RssInvalidDataException;
import com.uber.rss.util.ByteBufUtils;
import com.uber.rss.util.StreamUtils;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import net.jpountz.lz4.LZ4BlockInputStream;
Expand All @@ -30,8 +35,8 @@
*/
public class PartitionFileChecker {
private String filePath;
private String fileCompressCodec = "lz4";
private String blockCompressCodec = "lz4";
private String fileCompressCodec = Compression.COMPRESSION_CODEC_LZ4;
private String blockCompressCodec = Compression.COMPRESSION_CODEC_LZ4;

public void run() {
ByteBuf dataBlockStreamData = Unpooled.buffer(1000);
Expand All @@ -40,8 +45,10 @@ public void run() {
// Read data block stream from file
try (FileInputStream fileInputStream = new FileInputStream(filePath)) {
InputStream inputStream = fileInputStream;
if (fileCompressCodec.equals("lz4")) {
if (fileCompressCodec.equals(Compression.COMPRESSION_CODEC_LZ4)) {
inputStream = new LZ4BlockInputStream(fileInputStream);
} else if (fileCompressCodec.equals(Compression.COMPRESSION_CODEC_ZSTD)) {
inputStream = new ZstdInputStream(fileInputStream);
}
while (true) {
byte[] bytes = StreamUtils.readBytes(inputStream, Long.BYTES);
Expand All @@ -59,7 +66,7 @@ public void run() {
throw new RuntimeException(e);
}

if (blockCompressCodec.equals("lz4")) {
if (blockCompressCodec.equals(Compression.COMPRESSION_CODEC_LZ4)) {
dataBlockStreamUncompressedData = Unpooled.buffer(1000);

LZ4FastDecompressor decompressor = LZ4Factory.fastestInstance().fastDecompressor();
Expand All @@ -73,6 +80,20 @@ public void run() {
decompressor.decompress(compressedBytes, uncompressedBytes);
dataBlockStreamUncompressedData.writeBytes(uncompressedBytes);
}
} else if (blockCompressCodec.equals(Compression.COMPRESSION_CODEC_ZSTD)) {
dataBlockStreamUncompressedData = Unpooled.buffer(1000);
while (dataBlockStreamData.readableBytes() > 0) {
int compressedLen = dataBlockStreamData.readInt();
int uncompressedLen = dataBlockStreamData.readInt();
byte[] compressedBytes = new byte[compressedLen];
byte[] uncompressedBytes = new byte[uncompressedLen];
dataBlockStreamData.readBytes(compressedBytes);
long decompressResult = Zstd.decompress(compressedBytes, uncompressedBytes);
if (Zstd.isError(decompressResult)) {
throw new RssInvalidDataException("Failed to decompress zstd data, returned value: " + decompressResult);
}
dataBlockStreamUncompressedData.writeBytes(uncompressedBytes);
}
}

while (dataBlockStreamUncompressedData.readableBytes() > 0) {
Expand Down
10 changes: 10 additions & 0 deletions src/main/scala/org/apache/spark/shuffle/RssOpts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,14 @@ object RssOpts {
.doc("Create lazy connections from mappers to RSS servers just before sending the shuffle data")
.booleanConf
.createWithDefault(true)
val compression: ConfigEntry[String] =
ConfigBuilder("spark.shuffle.rss.compress")
.doc("type of compression for shuffle data, supported: lz4, zstd")
.stringConf
.createWithDefault("lz4")
val zstdCompressionLevel: ConfigEntry[Int] =
ConfigBuilder("spark.shuffle.rss.compress.zstd.level")
.doc("level of zstd compression")
.intConf
.createWithDefault(1)
}
11 changes: 10 additions & 1 deletion src/main/scala/org/apache/spark/shuffle/RssShuffleManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import java.util.Random
import java.util.function.Supplier
import com.uber.rss.{RssBuildInfo, StreamServerConfig}
import com.uber.rss.clients.{LazyWriteClient, MultiServerAsyncWriteClient, MultiServerHeartbeatClient, MultiServerSyncWriteClient, MultiServerWriteClient, PooledWriteClientFactory, ServerConnectionCacheUpdateRefresher, ServerConnectionStringCache, ServerConnectionStringResolver, ServerReplicationGroupUtil, ShuffleWriteConfig}
import com.uber.rss.common.{AppShuffleId, AppTaskAttemptId, ServerDetail, ServerList}
import com.uber.rss.common.{AppShuffleId, AppTaskAttemptId, Compression, ServerDetail, ServerList}
import com.uber.rss.exceptions.{RssException, RssInvalidStateException, RssNoServerAvailableException, RssServerResolveException}
import com.uber.rss.metadata.{ServiceRegistry, ServiceRegistryUtils, StandaloneServiceRegistryClient, ZooKeeperServiceRegistry}
import com.uber.rss.metrics.{M3Stats, ShuffleClientStageMetrics, ShuffleClientStageMetricsKey}
Expand Down Expand Up @@ -279,13 +279,21 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
try {
writeClient.connect()

val compressionLevel = if (Compression.COMPRESSION_CODEC_ZSTD.equals(RssOpts.compression)) {
conf.get(RssOpts.zstdCompressionLevel)
} else {
0
}

new RssShuffleWriter(
rssShuffleHandle.user,
new ServerList(rssShuffleHandle.rssServers.map(_.toServerDetail()).toArray),
writeClient,
mapInfo,
rssShuffleHandle.numMaps,
serializer,
conf.get(RssOpts.compression),
CompressionOptions(compressionLevel),
bufferOptions,
rssShuffleHandle.dependency,
shuffleClientStageMetrics,
Expand Down Expand Up @@ -332,6 +340,7 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
startPartition = startPartition,
endPartition = endPartition,
serializer = serializer,
decompression = conf.get(RssOpts.compression),
context = context,
shuffleDependency = rssShuffleHandle.dependency,
numMaps = rssShuffleHandle.numMaps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class RssShuffleReader[K, C](
startPartition: Int,
endPartition: Int,
serializer: Serializer,
decompression: String,
context: TaskContext,
shuffleDependency: ShuffleDependency[K, _, C],
numMaps: Int,
Expand Down Expand Up @@ -55,6 +56,7 @@ class RssShuffleReader[K, C](
startPartition = startPartition,
endPartition = endPartition,
serializer = serializer,
decompression = decompression,
numMaps = numMaps,
rssServers = rssServers,
partitionFanout = partitionFanout,
Expand Down
29 changes: 25 additions & 4 deletions src/main/scala/org/apache/spark/shuffle/RssShuffleWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

package org.apache.spark.shuffle

import com.github.luben.zstd.Zstd

import java.nio.ByteBuffer
import java.util.concurrent.{CompletableFuture, TimeUnit}
import com.uber.rss.clients.ShuffleDataWriter
import com.uber.rss.common.{AppTaskAttemptId, ServerList}
import com.uber.rss.common.{AppTaskAttemptId, Compression, ServerList}
import com.uber.rss.exceptions.RssInvalidStateException
import com.uber.rss.metrics.ShuffleClientStageMetrics
import net.jpountz.lz4.LZ4Factory
Expand All @@ -28,13 +30,17 @@ import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.rss.{BufferManagerOptions, RssUtils, WriteBufferManager, WriterAggregationManager, WriterNoAggregationManager}

case class CompressionOptions(level: Int=1)

class RssShuffleWriter[K, V, C](
user: String,
rssServers: ServerList,
writeClient: ShuffleDataWriter,
mapInfo: AppTaskAttemptId,
numMaps: Int,
serializer: Serializer,
compression: String,
compressionOptions: CompressionOptions,
bufferOptions: BufferManagerOptions,
shuffleDependency: ShuffleDependency[K, V, C],
stageMetrics: ShuffleClientStageMetrics,
Expand Down Expand Up @@ -71,7 +77,11 @@ class RssShuffleWriter[K, V, C](

logInfo(s"Using ${writerManager.getClass} as the shuffle writer manager.")

private val compressor = LZ4Factory.fastestInstance.fastCompressor
private val compressor = if (Compression.COMPRESSION_CODEC_ZSTD.equals(compression)) {
null
} else {
LZ4Factory.fastestInstance.fastCompressor
}

private def getPartition(key: K): Int = {
if (shouldPartition) partitioner.getPartition(key) else 0
Expand Down Expand Up @@ -201,8 +211,19 @@ class RssShuffleWriter[K, V, C](

private def createDataBlock(buffer: Array[Byte]): ByteBuffer = {
val uncompressedByteCount = buffer.size
val compressedBuffer = new Array[Byte](compressor.maxCompressedLength(uncompressedByteCount))
val compressedByteCount = compressor.compress(buffer, compressedBuffer)
var compressedBuffer: Array[Byte] = null
var compressedByteCount: Int = 0
if (Compression.COMPRESSION_CODEC_ZSTD.equals(compression)) {
compressedBuffer = new Array[Byte](uncompressedByteCount)
val n = Zstd.compress(compressedBuffer, buffer, compressionOptions.level)
if (Zstd.isError(n)) {
throw new RssInvalidStateException(s"Failed to run zstd compress for data block, zstd returned value: $compressedByteCount")
}
compressedByteCount = n.toInt
} else {
compressedBuffer = new Array[Byte](compressor.maxCompressedLength(uncompressedByteCount))
compressedByteCount = compressor.compress(buffer, compressedBuffer)
}
val dataBlockByteBuffer = ByteBuffer.allocate(Integer.BYTES + Integer.BYTES + compressedByteCount)
dataBlockByteBuffer.putInt(compressedByteCount)
dataBlockByteBuffer.putInt(uncompressedByteCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BlockDownloaderPartitionRangeRecordIterator[K, C](
startPartition: Int,
endPartition: Int,
serializer: Serializer,
decompression: String,
numMaps: Int,
rssServers: ServerList,
partitionFanout: Int,
Expand Down Expand Up @@ -138,6 +139,7 @@ class BlockDownloaderPartitionRangeRecordIterator[K, C](
shuffleId,
partition,
serializer,
decompression,
downloader,
shuffleReadMetrics)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
package org.apache.spark.shuffle.rss

import java.util.concurrent.TimeUnit

import com.esotericsoftware.kryo.io.Input
import com.github.luben.zstd.Zstd
import com.uber.rss.clients.{ShuffleDataReader, TaskDataBlock}
import com.uber.rss.common.Compression
import com.uber.rss.exceptions.{RssInvalidDataException, RssInvalidStateException}
import com.uber.rss.metrics.M3Stats
import com.uber.rss.util.{ByteBufUtils, ExceptionUtils}
Expand All @@ -27,14 +28,21 @@ import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.shuffle.FetchFailedException

import java.util

class BlockDownloaderPartitionRecordIterator[K, C](
shuffleId: Int,
partition: Int,
serializer: Serializer,
decompression: String,
downloader: ShuffleDataReader,
shuffleReadMetrics: ShuffleReadMetrics) extends Iterator[Product2[K, C]] with Logging {

private val decompressor: LZ4FastDecompressor = LZ4Factory.fastestInstance.fastDecompressor()
private val lz4Decompressor: LZ4FastDecompressor = if (Compression.COMPRESSION_CODEC_ZSTD.equals(decompression)) {
null
} else {
LZ4Factory.fastestInstance.fastDecompressor()
}

private var downloaderEof = false

Expand Down Expand Up @@ -164,12 +172,23 @@ class BlockDownloaderPartitionRecordIterator[K, C](
val compressedLen = ByteBufUtils.readInt(bytes, 0)
val uncompressedLen = ByteBufUtils.readInt(bytes, Integer.BYTES)
val uncompressedBytes = new Array[Byte](uncompressedLen)
val count = decompressor.decompress(bytes, Integer.BYTES + Integer.BYTES, uncompressedBytes, 0, uncompressedLen)
decompressTime += (System.nanoTime() - decompressStartTime)
if (count != compressedLen) {
throw new RssInvalidDataException(
s"Data corrupted for shuffle $shuffleId partition $partition, expected compressed length: $compressedLen, but it is: $count, " + String.valueOf(downloader))
if (Compression.COMPRESSION_CODEC_ZSTD.equals(decompression)) {
// TODO Zstd in Spark 2.4 does not support decompress method with a range from source byte array
// Better to use Zstd.decompressByteArray for Spark version higher than 2.4 to avoid copying bytes
val sourceBytes = util.Arrays.copyOfRange(bytes, Integer.BYTES + Integer.BYTES, bytes.length)
val n = Zstd.decompress(uncompressedBytes, sourceBytes)
if (Zstd.isError(n)) {
throw new RssInvalidDataException(
s"Data corrupted for shuffle $shuffleId partition $partition, failed to decompress zstd, decompress returned: $n, " + String.valueOf(downloader))
}
} else {
val count = lz4Decompressor.decompress(bytes, Integer.BYTES + Integer.BYTES, uncompressedBytes, 0, uncompressedLen)
if (count != compressedLen) {
throw new RssInvalidDataException(
s"Data corrupted for shuffle $shuffleId partition $partition, expected compressed length: $compressedLen, but it is: $count, " + String.valueOf(downloader))
}
}
decompressTime += (System.nanoTime() - decompressStartTime)

deserializationInput = new Input(uncompressedBytes, 0, uncompressedLen)
deserializationStream = serializerInstance.deserializeStream(deserializationInput)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.shuffle.{MockTaskContext, RssOpts, RssShuffleReader, RssShuffleWriter}
import org.apache.spark.shuffle.{CompressionOptions, MockTaskContext, RssOpts, RssShuffleReader, RssShuffleWriter}
import org.apache.spark.{HashPartitioner, MapOutputTrackerMaster, ShuffleDependency, SparkConf, SparkContext, SparkEnv}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -323,6 +323,8 @@ class RssStressTool extends Logging {
mapInfo = new AppTaskAttemptId(appMapId, taskAttemptId),
numMaps = numMaps,
serializer = new KryoSerializer(sparkConf),
compressionOptions = CompressionOptions(),
compression = Compression.COMPRESSION_CODEC_LZ4,
bufferOptions = BufferManagerOptions(writerBufferSize, 256 * 1024 * 1024, writerBufferSpill),
shuffleDependency = shuffleDependency,
stageMetrics = new ShuffleClientStageMetrics(new ShuffleClientStageMetricsKey("user1", "queue=1")),
Expand Down Expand Up @@ -356,6 +358,7 @@ class RssStressTool extends Logging {
startPartition = readPartitionId,
endPartition = readPartitionId + 1,
serializer = shuffleDependency.serializer,
decompression = Compression.COMPRESSION_CODEC_LZ4,
context = new MockTaskContext(1, 0, taskAttemptIdSeed.incrementAndGet()),
shuffleDependency = shuffleDependency,
numMaps = numMaps,
Expand Down
Loading