From 5203505ff5f446726aa29af4fd52f91106800830 Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 10 Sep 2019 14:54:47 -0700 Subject: [PATCH 01/14] Bring implementation into closer alignment with upstream. Step to ease merge conflict resolution and build failure problems when we pull in changes from upstream. --- .../shuffle/ShuffleExecutorComponents.java | 37 ------------ .../api/shuffle/ShuffleWriteSupport.java | 37 ------------ .../api}/ShuffleBlockInfo.java | 2 +- .../api}/ShuffleDataIO.java | 2 +- .../api}/ShuffleDriverComponents.java | 2 +- .../api/ShuffleExecutorComponents.java} | 22 +++++-- .../api}/ShuffleMapOutputWriter.java | 2 +- .../api}/ShufflePartitionWriter.java | 2 +- .../api}/SupportsTransferTo.java | 2 +- .../TransferrableWritableByteChannel.java | 2 +- .../sort/BypassMergeSortShuffleWriter.java | 18 +++--- ...faultTransferrableWritableByteChannel.java | 5 +- .../shuffle/sort/UnsafeShuffleWriter.java | 21 +++---- .../shuffle/sort/io/DefaultShuffleDataIO.java | 6 +- .../io/DefaultShuffleExecutorComponents.java | 57 ++++++++++++++----- .../io/DefaultShuffleMapOutputWriter.java | 8 +-- .../sort/io/DefaultShuffleWriteSupport.java | 52 ----------------- .../DefaultShuffleDriverComponents.java | 2 +- .../org/apache/spark/ContextCleaner.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 3 +- .../shuffle/BlockStoreShuffleReader.scala | 19 +++---- .../io/DefaultShuffleReadSupport.scala | 10 ++-- .../shuffle/sort/SortShuffleManager.scala | 10 ++-- .../shuffle/sort/SortShuffleWriter.scala | 6 +- .../util/collection/ExternalSorter.scala | 5 +- .../ShufflePartitionPairsWriter.scala | 2 +- .../sort/UnsafeShuffleWriterSuite.java | 24 ++++++-- .../DAGSchedulerShufflePluginSuite.scala | 5 +- .../ShuffleDriverComponentsSuite.scala | 43 ++++++++------ ...ypassMergeSortShuffleWriterBenchmark.scala | 15 +++-- .../BypassMergeSortShuffleWriterSuite.scala | 26 +++++---- .../sort/ShuffleWriterBenchmarkBase.scala | 7 ++- .../sort/SortShuffleWriterBenchmark.scala | 14 +++-- .../sort/UnsafeShuffleWriterBenchmark.scala | 14 +++-- .../DefaultShuffleMapOutputWriterSuite.scala | 4 +- 35 files changed, 219 insertions(+), 269 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java delete mode 100644 core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java rename core/src/main/java/org/apache/spark/{api/shuffle => shuffle/api}/ShuffleBlockInfo.java (98%) rename core/src/main/java/org/apache/spark/{api/shuffle => shuffle/api}/ShuffleDataIO.java (96%) rename core/src/main/java/org/apache/spark/{api/shuffle => shuffle/api}/ShuffleDriverComponents.java (97%) rename core/src/main/java/org/apache/spark/{api/shuffle/ShuffleReadSupport.java => shuffle/api/ShuffleExecutorComponents.java} (72%) rename core/src/main/java/org/apache/spark/{api/shuffle => shuffle/api}/ShuffleMapOutputWriter.java (97%) rename core/src/main/java/org/apache/spark/{api/shuffle => shuffle/api}/ShufflePartitionWriter.java (97%) rename core/src/main/java/org/apache/spark/{api/shuffle => shuffle/api}/SupportsTransferTo.java (98%) rename core/src/main/java/org/apache/spark/{api/shuffle => shuffle/api}/TransferrableWritableByteChannel.java (98%) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java deleted file mode 100644 index a5fa032bf651d..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; - -import java.util.Map; - -/** - * :: Experimental :: - * An interface for building shuffle support for Executors - * - * @since 3.0.0 - */ -@Experimental -public interface ShuffleExecutorComponents { - void initializeExecutor(String appId, String execId, Map extraConfigs); - - ShuffleWriteSupport writes(); - - ShuffleReadSupport reads(); -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java deleted file mode 100644 index 7ee1d8a554073..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import java.io.IOException; - -import org.apache.spark.annotation.Experimental; - -/** - * :: Experimental :: - * An interface for deploying a shuffle map output writer - * - * @since 3.0.0 - */ -@Experimental -public interface ShuffleWriteSupport { - ShuffleMapOutputWriter createMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId, - int numPartitions) throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java similarity index 98% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java index 34daf2c137a12..66270a512b0e7 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; import org.apache.spark.api.java.Optional; import org.apache.spark.storage.BlockManagerId; diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java similarity index 96% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java index dd7c0ac7320cb..ac3d2a15fec5a 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; import org.apache.spark.annotation.Experimental; diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java similarity index 97% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java index 8b54968f9b134..cbc59bc7b6a05 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; import java.io.IOException; import java.util.Map; diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java similarity index 72% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java index 83947bd4d6fa4..87ad794707150 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -15,20 +15,30 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; +package org.apache.spark.shuffle.api; import java.io.IOException; import java.io.InputStream; +import java.util.Map; + +import org.apache.spark.annotation.Experimental; /** * :: Experimental :: - * An interface for reading shuffle records. + * An interface for building shuffle support for Executors + * * @since 3.0.0 */ @Experimental -public interface ShuffleReadSupport { +public interface ShuffleExecutorComponents { + void initializeExecutor(String appId, String execId, Map extraConfigs); + + ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId, + int numPartitions) throws IOException; + /** * Returns an underlying {@link Iterable} that will iterate * through shuffle data, given an iterable for the shuffle blocks to fetch. @@ -36,7 +46,7 @@ public interface ShuffleReadSupport { Iterable getPartitionReaders(Iterable blockMetadata) throws IOException; - default boolean shouldWrapStream() { + default boolean shouldWrapPartitionReaderStream() { return true; } } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java similarity index 97% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 025fc096faaad..14e020a5c63ad 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; import java.io.IOException; diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java similarity index 97% rename from core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java index 74c928b0b9c8f..3d6c287014bf5 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; import java.io.IOException; import java.io.OutputStream; diff --git a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java b/core/src/main/java/org/apache/spark/shuffle/api/SupportsTransferTo.java similarity index 98% rename from core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java rename to core/src/main/java/org/apache/spark/shuffle/api/SupportsTransferTo.java index 866b61d0bafd9..ae8cb36b7e719 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/SupportsTransferTo.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; import java.io.IOException; diff --git a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/api/TransferrableWritableByteChannel.java similarity index 98% rename from core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java rename to core/src/main/java/org/apache/spark/shuffle/api/TransferrableWritableByteChannel.java index 18234d7c4c944..76e0dfd8b5a05 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/TransferrableWritableByteChannel.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; import java.io.Closeable; import java.io.IOException; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 63aee8ad50da3..dd10bff9ec45f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -26,6 +26,7 @@ import javax.annotation.Nullable; import org.apache.spark.api.java.Optional; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import scala.None$; import scala.Option; import scala.Product2; @@ -40,11 +41,10 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; -import org.apache.spark.api.shuffle.SupportsTransferTo; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.SupportsTransferTo; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; import org.apache.spark.internal.config.package$; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -90,7 +90,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int mapId; private final long mapTaskAttemptId; private final Serializer serializer; - private final ShuffleWriteSupport shuffleWriteSupport; + private final ShuffleExecutorComponents shuffleExecutorComponents; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; @@ -112,7 +112,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { long mapTaskAttemptId, SparkConf conf, ShuffleWriteMetricsReporter writeMetrics, - ShuffleWriteSupport shuffleWriteSupport) { + ShuffleExecutorComponents shuffleExecutorComponents) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); @@ -125,13 +125,13 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); - this.shuffleWriteSupport = shuffleWriteSupport; + this.shuffleExecutorComponents = shuffleExecutorComponents; } @Override public void write(Iterator> records) throws IOException { assert (partitionWriters == null); - ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport + ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); try { if (!records.hasNext()) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java index 64ce851e392d2..cb8ac86972d35 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java @@ -20,12 +20,13 @@ import java.io.IOException; import java.nio.channels.FileChannel; import java.nio.channels.WritableByteChannel; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.SupportsTransferTo; import org.apache.spark.util.Utils; /** * This is used when transferTo is enabled but the shuffle plugin hasn't implemented - * {@link org.apache.spark.api.shuffle.SupportsTransferTo}. + * {@link SupportsTransferTo}. *

* This default implementation exists as a convenience to the unsafe shuffle writer and * the bypass merge sort shuffle writers. diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9627f1151f837..441718126bc92 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -24,6 +24,7 @@ import java.util.Iterator; import org.apache.spark.api.java.Optional; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.storage.BlockManagerId; import scala.Option; import scala.Product2; @@ -39,11 +40,10 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.api.shuffle.SupportsTransferTo; +import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.SupportsTransferTo; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; @@ -74,7 +74,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; - private final ShuffleWriteSupport shuffleWriteSupport; + private final ShuffleExecutorComponents shuffleExecutorComponents; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -111,7 +111,7 @@ public UnsafeShuffleWriter( TaskContext taskContext, SparkConf sparkConf, ShuffleWriteMetricsReporter writeMetrics, - ShuffleWriteSupport shuffleWriteSupport) throws IOException { + ShuffleExecutorComponents shuffleExecutorComponents) { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -127,7 +127,7 @@ public UnsafeShuffleWriter( this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = writeMetrics; - this.shuffleWriteSupport = shuffleWriteSupport; + this.shuffleExecutorComponents = shuffleExecutorComponents; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); @@ -216,8 +216,9 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport - .createMapOutputWriter(shuffleId, + final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents + .createMapOutputWriter( + shuffleId, mapId, taskContext.taskAttemptId(), partitioner.numPartitions()); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java index 7c124c1fe68bc..a6faa6ac52ca6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java @@ -18,9 +18,9 @@ package org.apache.spark.shuffle.sort.io; import org.apache.spark.SparkConf; -import org.apache.spark.api.shuffle.ShuffleDriverComponents; -import org.apache.spark.api.shuffle.ShuffleExecutorComponents; -import org.apache.spark.api.shuffle.ShuffleDataIO; +import org.apache.spark.shuffle.api.ShuffleDriverComponents; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleDataIO; import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents; public class DefaultShuffleDataIO implements ShuffleDataIO { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 3b5f9670d64d2..77edba8642728 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -17,12 +17,16 @@ package org.apache.spark.shuffle.sort.io; +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.io.InputStream; import org.apache.spark.MapOutputTracker; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; -import org.apache.spark.api.shuffle.ShuffleExecutorComponents; -import org.apache.spark.api.shuffle.ShuffleReadSupport; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.TaskContext; +import org.apache.spark.shuffle.api.ShuffleBlockInfo; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; @@ -33,36 +37,59 @@ public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { private final SparkConf sparkConf; + // Submodule for the read side for shuffles - implemented in Scala for ease of + // compatibility with previously written code. + private DefaultShuffleReadSupport shuffleReadSupport; private BlockManager blockManager; private IndexShuffleBlockResolver blockResolver; - private MapOutputTracker mapOutputTracker; - private SerializerManager serializerManager; public DefaultShuffleExecutorComponents(SparkConf sparkConf) { this.sparkConf = sparkConf; } + @VisibleForTesting + public DefaultShuffleExecutorComponents( + SparkConf sparkConf, + BlockManager blockManager, + MapOutputTracker mapOutputTracker, + SerializerManager serializerManager, + IndexShuffleBlockResolver blockResolver) { + this.sparkConf = sparkConf; + this.blockManager = blockManager; + this.blockResolver = blockResolver; + this.shuffleReadSupport = new DefaultShuffleReadSupport( + blockManager, mapOutputTracker, serializerManager, sparkConf); + } + @Override public void initializeExecutor(String appId, String execId, Map extraConfigs) { blockManager = SparkEnv.get().blockManager(); - mapOutputTracker = SparkEnv.get().mapOutputTracker(); - serializerManager = SparkEnv.get().serializerManager(); + MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker(); + SerializerManager serializerManager = SparkEnv.get().serializerManager(); blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + shuffleReadSupport = new DefaultShuffleReadSupport( + blockManager, mapOutputTracker, serializerManager, sparkConf); } @Override - public ShuffleWriteSupport writes() { + public ShuffleMapOutputWriter createMapOutputWriter(int shuffleId, int mapId, long mapTaskAttemptId, int numPartitions) throws IOException { checkInitialized(); - return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); + return new DefaultShuffleMapOutputWriter( + shuffleId, + mapId, + numPartitions, + blockManager.shuffleServerId(), + TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); } @Override - public ShuffleReadSupport reads() { - checkInitialized(); - return new DefaultShuffleReadSupport(blockManager, - mapOutputTracker, - serializerManager, - sparkConf); + public Iterable getPartitionReaders(Iterable blockMetadata) throws IOException { + return shuffleReadSupport.getPartitionReaders(blockMetadata); + } + + @Override + public boolean shouldWrapPartitionReaderStream() { + return false; } private void checkInitialized() { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java index ad55b3db377f6..5c8d2d43dacf7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -30,10 +30,10 @@ import org.slf4j.LoggerFactory; import org.apache.spark.SparkConf; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.api.shuffle.SupportsTransferTo; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.SupportsTransferTo; +import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; import org.apache.spark.internal.config.package$; import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java deleted file mode 100644 index d6210f045840b..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.storage.BlockManagerId; - -public class DefaultShuffleWriteSupport implements ShuffleWriteSupport { - - private final SparkConf sparkConf; - private final IndexShuffleBlockResolver blockResolver; - private final BlockManagerId shuffleServerId; - - public DefaultShuffleWriteSupport( - SparkConf sparkConf, - IndexShuffleBlockResolver blockResolver, - BlockManagerId shuffleServerId) { - this.sparkConf = sparkConf; - this.blockResolver = blockResolver; - this.shuffleServerId = shuffleServerId; - } - - @Override - public ShuffleMapOutputWriter createMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId, - int numPartitions) { - return new DefaultShuffleMapOutputWriter( - shuffleId, mapId, numPartitions, shuffleServerId, - TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java index c6f43b91f90a0..c6893a49ed238 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java @@ -22,7 +22,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.SparkEnv; -import org.apache.spark.api.shuffle.ShuffleDriverComponents; +import org.apache.spark.shuffle.api.ShuffleDriverComponents; import org.apache.spark.internal.config.package$; import org.apache.spark.storage.BlockManagerMaster; diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bcd47ba0c29c1..98232380cc266 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -23,11 +23,11 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Scheduled import scala.collection.JavaConverters._ -import org.apache.spark.api.shuffle.ShuffleDriverComponents import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} +import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, ThreadUtils, Utils} /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f359022716571..8b9cf7a2e95ec 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -43,7 +43,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.conda.CondaEnvironment import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents} +import org.apache.spark.api.shuffle.ShuffleDriverComponents import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} @@ -58,6 +58,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents} import org.apache.spark.status.{AppStatusSource, AppStatusStore} import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage._ diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index a20a849cc6421..e614dbc8c9542 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,17 +17,14 @@ package org.apache.spark.shuffle -import java.io.InputStream - import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.api.java.Optional -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.api.{ShuffleBlockInfo, ShuffleExecutorComponents} import org.apache.spark.storage.ShuffleBlockAttemptId import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -42,7 +39,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, - shuffleReadSupport: ShuffleReadSupport, + shuffleExecutorComponents: ShuffleExecutorComponents, serializerManager: SerializerManager = SparkEnv.get.serializerManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, sparkConf: SparkConf = SparkEnv.get.conf) @@ -57,7 +54,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val streamsIterator = - shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] { + shuffleExecutorComponents.getPartitionReaders(new Iterable[ShuffleBlockInfo] { override def iterator: Iterator[ShuffleBlockInfo] = { mapOutputTracker .getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) @@ -76,18 +73,18 @@ private[spark] class BlockStoreShuffleReader[K, C]( } }.asJava).iterator() - val retryingWrappedStreams = streamsIterator.asScala.map(readSupportStream => { - if (shuffleReadSupport.shouldWrapStream()) { + val retryingWrappedStreams = streamsIterator.asScala.map(rawReaderStream => { + if (shuffleExecutorComponents.shouldWrapPartitionReaderStream()) { if (compressShuffle) { compressionCodec.compressedInputStream( - serializerManager.wrapForEncryption(readSupportStream)) + serializerManager.wrapForEncryption(rawReaderStream)) } else { - serializerManager.wrapForEncryption(readSupportStream) + serializerManager.wrapForEncryption(rawReaderStream) } } else { // The default implementation checks for corrupt streams, so it will already have // decompressed/decrypted the bytes - readSupportStream + rawReaderStream } }) diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index e18097c2c590a..6ab14e3780572 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -22,17 +22,17 @@ import java.io.InputStream import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockAttemptId, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.shuffle.api.ShuffleBlockInfo +import org.apache.spark.storage.{BlockManager, ShuffleBlockAttemptId, ShuffleBlockFetcherIterator, ShuffleBlockId} class DefaultShuffleReadSupport( blockManager: BlockManager, mapOutputTracker: MapOutputTracker, serializerManager: SerializerManager, - conf: SparkConf) extends ShuffleReadSupport { + conf: SparkConf) { private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 private val maxReqsInFlight = conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) @@ -41,7 +41,7 @@ class DefaultShuffleReadSupport( private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) - override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): + def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { val iterableToReturn = if (blockMetadata.asScala.isEmpty) { @@ -70,8 +70,6 @@ class DefaultShuffleReadSupport( } iterableToReturn.asJava } - - override def shouldWrapStream(): Boolean = false } private class ShuffleBlockFetcherIterable( diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index c364c8d08db20..dbfbe6d689007 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -22,9 +22,11 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import org.apache.spark._ -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents} + +import org.apache.spark.api.shuffle.ShuffleExecutorComponents import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents} import org.apache.spark.util.Utils /** @@ -152,7 +154,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context, env.conf, metrics, - shuffleExecutorComponents.writes()) + shuffleExecutorComponents) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, @@ -161,10 +163,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context.taskAttemptId(), env.conf, metrics, - shuffleExecutorComponents.writes()) + shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents.writes()) + shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 26f3f2267d44d..f0d3368d0a58d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -18,10 +18,10 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ -import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( @@ -29,7 +29,7 @@ private[spark] class SortShuffleWriter[K, V, C]( handle: BaseShuffleHandle[K, V, C], mapId: Int, context: TaskContext, - writeSupport: ShuffleWriteSupport) + shuffleExecutorComponents: ShuffleExecutorComponents) extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency @@ -64,7 +64,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val mapOutputWriter = writeSupport.createMapOutputWriter( + val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions) val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) val location = mapOutputWriter.commitAllPartitions diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 14d34e1c47c8e..0c1af50e73fcf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -24,12 +24,13 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams - import org.apache.spark._ -import org.apache.spark.api.shuffle.{ShuffleMapOutputWriter, ShufflePartitionWriter} + +import org.apache.spark.api.shuffle.ShufflePartitionWriter import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala index 8538a78b377c8..62f17a8e3cfbd 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala @@ -19,9 +19,9 @@ package org.apache.spark.util.collection import java.io.{Closeable, FilterOutputStream, OutputStream} -import org.apache.spark.api.shuffle.ShufflePartitionWriter import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleWriteMetricsReporter +import org.apache.spark.shuffle.api.ShufflePartitionWriter import org.apache.spark.storage.BlockId /** diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 4c2e6ac6474da..698a7f72a722c 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -22,7 +22,6 @@ import java.nio.file.Files; import java.util.*; -import org.mockito.stubbing.Answer; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -36,8 +35,10 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.mockito.stubbing.Answer; import org.apache.spark.HashPartitioner; +import org.apache.spark.MapOutputTracker; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; @@ -56,7 +57,7 @@ import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport; +import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -87,6 +88,8 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; + @Mock(answer = RETURNS_SMART_NULLS) MapOutputTracker mapOutputTracker; + @Mock(answer = RETURNS_SMART_NULLS) SerializerManager serializerManager; @After public void tearDown() { @@ -181,14 +184,18 @@ private UnsafeShuffleWriter createWriter( conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter<>( blockManager, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId()) - ); + new DefaultShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + shuffleBlockResolver)); } private void assertSpillFilesWereCleanedUp() { @@ -548,7 +555,12 @@ public void testPeakMemoryUsed() throws Exception { taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId())); + new DefaultShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + shuffleBlockResolver)); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala index 68bc5c2961e2d..39fbc3a1b5851 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.scheduler import java.util - import org.apache.spark.{FetchFailed, HashPartitioner, ShuffleDependency, SparkConf, Success} -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} + +import org.apache.spark.api.shuffle.{ShuffleDriverComponents, ShuffleExecutorComponents} import org.apache.spark.internal.config import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO import org.apache.spark.storage.BlockManagerId diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala index 0abfa4d8d8413..e2ccb3fdce651 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.shuffle -import java.util +import java.io.InputStream +import java.lang.{Iterable => JIterable} +import java.util.{Map => JMap} import com.google.common.collect.ImmutableMap -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleReadSupport, ShuffleWriteSupport} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.shuffle.api.{ShuffleBlockInfo, ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleMapOutputWriter} +import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { test(s"test serialization of shuffle initialization conf to executors") { @@ -43,7 +44,7 @@ class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext } class TestShuffleDriverComponents extends ShuffleDriverComponents { - override def initializeApplication(): util.Map[String, String] = + override def initializeApplication(): JMap[String, String] = ImmutableMap.of("test-key", "test-value") } @@ -55,21 +56,29 @@ class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { } class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents { - override def initializeExecutor(appId: String, execId: String, - extraConfigs: util.Map[String, String]): Unit = { + + private var delegate = new DefaultShuffleExecutorComponents(sparkConf) + + override def initializeExecutor( + appId: String, execId: String, extraConfigs: JMap[String, String]): Unit = { assert(extraConfigs.get("test-key") == "test-value") + delegate.initializeExecutor(appId, execId, extraConfigs) + } + + override def createMapOutputWriter( + shuffleId: Int, + mapId: Int, + mapTaskAttemptId: Long, + numPartitions: Int): ShuffleMapOutputWriter = { + delegate.createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions) } - override def writes(): ShuffleWriteSupport = { - val blockManager = SparkEnv.get.blockManager - val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) - new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId) + override def getPartitionReaders( + blockMetadata: JIterable[ShuffleBlockInfo]): JIterable[InputStream] = { + delegate.getPartitionReaders(blockMetadata) } - override def reads(): ShuffleReadSupport = { - val blockManager = SparkEnv.get.blockManager - val mapOutputTracker = SparkEnv.get.mapOutputTracker - val serializerManager = SparkEnv.get.serializerManager - new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf) + override def shouldWrapPartitionReaderStream(): Boolean = { + delegate.shouldWrapPartitionReaderStream() } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index dbcf09400c97e..48cb3800b698a 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -18,9 +18,9 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf + import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.sort.io.{DefaultShuffleExecutorComponents} /** * Benchmark to measure performance for aggregate primitives. @@ -49,9 +49,12 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") - val shuffleWriteSupport = - new DefaultShuffleWriteSupport( - conf, blockResolver, BlockManagerId("0", "localhost", 7077, None)) + val shuffleExecutorComponents = new DefaultShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver) val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, @@ -60,7 +63,7 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase taskContext.taskAttemptId(), conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleWriteSupport + shuffleExecutorComponents ) shuffleWriter diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index bd241b5ebfaef..eed9c651ac66d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,12 +33,12 @@ import org.scalatest.BeforeAndAfterEach import scala.util.Random import org.apache.spark._ -import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -49,11 +49,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ + @Mock(answer = RETURNS_SMART_NULLS) private var serializerManager: SerializerManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var mapOutputTracker: MapOutputTracker = _ private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ - private var writeSupport: ShuffleWriteSupport = _ + private var shuffleExecutorComponents: ShuffleExecutorComponents = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) .set("spark.app.id", "sampleApp") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() @@ -140,8 +142,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte metricsSystem = null, taskMetrics = taskMetrics)) - writeSupport = - new DefaultShuffleWriteSupport(conf, blockResolver, blockManager.shuffleServerId) + shuffleExecutorComponents = new DefaultShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver) } override def afterEach(): Unit = { @@ -163,7 +169,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte taskContext.taskAttemptId(), conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport + shuffleExecutorComponents ) writer.write(Iterator.empty) writer.stop( /* success = */ true) @@ -189,7 +195,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte taskContext.taskAttemptId(), transferConf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport + shuffleExecutorComponents ) writer.write(records) writer.stop( /* success = */ true) @@ -214,7 +220,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte taskContext.taskAttemptId(), conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport + shuffleExecutorComponents ) writer.write(records) writer.stop( /* success = */ true) @@ -250,7 +256,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte taskContext.taskAttemptId(), conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport + shuffleExecutorComponents ) intercept[SparkException] { @@ -273,7 +279,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte taskContext.taskAttemptId(), conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport + shuffleExecutorComponents ) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala index 26b92e5203b50..e883eb61a2763 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala @@ -29,7 +29,8 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random -import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkConf, TaskContext} +import org.apache.spark.{HashPartitioner, MapOutputTracker, ShuffleDependency, SparkConf, TaskContext} + import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} @@ -50,9 +51,11 @@ abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase { @Mock(answer = RETURNS_SMART_NULLS) protected var taskContext: TaskContext = _ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEnv: RpcEnv = _ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + // only used to retrieve info about the maps at the beginning, doesn't affect perf + @Mock(answer = RETURNS_SMART_NULLS) protected var mapOutputTracker: MapOutputTracker = _ protected val defaultConf: SparkConf = new SparkConf(loadDefaults = false) - protected val serializer: Serializer = new KryoSerializer(defaultConf) + protected val serializer: Serializer = new KryoSerializer(defaultConf) protected val partitioner: HashPartitioner = new HashPartitioner(10) protected val serializerManager: SerializerManager = new SerializerManager(serializer, defaultConf) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index 7e7a86b3e6b2a..e6471ce2d8e93 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -22,8 +22,7 @@ import org.mockito.Mockito.when import org.apache.spark.{Aggregator, SparkEnv, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -78,16 +77,19 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(taskContext) - val writeSupport = - new DefaultShuffleWriteSupport( - defaultConf, blockResolver, BlockManagerId("0", "localhost", 7077, None)) + val shuffleExecutorComponents = new DefaultShuffleExecutorComponents( + defaultConf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver) val shuffleWriter = new SortShuffleWriter[String, String, String]( blockResolver, shuffleHandle, 0, taskContext, - writeSupport) + shuffleExecutorComponents) shuffleWriter } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index b09ccb334e4f1..04a557cf4384a 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -18,8 +18,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -44,9 +43,12 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) - val shuffleWriteSupport = - new DefaultShuffleWriteSupport( - conf, blockResolver, BlockManagerId("0", "localhost", 7077, None)) + val shuffleExecutorComponents = new DefaultShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver) TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( @@ -57,7 +59,7 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleWriteSupport) + shuffleExecutorComponents) } def writeBenchmarkWithSmallDataset(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala index 3ccb549912782..92960ad956ce2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -30,12 +30,12 @@ import org.mockito.MockitoAnnotations import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach - import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.api.shuffle.SupportsTransferTo + import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.api.SupportsTransferTo import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{ByteBufferInputStream, Utils} From 32d9b697e7e8f78961373fee12502842ea216734 Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 10 Sep 2019 15:58:45 -0700 Subject: [PATCH 02/14] Cherry-pick BypassMergeSortShuffleWriter changes and shuffle writer API changes --- .../spark/shuffle/api/ShuffleDataIO.java | 29 +- .../api/ShuffleExecutorComponents.java | 25 +- .../shuffle/api/ShuffleMapOutputWriter.java | 46 ++- .../shuffle/api/ShufflePartitionWriter.java | 66 ++++- .../api/WritableByteChannelWrapper.java | 42 +++ .../sort/BypassMergeSortShuffleWriter.java | 89 +++--- .../sort/io/LocalDiskShuffleDataIO.java | 40 +++ .../LocalDiskShuffleExecutorComponents.java | 73 +++++ .../io/LocalDiskShuffleMapOutputWriter.java | 261 ++++++++++++++++++ .../spark/internal/config/package.scala | 6 +- .../shuffle/sort/SortShuffleManager.scala | 3 +- .../scala/org/apache/spark/ShuffleSuite.scala | 18 +- .../BypassMergeSortShuffleWriterSuite.scala | 167 ++++------- ...LocalDiskShuffleMapOutputWriterSuite.scala | 147 ++++++++++ 14 files changed, 834 insertions(+), 178 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java index ac3d2a15fec5a..5126f0c3577f8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java @@ -17,18 +17,37 @@ package org.apache.spark.shuffle.api; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Private; /** - * :: Experimental :: - * An interface for launching Shuffle related components - * + * :: Private :: + * An interface for plugging in modules for storing and reading temporary shuffle data. + *

+ * This is the root of a plugin system for storing shuffle bytes to arbitrary storage + * backends in the sort-based shuffle algorithm implemented by the + * {@link org.apache.spark.shuffle.sort.SortShuffleManager}. If another shuffle algorithm is + * needed instead of sort-based shuffle, one should implement + * {@link org.apache.spark.shuffle.ShuffleManager} instead. + *

+ * A single instance of this module is loaded per process in the Spark application. + * The default implementation reads and writes shuffle data from the local disks of + * the executor, and is the implementation of shuffle file storage that has remained + * consistent throughout most of Spark's history. + *

+ * Alternative implementations of shuffle data storage can be loaded via setting + * spark.shuffle.sort.io.plugin.class. * @since 3.0.0 */ -@Experimental +@Private public interface ShuffleDataIO { + String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."; ShuffleDriverComponents driver(); + + /** + * Called once on executor processes to bootstrap the shuffle data storage modules that + * are only invoked on the executors. + */ ShuffleExecutorComponents executor(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java index 87ad794707150..8f3b6671c9482 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -18,21 +18,36 @@ package org.apache.spark.shuffle.api; import java.io.IOException; -import java.io.InputStream; import java.util.Map; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Private; /** - * :: Experimental :: - * An interface for building shuffle support for Executors + * :: Private :: + * An interface for building shuffle support for Executors. * * @since 3.0.0 */ -@Experimental +@Private public interface ShuffleExecutorComponents { + + /** + * Called once per executor to bootstrap this module with state that is specific to + * that executor, specifically the application ID and executor ID. + */ void initializeExecutor(String appId, String execId, Map extraConfigs); + /** + * Called once per map task to create a writer that will be responsible for persisting all the + * partitioned bytes written by that map task. + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param mapId Within the shuffle, the identifier of the map task + * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task + * with the same (shuffleId, mapId) pair can be distinguished by the + * different values of mapTaskAttemptId. + * @param numPartitions The number of partitions that will be written by the map task. Some of + * these partitions may be empty. + */ ShuffleMapOutputWriter createMapOutputWriter( int shuffleId, int mapId, diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 14e020a5c63ad..9135293636e90 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -19,21 +19,53 @@ import java.io.IOException; -import org.apache.spark.annotation.Experimental; -import org.apache.spark.api.java.Optional; -import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.annotation.Private; /** - * :: Experimental :: - * An interface for creating and managing shuffle partition writers + * :: Private :: + * A top-level writer that returns child writers for persisting the output of a map task, + * and then commits all of the writes as one atomic operation. * * @since 3.0.0 */ -@Experimental +@Private public interface ShuffleMapOutputWriter { - ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException; + /** + * Creates a writer that can open an output stream to persist bytes targeted for a given reduce + * partition id. + *

+ * The chunk corresponds to bytes in the given reduce partition. This will not be called twice + * for the same partition within any given map task. The partition identifier will be in the + * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was + * provided upon the creation of this map output writer via + * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}. + *

+ * Calls to this method will be invoked with monotonically increasing reducePartitionIds; each + * call to this method will be called with a reducePartitionId that is strictly greater than + * the reducePartitionIds given to any previous call to this method. This method is not + * guaranteed to be called for every partition id in the above described range. In particular, + * no guarantees are made as to whether or not this method will be called for empty partitions. + */ + ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException; + + /** + * Commits the writes done by all partition writers returned by all calls to this object's + * {@link #getPartitionWriter(int)}. + *

+ * This should ensure that the writes conducted by this module's partition writers are + * available to downstream reduce tasks. If this method throws any exception, this module's + * {@link #abort(Throwable)} method will be invoked before propagating the exception. + *

+ * This can also close any resources and clean up temporary state if necessary. + */ Optional commitAllPartitions() throws IOException; + /** + * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}. + *

+ * This should invalidate the results of writing bytes. This can also close any resources and + * clean up temporary state if necessary. + */ void abort(Throwable error) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java index 3d6c287014bf5..928875156a70f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -18,27 +18,81 @@ package org.apache.spark.shuffle.api; import java.io.IOException; +import java.util.Optional; import java.io.OutputStream; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Private; /** - * :: Experimental :: - * An interface for giving streams / channels for shuffle writes. + * :: Private :: + * An interface for opening streams to persist partition bytes to a backing data store. + *

+ * This writer stores bytes for one (mapper, reducer) pair, corresponding to one shuffle + * block. * * @since 3.0.0 */ -@Experimental +@Private public interface ShufflePartitionWriter { /** - * Opens and returns an underlying {@link OutputStream} that can write bytes to the underlying + * Open and return an {@link OutputStream} that can write bytes to the underlying * data store. + *

+ * This method will only be called once on this partition writer in the map task, to write the + * bytes to the partition. The output stream will only be used to write the bytes for this + * partition. The map task closes this output stream upon writing all the bytes for this + * block, or if the write fails for any reason. + *

+ * Implementations that intend on combining the bytes for all the partitions written by this + * map task should reuse the same OutputStream instance across all the partition writers provided + * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that + * {@link OutputStream#close()} does not close the resource, since it will be reused across + * partition writes. The underlying resources should be cleaned up in + * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#abort(Throwable)}. */ OutputStream openStream() throws IOException; /** - * Get the number of bytes written by this writer's stream returned by {@link #openStream()}. + * Opens and returns a {@link WritableByteChannelWrapper} for transferring bytes from + * input byte channels to the underlying shuffle data store. + *

+ * This method will only be called once on this partition writer in the map task, to write the + * bytes to the partition. The channel will only be used to write the bytes for this + * partition. The map task closes this channel upon writing all the bytes for this + * block, or if the write fails for any reason. + *

+ * Implementations that intend on combining the bytes for all the partitions written by this + * map task should reuse the same channel instance across all the partition writers provided + * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that + * {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel + * will be reused across partition writes. The underlying resources should be cleaned up in + * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + *

+ * This method is primarily for advanced optimizations where bytes can be copied from the input + * spill files to the output channel without copying data into memory. If such optimizations are + * not supported, the implementation should return {@link Optional#empty()}. By default, the + * implementation returns {@link Optional#empty()}. + *

+ * Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the + * underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure + * that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()}, + * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + */ + default Optional openChannelWrapper() throws IOException { + return Optional.empty(); + } + + /** + * Returns the number of bytes written either by this writer's output stream opened by + * {@link #openStream()} or the byte channel opened by {@link #openChannelWrapper()}. + *

+ * This can be different from the number of bytes given by the caller. For example, the + * stream might compress or encrypt the bytes before persisting the data to the backing + * data store. */ long getNumBytesWritten(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java new file mode 100644 index 0000000000000..a204903008a51 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.Closeable; +import java.nio.channels.WritableByteChannel; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * + * A thin wrapper around a {@link WritableByteChannel}. + *

+ * This is primarily provided for the local disk shuffle implementation to provide a + * {@link java.nio.channels.FileChannel} that keeps the channel open across partition writes. + * + * @since 3.0.0 + */ +@Private +public interface WritableByteChannelWrapper extends Closeable { + + /** + * The underlying channel to write bytes into. + */ + WritableByteChannel channel(); +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index dd10bff9ec45f..d6cc1d500e3d1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -23,6 +23,7 @@ import java.io.OutputStream; import java.nio.channels.Channels; import java.nio.channels.FileChannel; +import java.util.Optional; import javax.annotation.Nullable; import org.apache.spark.api.java.Optional; @@ -41,10 +42,10 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.api.SupportsTransferTo; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; -import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.internal.config.package$; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -119,6 +120,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.blockManager = blockManager; final ShuffleDependency dep = handle.dependency(); this.mapId = mapId; + this.mapTaskAttemptId = mapTaskAttemptId; this.shuffleId = dep.shuffleId(); this.mapTaskAttemptId = mapTaskAttemptId; this.partitioner = dep.partitioner(); @@ -136,11 +138,10 @@ public void write(Iterator> records) throws IOException { try { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; - Optional location = mapOutputWriter.commitAllPartitions(); + mapOutputWriter.commitAllPartitions(); mapStatus = MapStatus$.MODULE$.apply( - location.orNull(), - partitionLengths, - mapTaskAttemptId); + blockManager.shuffleServerId(), + partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -173,13 +174,14 @@ public void write(Iterator> records) throws IOException { } partitionLengths = writePartitionedData(mapOutputWriter); - Optional location = mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(location.orNull(), partitionLengths, mapTaskAttemptId); + mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } catch (Exception e) { try { mapOutputWriter.abort(e); } catch (Exception e2) { logger.error("Failed to abort the writer after failing to write map output.", e2); + e.addSuppressed(e2); } throw e; } @@ -208,36 +210,17 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro final File file = partitionWriterSegments[i].file(); ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); if (file.exists()) { - boolean copyThrewException = true; if (transferToEnabled) { - FileInputStream in = new FileInputStream(file); - TransferrableWritableByteChannel outputChannel = null; - try (FileChannel inputChannel = in.getChannel()) { - if (writer instanceof SupportsTransferTo) { - outputChannel = ((SupportsTransferTo) writer).openTransferrableChannel(); - } else { - // Use default transferrable writable channel anyways in order to have parity with - // UnsafeShuffleWriter. - outputChannel = new DefaultTransferrableWritableByteChannel( - Channels.newChannel(writer.openStream())); - } - outputChannel.transferFrom(inputChannel, 0L, inputChannel.size()); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - Closeables.close(outputChannel, copyThrewException); + // Using WritableByteChannelWrapper to make resource closing consistent between + // this implementation and UnsafeShuffleWriter. + Optional maybeOutputChannel = writer.openChannelWrapper(); + if (maybeOutputChannel.isPresent()) { + writePartitionedDataWithChannel(file, maybeOutputChannel.get()); + } else { + writePartitionedDataWithStream(file, writer); } } else { - FileInputStream in = new FileInputStream(file); - OutputStream outputStream = null; - try { - outputStream = writer.openStream(); - Utils.copyStream(in, outputStream, false, false); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - Closeables.close(outputStream, copyThrewException); - } + writePartitionedDataWithStream(file, writer); } if (!file.delete()) { logger.error("Unable to delete file for partition {}", i); @@ -252,6 +235,42 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro return lengths; } + private void writePartitionedDataWithChannel( + File file, + WritableByteChannelWrapper outputChannel) throws IOException { + boolean copyThrewException = true; + try { + FileInputStream in = new FileInputStream(file); + try (FileChannel inputChannel = in.getChannel()) { + Utils.copyFileStreamNIO( + inputChannel, outputChannel.channel(), 0L, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + } finally { + Closeables.close(outputChannel, copyThrewException); + } + } + + private void writePartitionedDataWithStream(File file, ShufflePartitionWriter writer) + throws IOException { + boolean copyThrewException = true; + FileInputStream in = new FileInputStream(file); + OutputStream outputStream; + try { + outputStream = writer.openStream(); + try { + Utils.copyStream(in, outputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(outputStream, copyThrewException); + } + } finally { + Closeables.close(in, copyThrewException); + } + } + @Override public Option stop(boolean success) { if (stopping) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java new file mode 100644 index 0000000000000..cabcb171ac23a --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleDataIO; + +/** + * Implementation of the {@link ShuffleDataIO} plugin system that replicates the local shuffle + * storage and index file functionality that has historically been used from Spark 2.4 and earlier. + */ +public class LocalDiskShuffleDataIO implements ShuffleDataIO { + + private final SparkConf sparkConf; + + public LocalDiskShuffleDataIO(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @Override + public ShuffleExecutorComponents executor() { + return new LocalDiskShuffleExecutorComponents(sparkConf); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java new file mode 100644 index 0000000000000..f32306d4c37c7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.util.Map; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.BlockManager; + +public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents { + + private final SparkConf sparkConf; + private BlockManager blockManager; + private IndexShuffleBlockResolver blockResolver; + + public LocalDiskShuffleExecutorComponents(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @VisibleForTesting + public LocalDiskShuffleExecutorComponents( + SparkConf sparkConf, + BlockManager blockManager, + IndexShuffleBlockResolver blockResolver) { + this.sparkConf = sparkConf; + this.blockManager = blockManager; + this.blockResolver = blockResolver; + } + + @Override + public void initializeExecutor(String appId, String execId, Map extraConfigs) { + blockManager = SparkEnv.get().blockManager(); + if (blockManager == null) { + throw new IllegalStateException("No blockManager available from the SparkEnv."); + } + blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + } + + @Override + public ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId, + int numPartitions) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return new LocalDiskShuffleMapOutputWriter( + shuffleId, mapId, numPartitions, blockResolver, sparkConf); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java new file mode 100644 index 0000000000000..add4634a61fb5 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.internal.config.package$; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.util.Utils; + +/** + * Implementation of {@link ShuffleMapOutputWriter} that replicates the functionality of shuffle + * persisting shuffle data to local disk alongside index files, identical to Spark's historic + * canonical shuffle storage mechanism. + */ +public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { + + private static final Logger log = + LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class); + + private final int shuffleId; + private final int mapId; + private final IndexShuffleBlockResolver blockResolver; + private final long[] partitionLengths; + private final int bufferSize; + private int lastPartitionId = -1; + private long currChannelPosition; + + private final File outputFile; + private File outputTempFile; + private FileOutputStream outputFileStream; + private FileChannel outputFileChannel; + private BufferedOutputStream outputBufferedFileStream; + + public LocalDiskShuffleMapOutputWriter( + int shuffleId, + int mapId, + int numPartitions, + IndexShuffleBlockResolver blockResolver, + SparkConf sparkConf) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.blockResolver = blockResolver; + this.bufferSize = + (int) (long) sparkConf.get( + package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + this.partitionLengths = new long[numPartitions]; + this.outputFile = blockResolver.getDataFile(shuffleId, mapId); + this.outputTempFile = null; + } + + @Override + public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException { + if (reducePartitionId <= lastPartitionId) { + throw new IllegalArgumentException("Partitions should be requested in increasing order."); + } + lastPartitionId = reducePartitionId; + if (outputTempFile == null) { + outputTempFile = Utils.tempFileWith(outputFile); + } + if (outputFileChannel != null) { + currChannelPosition = outputFileChannel.position(); + } else { + currChannelPosition = 0L; + } + return new LocalDiskShufflePartitionWriter(reducePartitionId); + } + + @Override + public void commitAllPartitions() throws IOException { + cleanUp(); + File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + } + + @Override + public void abort(Throwable error) throws IOException { + cleanUp(); + if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { + log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath()); + } + } + + private void cleanUp() throws IOException { + if (outputBufferedFileStream != null) { + outputBufferedFileStream.close(); + } + if (outputFileChannel != null) { + outputFileChannel.close(); + } + if (outputFileStream != null) { + outputFileStream.close(); + } + } + + private void initStream() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + } + if (outputBufferedFileStream == null) { + outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize); + } + } + + private void initChannel() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + } + if (outputFileChannel == null) { + outputFileChannel = outputFileStream.getChannel(); + } + } + + private class LocalDiskShufflePartitionWriter implements ShufflePartitionWriter { + + private final int partitionId; + private PartitionWriterStream partStream = null; + private PartitionWriterChannel partChannel = null; + + private LocalDiskShufflePartitionWriter(int partitionId) { + this.partitionId = partitionId; + } + + @Override + public OutputStream openStream() throws IOException { + if (partStream == null) { + if (outputFileChannel != null) { + throw new IllegalStateException("Requested an output channel for a previous write but" + + " now an output stream has been requested. Should not be using both channels" + + " and streams to write."); + } + initStream(); + partStream = new PartitionWriterStream(partitionId); + } + return partStream; + } + + @Override + public Optional openChannelWrapper() throws IOException { + if (partChannel == null) { + if (partStream != null) { + throw new IllegalStateException("Requested an output stream for a previous write but" + + " now an output channel has been requested. Should not be using both channels" + + " and streams to write."); + } + initChannel(); + partChannel = new PartitionWriterChannel(partitionId); + } + return Optional.of(partChannel); + } + + @Override + public long getNumBytesWritten() { + if (partChannel != null) { + try { + return partChannel.getCount(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else if (partStream != null) { + return partStream.getCount(); + } else { + // Assume an empty partition if stream and channel are never created + return 0; + } + } + } + + private class PartitionWriterStream extends OutputStream { + private final int partitionId; + private int count = 0; + private boolean isClosed = false; + + PartitionWriterStream(int partitionId) { + this.partitionId = partitionId; + } + + public int getCount() { + return count; + } + + @Override + public void write(int b) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(b); + count++; + } + + @Override + public void write(byte[] buf, int pos, int length) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(buf, pos, length); + count += length; + } + + @Override + public void close() { + isClosed = true; + partitionLengths[partitionId] = count; + } + + private void verifyNotClosed() { + if (isClosed) { + throw new IllegalStateException("Attempting to write to a closed block output stream."); + } + } + } + + private class PartitionWriterChannel implements WritableByteChannelWrapper { + + private final int partitionId; + + PartitionWriterChannel(int partitionId) { + this.partitionId = partitionId; + } + + public long getCount() throws IOException { + long writtenPosition = outputFileChannel.position(); + return writtenPosition - currChannelPosition; + } + + @Override + public WritableByteChannel channel() { + return outputFileChannel; + } + + @Override + public void close() throws IOException { + partitionLengths[partitionId] = getCount(); + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a852a06be9125..23607e7ad975f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode} -import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -770,10 +770,10 @@ package object config { .createWithDefault(false) private[spark] val SHUFFLE_IO_PLUGIN_CLASS = - ConfigBuilder("spark.shuffle.io.plugin.class") + ConfigBuilder("spark.shuffle.sort.io.plugin.class") .doc("Name of the class to use for shuffle IO.") .stringConf - .createWithDefault(classOf[DefaultShuffleDataIO].getName) + .createWithDefault(classOf[LocalDiskShuffleDataIO].getName) private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index dbfbe6d689007..0308b94fd14ba 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -22,8 +22,6 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import org.apache.spark._ - -import org.apache.spark.api.shuffle.ShuffleExecutorComponents import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents} @@ -148,6 +146,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, + shuffleBlockResolver, context.taskMemoryManager(), unsafeShuffleHandle, mapId, diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 1cd7296e9de53..8da45ac6261b2 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -383,14 +383,15 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // simultaneously, and everything is still OK def writeAndClose( - writer: ShuffleWriter[Int, Int], - taskContext: TaskContext)( - iter: Iterator[(Int, Int)]): Option[MapStatus] = { - TaskContext.setTaskContext(taskContext) - val files = writer.write(iter) - val status = writer.stop(true) - TaskContext.unset - status + writer: ShuffleWriter[Int, Int], + taskContext: TaskContext)( + iter: Iterator[(Int, Int)]): Option[MapStatus] = { + try { + val files = writer.write(iter) + writer.stop(true) + } finally { + TaskContext.unset() + } } val interleaver = new InterleaveIterators( data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2)) @@ -412,6 +413,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) + TaskContext.unset() val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index eed9c651ac66d..16fcb89a32fce 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -64,39 +64,42 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte override def beforeEach(): Unit = { super.beforeEach() + MockitoAnnotations.initMocks(this) tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) taskMetrics = new TaskMetrics - MockitoAnnotations.initMocks(this) shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( shuffleId = 0, numMaps = 2, dependency = dependency ) + val memoryManager = new TestMemoryManager(conf) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) when(dependency.partitioner).thenReturn(new HashPartitioner(7)) when(dependency.serializer).thenReturn(new JavaSerializer(conf)) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - doAnswer(new Answer[Void] { - def answer(invocationOnMock: InvocationOnMock): Void = { - val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + + when(blockResolver.writeIndexFileAndCommit( + anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + .thenAnswer { invocationOnMock => + val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) } null } - }).when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) - when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(blockManager.getDiskWriter( any[BlockId], any[File], any[SerializerInstance], anyInt(), - any[ShuffleWriteMetrics] - )).thenAnswer(new Answer[DiskBlockObjectWriter] { - override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { + any[ShuffleWriteMetrics])) + .thenAnswer { invocation => val args = invocation.getArguments val manager = new SerializerManager(new JavaSerializer(conf), conf) new DiskBlockObjectWriter( @@ -106,48 +109,24 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte args(3).asInstanceOf[Int], syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], - blockId = args(0).asInstanceOf[BlockId] - ) + blockId = args(0).asInstanceOf[BlockId]) } - }) - when(diskBlockManager.createTempShuffleBlock()).thenAnswer( - new Answer[(TempShuffleBlockId, File)] { - override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = { - val blockId = new TempShuffleBlockId(UUID.randomUUID) - val file = new File(tempDir, blockId.name) - blockIdToFileMap.put(blockId, file) - temporaryFilesCreated += file - (blockId, file) - } - }) - when(diskBlockManager.getFile(any[BlockId])).thenAnswer( - new Answer[File] { - override def answer(invocation: InvocationOnMock): File = { - blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get - } - }) - val memoryManager = new TestMemoryManager(conf) - val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + when(diskBlockManager.createTempShuffleBlock()) + .thenAnswer { _ => + val blockId = new TempShuffleBlockId(UUID.randomUUID) + val file = new File(tempDir, blockId.name) + blockIdToFileMap.put(blockId, file) + temporaryFilesCreated += file + (blockId, file) + } - TaskContext.setTaskContext(new TaskContextImpl( - stageId = 0, - stageAttemptNumber = 0, - partitionId = 0, - taskAttemptId = Random.nextInt(10000), - attemptNumber = 0, - taskMemoryManager = taskMemoryManager, - localProperties = new Properties, - metricsSystem = null, - taskMetrics = taskMetrics)) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer { invocation => + blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId]) + } - shuffleExecutorComponents = new DefaultShuffleExecutorComponents( - conf, - blockManager, - mapOutputTracker, - serializerManager, - blockResolver) + shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, blockManager, blockResolver) } override def afterEach(): Unit = { @@ -166,11 +145,11 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - taskContext.taskAttemptId(), + 0L, // MapTaskAttemptId conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleExecutorComponents - ) + shuffleExecutorComponents) + writer.write(Iterator.empty) writer.stop( /* success = */ true) assert(writer.getPartitionLengths.sum === 0) @@ -184,55 +163,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(taskMetrics.memoryBytesSpilled === 0) } - test("write with some empty partitions") { - val transferConf = conf.clone.set("spark.file.transferTo", "false") - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - val writer = new BypassMergeSortShuffleWriter[Int, Int]( - blockManager, - shuffleHandle, - 0, // MapId - taskContext.taskAttemptId(), - transferConf, - taskContext.taskMetrics().shuffleWriteMetrics, - shuffleExecutorComponents - ) - writer.write(records) - writer.stop( /* success = */ true) - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - } - - test("write with some empty partitions with transferTo") { - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - val writer = new BypassMergeSortShuffleWriter[Int, Int]( - blockManager, - shuffleHandle, - 0, // MapId - taskContext.taskAttemptId(), - conf, - taskContext.taskMetrics().shuffleWriteMetrics, - shuffleExecutorComponents - ) - writer.write(records) - writer.stop( /* success = */ true) - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) + Seq(true, false).foreach { transferTo => + test(s"write with some empty partitions - transferTo $transferTo") { + val transferConf = conf.clone.set("spark.file.transferTo", transferTo.toString) + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + shuffleHandle, + 0, // MapId + 0L, + transferConf, + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleExecutorComponents) + writer.write(records) + writer.stop( /* success = */ true) + assert(temporaryFilesCreated.nonEmpty) + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } } test("only generate temp shuffle file for non-empty partition") { @@ -253,11 +208,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - taskContext.taskAttemptId(), + 0L, conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleExecutorComponents - ) + shuffleExecutorComponents) intercept[SparkException] { writer.write(records) @@ -276,11 +230,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - taskContext.taskAttemptId(), + 0L, conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleExecutorComponents - ) + shuffleExecutorComponents) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { if (i == 99990) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala new file mode 100644 index 0000000000000..5693b9824523a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import java.io.{File, FileInputStream} +import java.nio.channels.FileChannel +import java.nio.file.Files +import java.util.Arrays + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.Mock +import org.mockito.Mockito.when +import org.mockito.MockitoAnnotations +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.util.Utils + +class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) + private var blockResolver: IndexShuffleBlockResolver = _ + + private val NUM_PARTITIONS = 4 + private val data: Array[Array[Byte]] = (0 until NUM_PARTITIONS).map { p => + if (p == 3) { + Array.emptyByteArray + } else { + (0 to p * 10).map(_ + p).map(_.toByte).toArray + } + }.toArray + + private val partitionLengths = data.map(_.length) + + private var tempFile: File = _ + private var mergedOutputFile: File = _ + private var tempDir: File = _ + private var partitionSizesInMergedFile: Array[Long] = _ + private var conf: SparkConf = _ + private var mapOutputWriter: LocalDiskShuffleMapOutputWriter = _ + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def beforeEach(): Unit = { + MockitoAnnotations.initMocks(this) + tempDir = Utils.createTempDir() + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) + tempFile = File.createTempFile("tempfile", "", tempDir) + partitionSizesInMergedFile = null + conf = new SparkConf() + .set("spark.app.id", "example.spark.app") + .set("spark.shuffle.unsafe.file.output.buffer", "16k") + when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) + when(blockResolver.writeIndexFileAndCommit( + anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + .thenAnswer { invocationOnMock => + partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + mergedOutputFile.delete() + tmp.renameTo(mergedOutputFile) + } + null + } + mapOutputWriter = new LocalDiskShuffleMapOutputWriter( + 0, + 0, + NUM_PARTITIONS, + blockResolver, + conf) + } + + test("writing to an outputstream") { + (0 until NUM_PARTITIONS).foreach { p => + val writer = mapOutputWriter.getPartitionWriter(p) + val stream = writer.openStream() + data(p).foreach { i => stream.write(i) } + stream.close() + intercept[IllegalStateException] { + stream.write(p) + } + assert(writer.getNumBytesWritten === data(p).length) + } + verifyWrittenRecords() + } + + test("writing to a channel") { + (0 until NUM_PARTITIONS).foreach { p => + val writer = mapOutputWriter.getPartitionWriter(p) + val outputTempFile = File.createTempFile("channelTemp", "", tempDir) + Files.write(outputTempFile.toPath, data(p)) + val tempFileInput = new FileInputStream(outputTempFile) + val channel = writer.openChannelWrapper() + Utils.tryWithResource(new FileInputStream(outputTempFile)) { tempFileInput => + Utils.tryWithResource(writer.openChannelWrapper().get) { channelWrapper => + assert(channelWrapper.channel().isInstanceOf[FileChannel], + "Underlying channel should be a file channel") + Utils.copyFileStreamNIO( + tempFileInput.getChannel, channelWrapper.channel(), 0L, data(p).length) + } + } + assert(writer.getNumBytesWritten === data(p).length, + s"Partition $p does not have the correct number of bytes.") + } + verifyWrittenRecords() + } + + private def readRecordsFromFile() = { + val mergedOutputBytes = Files.readAllBytes(mergedOutputFile.toPath) + val result = (0 until NUM_PARTITIONS).map { part => + val startOffset = data.slice(0, part).map(_.length).sum + val partitionSize = data(part).length + Arrays.copyOfRange(mergedOutputBytes, startOffset, startOffset + partitionSize) + }.toArray + result + } + + private def verifyWrittenRecords(): Unit = { + mapOutputWriter.commitAllPartitions() + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile()) + } +} From 2f756088d8b13438c393bb5426076793a50c471b Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 26 Aug 2019 10:39:29 -0700 Subject: [PATCH 03/14] [SPARK-28607][CORE][SHUFFLE] Don't store partition lengths twice The shuffle writer API introduced in SPARK-28209 has a flaw that leads to a memory usage regression - we ended up tracking the partition lengths in two places. Here, we modify the API slightly to avoid redundant tracking. The implementation of the shuffle writer plugin is now responsible for tracking the lengths of partitions, and propagating this back up to the higher shuffle writer as part of the commitAllPartitions API. Existing unit tests. Closes #25341 from mccheah/dont-redundantly-store-part-lengths. Authored-by: mcheah Signed-off-by: Marcelo Vanzin --- .../api/MapOutputWriterCommitMessage.java | 35 +++++++++ .../shuffle/api/ShuffleMapOutputWriter.java | 13 +++- .../sort/BypassMergeSortShuffleWriter.java | 78 +++++++++---------- .../shuffle/sort/UnsafeShuffleWriter.java | 43 ++++------ .../io/LocalDiskShuffleMapOutputWriter.java | 3 +- .../shuffle/sort/SortShuffleWriter.scala | 9 ++- .../util/collection/ExternalSorter.scala | 12 +-- ...LocalDiskShuffleMapOutputWriterSuite.scala | 6 +- 8 files changed, 110 insertions(+), 89 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java diff --git a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java new file mode 100644 index 0000000000000..e07efd57cc07f --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java @@ -0,0 +1,35 @@ +package org.apache.spark.shuffle.api; + +import java.util.Optional; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.storage.BlockManagerId; + +@Experimental +public final class MapOutputWriterCommitMessage { + + private final long[] partitionLengths; + private final Optional location; + + private MapOutputWriterCommitMessage(long[] partitionLengths, Optional location) { + this.partitionLengths = partitionLengths; + this.location = location; + } + + public static MapOutputWriterCommitMessage of(long[] partitionLengths) { + return new MapOutputWriterCommitMessage(partitionLengths, Optional.empty()); + } + + public static MapOutputWriterCommitMessage of( + long[] partitionLengths, java.util.Optional location) { + return new MapOutputWriterCommitMessage(partitionLengths, location); + } + + public long[] getPartitionLengths() { + return partitionLengths; + } + + public Optional getLocation() { + return location; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 9135293636e90..8fcc73ba3c9b2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -51,15 +51,24 @@ public interface ShuffleMapOutputWriter { /** * Commits the writes done by all partition writers returned by all calls to this object's - * {@link #getPartitionWriter(int)}. + * {@link #getPartitionWriter(int)}, and returns a bundle of metadata associated with the + * behavior of the write. *

* This should ensure that the writes conducted by this module's partition writers are * available to downstream reduce tasks. If this method throws any exception, this module's * {@link #abort(Throwable)} method will be invoked before propagating the exception. *

* This can also close any resources and clean up temporary state if necessary. + *

+ * The returned array should contain two sets of metadata: + * + * 1. For each partition from (0) to (numPartitions - 1), the number of bytes written by + * the partition writer for that partition id. + * + * 2. If the partition data was stored on the local disk of this executor, also provide + * the block manager id where these bytes can be fetched from. */ - Optional commitAllPartitions() throws IOException; + MapOutputWriterCommitMessage commitAllPartitions() throws IOException; /** * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}. diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index d6cc1d500e3d1..94ad5fc66185b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,13 +21,10 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.OutputStream; -import java.nio.channels.Channels; import java.nio.channels.FileChannel; import java.util.Optional; import javax.annotation.Nullable; -import org.apache.spark.api.java.Optional; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import scala.None$; import scala.Option; import scala.Product2; @@ -42,6 +39,7 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; @@ -97,7 +95,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private DiskBlockObjectWriter[] partitionWriters; private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; - private long[] partitionLengths; + private MapOutputWriterCommitMessage commitMessage; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -122,7 +120,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.mapId = mapId; this.mapTaskAttemptId = mapTaskAttemptId; this.shuffleId = dep.shuffleId(); - this.mapTaskAttemptId = mapTaskAttemptId; this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; @@ -137,11 +134,11 @@ public void write(Iterator> records) throws IOException { .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); try { if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; - mapOutputWriter.commitAllPartitions(); + commitMessage = mapOutputWriter.commitAllPartitions(); mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), - partitionLengths); + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + mapTaskAttemptId); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -173,9 +170,11 @@ public void write(Iterator> records) throws IOException { } } - partitionLengths = writePartitionedData(mapOutputWriter); - mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + commitMessage = writePartitionedData(mapOutputWriter); + mapStatus = MapStatus$.MODULE$.apply( + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + mapTaskAttemptId); } catch (Exception e) { try { mapOutputWriter.abort(e); @@ -189,7 +188,7 @@ public void write(Iterator> records) throws IOException { @VisibleForTesting long[] getPartitionLengths() { - return partitionLengths; + return commitMessage.getPartitionLengths(); } /** @@ -197,42 +196,39 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException { + private MapOutputWriterCommitMessage writePartitionedData( + ShuffleMapOutputWriter mapOutputWriter) throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; - if (partitionWriters == null) { - // We were passed an empty iterator - return lengths; - } - final long writeStartTime = System.nanoTime(); - try { - for (int i = 0; i < numPartitions; i++) { - final File file = partitionWriterSegments[i].file(); - ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); - if (file.exists()) { - if (transferToEnabled) { - // Using WritableByteChannelWrapper to make resource closing consistent between - // this implementation and UnsafeShuffleWriter. - Optional maybeOutputChannel = writer.openChannelWrapper(); - if (maybeOutputChannel.isPresent()) { - writePartitionedDataWithChannel(file, maybeOutputChannel.get()); + if (partitionWriters != null) { + final long writeStartTime = System.nanoTime(); + try { + for (int i = 0; i < numPartitions; i++) { + final File file = partitionWriterSegments[i].file(); + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); + if (file.exists()) { + if (transferToEnabled) { + // Using WritableByteChannelWrapper to make resource closing consistent between + // this implementation and UnsafeShuffleWriter. + Optional maybeOutputChannel = writer.openChannelWrapper(); + if (maybeOutputChannel.isPresent()) { + writePartitionedDataWithChannel(file, maybeOutputChannel.get()); + } else { + writePartitionedDataWithStream(file, writer); + } } else { writePartitionedDataWithStream(file, writer); } - } else { - writePartitionedDataWithStream(file, writer); - } - if (!file.delete()) { - logger.error("Unable to delete file for partition {}", i); + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } } } - lengths[i] = writer.getNumBytesWritten(); + } finally { + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } - } finally { - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + partitionWriters = null; } - partitionWriters = null; - return lengths; + return mapOutputWriter.commitAllPartitions(); } private void writePartitionedDataWithChannel( diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 441718126bc92..745f4785ce01b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -23,9 +23,6 @@ import java.nio.channels.FileChannel; import java.util.Iterator; -import org.apache.spark.api.java.Optional; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; -import org.apache.spark.storage.BlockManagerId; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -40,10 +37,12 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.SupportsTransferTo; +import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; @@ -222,11 +221,10 @@ void closeAndWriteOutput() throws IOException { mapId, taskContext.taskAttemptId(), partitioner.numPartitions()); - final long[] partitionLengths; - Optional location; + MapOutputWriterCommitMessage commitMessage; try { try { - partitionLengths = mergeSpills(spills, mapWriter); + mergeSpills(spills, mapWriter); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { @@ -234,7 +232,7 @@ void closeAndWriteOutput() throws IOException { } } } - location = mapWriter.commitAllPartitions(); + commitMessage = mapWriter.commitAllPartitions(); } catch (Exception e) { try { mapWriter.abort(e); @@ -244,7 +242,9 @@ void closeAndWriteOutput() throws IOException { throw e; } mapStatus = MapStatus$.MODULE$.apply( - location.orNull(), partitionLengths, taskContext.attemptNumber()); + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + taskContext.attemptNumber()); } @VisibleForTesting @@ -276,7 +276,7 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, + private void mergeSpills(SpillInfo[] spills, ShuffleMapOutputWriter mapWriter) throws IOException { final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); @@ -285,12 +285,8 @@ private long[] mergeSpills(SpillInfo[] spills, final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); - final int numPartitions = partitioner.numPartitions(); - long[] partitionLengths = new long[numPartitions]; try { - if (spills.length == 0) { - return partitionLengths; - } else { + if (spills.length > 0) { // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill @@ -307,14 +303,14 @@ private long[] mergeSpills(SpillInfo[] spills, // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, mapWriter); + mergeSpillsWithTransferTo(spills, mapWriter); } else { logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, null); + mergeSpillsWithFileStream(spills, mapWriter, null); } } else { logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); + mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that @@ -322,7 +318,6 @@ private long[] mergeSpills(SpillInfo[] spills, // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - return partitionLengths; } } catch (IOException e) { throw e; @@ -345,12 +340,11 @@ private long[] mergeSpills(SpillInfo[] spills, * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithFileStream( + private void mergeSpillsWithFileStream( SpillInfo[] spills, ShuffleMapOutputWriter mapWriter, @Nullable CompressionCodec compressionCodec) throws IOException { final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; boolean threwException = true; @@ -395,7 +389,6 @@ private long[] mergeSpillsWithFileStream( Closeables.close(partitionOutput, copyThrewExecption); } long numBytesWritten = writer.getNumBytesWritten(); - partitionLengths[partition] = numBytesWritten; writeMetrics.incBytesWritten(numBytesWritten); } threwException = false; @@ -406,7 +399,6 @@ private long[] mergeSpillsWithFileStream( Closeables.close(stream, threwException); } } - return partitionLengths; } /** @@ -418,11 +410,10 @@ private long[] mergeSpillsWithFileStream( * @param mapWriter the map output writer to use for output. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithTransferTo( + private void mergeSpillsWithTransferTo( SpillInfo[] spills, ShuffleMapOutputWriter mapWriter) throws IOException { final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; @@ -455,7 +446,6 @@ private long[] mergeSpillsWithTransferTo( Closeables.close(partitionChannel, copyThrewExecption); } long numBytes = writer.getNumBytesWritten(); - partitionLengths[partition] = numBytes; writeMetrics.incBytesWritten(numBytes); } threwException = false; @@ -467,7 +457,6 @@ private long[] mergeSpillsWithTransferTo( Closeables.close(spillInputChannels[i], threwException); } } - return partitionLengths; } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index add4634a61fb5..7fc19b1270a46 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -96,10 +96,11 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I } @Override - public void commitAllPartitions() throws IOException { + public long[] commitAllPartitions() throws IOException { cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + return partitionLengths; } @Override diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index f0d3368d0a58d..626f5fd91c291 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -66,9 +66,12 @@ private[spark] class SortShuffleWriter[K, V, C]( // (see SPARK-3570). val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions) - val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - val location = mapOutputWriter.commitAllPartitions - mapStatus = MapStatus(location.orNull, partitionLengths, context.taskAttemptId()) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + val commitMessage = mapOutputWriter.commitAllPartitions + mapStatus = MapStatus( + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths, + context.taskAttemptId()) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 0c1af50e73fcf..2f967a3cdfae0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -729,9 +729,7 @@ private[spark] class ExternalSorter[K, V, C]( * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ def writePartitionedMapOutput( - shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { - // Track location of each range in the map output - val lengths = new Array[Long](numPartitions) + shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Unit = { if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer @@ -757,9 +755,6 @@ private[spark] class ExternalSorter[K, V, C]( partitionPairsWriter.close() } } - if (partitionWriter != null) { - lengths(partitionId) = partitionWriter.getNumBytesWritten - } } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. @@ -791,17 +786,12 @@ private[spark] class ExternalSorter[K, V, C]( partitionPairsWriter.close() } } - if (partitionWriter != null) { - lengths(id) = partitionWriter.getNumBytesWritten - } } } context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - - lengths } def stop(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index 5693b9824523a..5156cc2cc47a6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -102,7 +102,6 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA intercept[IllegalStateException] { stream.write(p) } - assert(writer.getNumBytesWritten === data(p).length) } verifyWrittenRecords() } @@ -122,8 +121,6 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA tempFileInput.getChannel, channelWrapper.channel(), 0L, data(p).length) } } - assert(writer.getNumBytesWritten === data(p).length, - s"Partition $p does not have the correct number of bytes.") } verifyWrittenRecords() } @@ -139,8 +136,9 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA } private def verifyWrittenRecords(): Unit = { - mapOutputWriter.commitAllPartitions() + val committedLengths = mapOutputWriter.commitAllPartitions() assert(partitionSizesInMergedFile === partitionLengths) + assert(committedLengths === partitionLengths) assert(mergedOutputFile.length() === partitionLengths.sum) assert(data === readRecordsFromFile()) } From 3797cb2847e5eda44e841ce22713cb711795684b Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 30 Aug 2019 09:43:07 -0700 Subject: [PATCH 04/14] [SPARK-28571][CORE][SHUFFLE] Use the shuffle writer plugin for the SortShuffleWriter Use the shuffle writer APIs introduced in SPARK-28209 in the sort shuffle writer. Existing unit tests were changed to use the plugin instead, and they used the local disk version to ensure that there were no regressions. Closes #25342 from mccheah/shuffle-writer-refactor-sort-shuffle-writer. Lead-authored-by: mcheah Co-authored-by: mccheah Signed-off-by: Marcelo Vanzin --- .../shuffle/ShufflePartitionPairsWriter.scala | 126 ++++++++++++++++++ .../shuffle/sort/SortShuffleWriter.scala | 4 +- .../spark/storage/DiskBlockObjectWriter.scala | 2 +- .../util/collection/ExternalSorter.scala | 37 ++--- .../spark/util/collection/PairsWriter.scala | 5 + .../shuffle/sort/SortShuffleWriterSuite.scala | 108 +++++++++++++++ 6 files changed, 262 insertions(+), 20 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala new file mode 100644 index 0000000000000..a988c5e126a76 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.{Closeable, IOException, OutputStream} + +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.api.ShufflePartitionWriter +import org.apache.spark.storage.BlockId +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.PairsWriter + +/** + * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an + * arbitrary partition writer instead of writing to local disk through the block manager. + */ +private[spark] class ShufflePartitionPairsWriter( + partitionWriter: ShufflePartitionWriter, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + blockId: BlockId, + writeMetrics: ShuffleWriteMetricsReporter) + extends PairsWriter with Closeable { + + private var isClosed = false + private var partitionStream: OutputStream = _ + private var wrappedStream: OutputStream = _ + private var objOut: SerializationStream = _ + private var numRecordsWritten = 0 + private var curNumBytesWritten = 0L + + override def write(key: Any, value: Any): Unit = { + if (isClosed) { + throw new IOException("Partition pairs writer is already closed.") + } + if (objOut == null) { + open() + } + objOut.writeKey(key) + objOut.writeValue(value) + recordWritten() + } + + private def open(): Unit = { + try { + partitionStream = partitionWriter.openStream + wrappedStream = serializerManager.wrapStream(blockId, partitionStream) + objOut = serializerInstance.serializeStream(wrappedStream) + } catch { + case e: Exception => + Utils.tryLogNonFatalError { + close() + } + throw e + } + } + + override def close(): Unit = { + if (!isClosed) { + Utils.tryWithSafeFinally { + Utils.tryWithSafeFinally { + objOut = closeIfNonNull(objOut) + // Setting these to null will prevent the underlying streams from being closed twice + // just in case any stream's close() implementation is not idempotent. + wrappedStream = null + partitionStream = null + } { + // Normally closing objOut would close the inner streams as well, but just in case there + // was an error in initialization etc. we make sure we clean the other streams up too. + Utils.tryWithSafeFinally { + wrappedStream = closeIfNonNull(wrappedStream) + // Same as above - if wrappedStream closes then assume it closes underlying + // partitionStream and don't close again in the finally + partitionStream = null + } { + partitionStream = closeIfNonNull(partitionStream) + } + } + updateBytesWritten() + } { + isClosed = true + } + } + } + + private def closeIfNonNull[T <: Closeable](closeable: T): T = { + if (closeable != null) { + closeable.close() + } + null.asInstanceOf[T] + } + + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + private def recordWritten(): Unit = { + numRecordsWritten += 1 + writeMetrics.incRecordsWritten(1) + + if (numRecordsWritten % 16384 == 0) { + updateBytesWritten() + } + } + + private def updateBytesWritten(): Unit = { + val numBytesWritten = partitionWriter.getNumBytesWritten + val bytesWrittenDiff = numBytesWritten - curNumBytesWritten + writeMetrics.incBytesWritten(bytesWrittenDiff) + curNumBytesWritten = numBytesWritten + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 626f5fd91c291..0082b4c9c6b24 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -67,9 +67,9 @@ private[spark] class SortShuffleWriter[K, V, C]( val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions) sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - val commitMessage = mapOutputWriter.commitAllPartitions + val commitMessage = mapOutputWriter.commitAllPartitions() mapStatus = MapStatus( - commitMessage.getLocation().orElse(null), + commitMessage.getLocation.orElse(null), commitMessage.getPartitionLengths, context.taskAttemptId()) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index f9f4e3594e4f9..758621c52495b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -234,7 +234,7 @@ private[spark] class DiskBlockObjectWriter( /** * Writes a key-value pair. */ - def write(key: Any, value: Any) { + override def write(key: Any, value: Any) { if (!streamOpen) { open() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 2f967a3cdfae0..1c8334be9a2bb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -24,14 +24,18 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams -import org.apache.spark._ +import com.google.common.io.{ByteStreams, Closeables} -import org.apache.spark.api.shuffle.ShufflePartitionWriter +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} +import org.apache.spark.shuffle.ShufflePartitionPairsWriter +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} +import org.apache.spark.util.{Utils => TryUtils} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -676,9 +680,9 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project. - * We should figure out an alternative way to test that so that we can remove this otherwise - * unused code path. + * TODO(SPARK-28764): remove this, as this is only used by UnsafeRowSerializerSuite in the SQL + * project. We should figure out an alternative way to test that so that we can remove this + * otherwise unused code path. */ def writePartitionedFile( blockId: BlockId, @@ -729,7 +733,10 @@ private[spark] class ExternalSorter[K, V, C]( * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ def writePartitionedMapOutput( - shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Unit = { + shuffleId: Int, + mapId: Int, + mapOutputWriter: ShuffleMapOutputWriter): Unit = { + var nextPartitionId = 0 if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer @@ -738,7 +745,7 @@ private[spark] class ExternalSorter[K, V, C]( val partitionId = it.nextPartition() var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null - try { + TryUtils.tryWithSafeFinally { partitionWriter = mapOutputWriter.getPartitionWriter(partitionId) val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) partitionPairsWriter = new ShufflePartitionPairsWriter( @@ -750,25 +757,20 @@ private[spark] class ExternalSorter[K, V, C]( while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(partitionPairsWriter) } - } finally { + } { if (partitionPairsWriter != null) { partitionPairsWriter.close() } } + nextPartitionId = partitionId + 1 } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. for ((id, elements) <- this.partitionedIterator) { - // The contract for the plugin is that we will ask for a writer for every partition - // even if it's empty. However, the external sorter will return non-contiguous - // partition ids. So this loop "backfills" the empty partitions that form the gaps. - - // The algorithm as a whole is correct because the partition ids are returned by the - // iterator in ascending order. val blockId = ShuffleBlockId(shuffleId, mapId, id) var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null - try { + TryUtils.tryWithSafeFinally { partitionWriter = mapOutputWriter.getPartitionWriter(id) partitionPairsWriter = new ShufflePartitionPairsWriter( partitionWriter, @@ -781,11 +783,12 @@ private[spark] class ExternalSorter[K, V, C]( partitionPairsWriter.write(elem._1, elem._2) } } - } finally { - if (partitionPairsWriter!= null) { + } { + if (partitionPairsWriter != null) { partitionPairsWriter.close() } } + nextPartitionId = id + 1 } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala index 9d7c209f242e1..05ed72c3e3778 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala @@ -17,6 +17,11 @@ package org.apache.spark.util.collection +/** + * An abstraction of a consumer of key-value pairs, primarily used when + * persisting partitioned data, either through the shuffle writer plugins + * or via DiskBlockObjectWriter. + */ private[spark] trait PairsWriter { def write(key: Any, value: Any): Unit diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala new file mode 100644 index 0000000000000..0dd6040808f9e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Mockito._ +import org.scalatest.Matchers + +import org.apache.spark.{Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.Utils + + +class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers { + + @Mock(answer = RETURNS_SMART_NULLS) + private var blockManager: BlockManager = _ + + private val shuffleId = 0 + private val numMaps = 5 + private var shuffleHandle: BaseShuffleHandle[Int, Int, Int] = _ + private val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + private val serializer = new JavaSerializer(conf) + private var shuffleExecutorComponents: ShuffleExecutorComponents = _ + + override def beforeEach(): Unit = { + super.beforeEach() + MockitoAnnotations.initMocks(this) + val partitioner = new Partitioner() { + def numPartitions = numMaps + def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions) + } + shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.partitioner).thenReturn(partitioner) + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps = numMaps, dependency) + } + shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, blockManager, shuffleBlockResolver) + } + + override def afterAll(): Unit = { + try { + shuffleBlockResolver.stop() + } finally { + super.afterAll() + } + } + + test("write empty iterator") { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val writer = new SortShuffleWriter[Int, Int, Int]( + shuffleBlockResolver, + shuffleHandle, + mapId = 1, + context, + shuffleExecutorComponents) + writer.write(Iterator.empty) + writer.stop(success = true) + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 1) + val writeMetrics = context.taskMetrics().shuffleWriteMetrics + assert(!dataFile.exists()) + assert(writeMetrics.bytesWritten === 0) + assert(writeMetrics.recordsWritten === 0) + } + + test("write with some records") { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val records = List[(Int, Int)]((1, 2), (2, 3), (4, 4), (6, 5)) + val writer = new SortShuffleWriter[Int, Int, Int]( + shuffleBlockResolver, + shuffleHandle, + mapId = 2, + context, + shuffleExecutorComponents) + writer.write(records.toIterator) + writer.stop(success = true) + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 2) + val writeMetrics = context.taskMetrics().shuffleWriteMetrics + assert(dataFile.exists()) + assert(dataFile.length() === writeMetrics.bytesWritten) + assert(records.size === writeMetrics.recordsWritten) + } +} From 5a9727128e421e60421c5d96164fde48b50d946c Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 10 Sep 2019 16:44:57 -0700 Subject: [PATCH 05/14] [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. --- .../api/ShuffleExecutorComponents.java | 37 +++ .../SingleSpillShuffleMapOutputWriter.java | 37 +++ .../shuffle/sort/UnsafeShuffleWriter.java | 218 +++++++++++------- .../LocalDiskShuffleExecutorComponents.java | 14 ++ .../io/LocalDiskShuffleMapOutputWriter.java | 24 +- .../LocalDiskSingleSpillMapOutputWriter.java | 55 +++++ .../spark/internal/config/package.scala | 2 +- .../shuffle/sort/SortShuffleManager.scala | 1 - .../sort/UnsafeShuffleWriterSuite.java | 35 +-- 9 files changed, 310 insertions(+), 113 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java index 8f3b6671c9482..0e6e90c404c16 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -18,7 +18,12 @@ package org.apache.spark.shuffle.api; import java.io.IOException; +<<<<<<< HEAD import java.util.Map; +||||||| parent of a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. +======= +import java.util.Optional; +>>>>>>> a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. import org.apache.spark.annotation.Private; @@ -40,7 +45,14 @@ public interface ShuffleExecutorComponents { /** * Called once per map task to create a writer that will be responsible for persisting all the * partitioned bytes written by that map task. +<<<<<<< HEAD * @param shuffleId Unique identifier for the shuffle the map task is a part of +||||||| parent of a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. + * @param shuffleId Unique identifier for the shuffle the map task is a part of +======= + * + * @param shuffleId Unique identifier for the shuffle the map task is a part of +>>>>>>> a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. * @param mapId Within the shuffle, the identifier of the map task * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task * with the same (shuffleId, mapId) pair can be distinguished by the @@ -53,6 +65,7 @@ ShuffleMapOutputWriter createMapOutputWriter( int mapId, long mapTaskAttemptId, int numPartitions) throws IOException; +<<<<<<< HEAD /** * Returns an underlying {@link Iterable} that will iterate @@ -64,4 +77,28 @@ Iterable getPartitionReaders(Iterable blockMetada default boolean shouldWrapPartitionReaderStream() { return true; } +||||||| parent of a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. +======= + + /** + * An optional extension for creating a map output writer that can optimize the transfer of a + * single partition file, as the entire result of a map task, to the backing store. + *

+ * Most implementations should return the default {@link Optional#empty()} to indicate that + * they do not support this optimization. This primarily is for backwards-compatibility in + * preserving an optimization in the local disk shuffle storage implementation. + * + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param mapId Within the shuffle, the identifier of the map task + * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task + * with the same (shuffleId, mapId) pair can be distinguished by the + * different values of mapTaskAttemptId. + */ + default Optional createSingleFileMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId) throws IOException { + return Optional.empty(); + } +>>>>>>> a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java new file mode 100644 index 0000000000000..bddb97bdf0d7e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.File; +import java.io.IOException; + +import org.apache.spark.annotation.Private; + +/** + * Optional extension for partition writing that is optimized for transferring a single + * file to the backing store. + */ +@Private +public interface SingleSpillShuffleMapOutputWriter { + + /** + * Transfer a file that contains the bytes of all the partitions written by this map task. + */ + MapOutputWriterCommitMessage transferMapSpillFile( + File mapOutputFile, long[] partitionLengths) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 745f4785ce01b..1ec9f590c0bbb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -18,9 +18,11 @@ package org.apache.spark.shuffle.sort; import java.nio.channels.Channels; +import java.util.Optional; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; import java.util.Iterator; import scala.Option; @@ -37,12 +39,6 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; -import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; -import org.apache.spark.shuffle.api.ShufflePartitionWriter; -import org.apache.spark.shuffle.api.SupportsTransferTo; -import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; @@ -55,8 +51,15 @@ import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -215,31 +218,15 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents - .createMapOutputWriter( - shuffleId, - mapId, - taskContext.taskAttemptId(), - partitioner.numPartitions()); - MapOutputWriterCommitMessage commitMessage; + final MapOutputWriterCommitMessage commitMessage; try { - try { - mergeSpills(spills, mapWriter); - } finally { - for (SpillInfo spill : spills) { - if (spill.file.exists() && !spill.file.delete()) { - logger.error("Error while deleting spill file {}", spill.file.getPath()); - } + commitMessage = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); } } - commitMessage = mapWriter.commitAllPartitions(); - } catch (Exception e) { - try { - mapWriter.abort(e); - } catch (Exception innerE) { - logger.error("Failed to abort the Map Output Writer", innerE); - } - throw e; } mapStatus = MapStatus$.MODULE$.apply( commitMessage.getLocation().orElse(null), @@ -276,52 +263,94 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private void mergeSpills(SpillInfo[] spills, - ShuffleMapOutputWriter mapWriter) throws IOException { + private MapOutputWriterCommitMessage mergeSpills(SpillInfo[] spills) throws IOException { + MapOutputWriterCommitMessage commitMessage; + if (spills.length == 0) { + final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents + .createMapOutputWriter( + shuffleId, + mapId, + taskContext.taskAttemptId(), + partitioner.numPartitions()); + return mapWriter.commitAllPartitions(); + } else if (spills.length == 1) { + Optional maybeSingleFileWriter = + shuffleExecutorComponents.createSingleFileMapOutputWriter( + shuffleId, mapId, taskContext.taskAttemptId()); + if (maybeSingleFileWriter.isPresent()) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + long[] partitionLengths = spills[0].partitionLengths; + return maybeSingleFileWriter.get().transferMapSpillFile( + spills[0].file, partitionLengths); + } else { + commitMessage = mergeSpillsUsingStandardWriter(spills); + } + } else { + commitMessage = mergeSpillsUsingStandardWriter(spills); + } + return commitMessage; + } + + private MapOutputWriterCommitMessage mergeSpillsUsingStandardWriter( + SpillInfo[] spills) throws IOException { + MapOutputWriterCommitMessage commitMessage; final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = - (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE()); + (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE()); final boolean fastMergeIsSupported = !compressionEnabled || - CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); + final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents + .createMapOutputWriter( + shuffleId, + mapId, + taskContext.taskAttemptId(), + partitioner.numPartitions()); try { - if (spills.length > 0) { - // There are multiple spills to merge, so none of these spill files' lengths were counted - // towards our shuffle write count or shuffle write time. If we use the slow merge path, - // then the final output file's size won't necessarily be equal to the sum of the spill - // files' sizes. To guard against this case, we look at the output file's actual size when - // computing shuffle bytes written. - // - // We allow the individual merge methods to report their own IO times since different merge - // strategies use different IO techniques. We count IO during merge towards the shuffle - // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" - // branch in ExternalSorter. - if (fastMergeEnabled && fastMergeIsSupported) { - // Compression is disabled or we are using an IO compression codec that supports - // decompression of concatenated compressed streams, so we can perform a fast spill merge - // that doesn't need to interpret the spilled bytes. - if (transferToEnabled && !encryptionEnabled) { - logger.debug("Using transferTo-based fast merge"); - mergeSpillsWithTransferTo(spills, mapWriter); - } else { - logger.debug("Using fileStream-based fast merge"); - mergeSpillsWithFileStream(spills, mapWriter, null); - } + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // write time, which appears to be consistent with the "not bypassing merge-sort" branch in + // ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled && !encryptionEnabled) { + logger.debug("Using transferTo-based fast merge"); + mergeSpillsWithTransferTo(spills, mapWriter); } else { - logger.debug("Using slow merge"); - mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); + logger.debug("Using fileStream-based fast merge"); + mergeSpillsWithFileStream(spills, mapWriter, null); } - // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has - // in-memory records, we write out the in-memory records to a file but do not count that - // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs - // to be counted as shuffle write, but this will lead to double-counting of the final - // SpillInfo's bytes. - writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); + } else { + logger.debug("Using slow merge"); + mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); + commitMessage = mapWriter.commitAllPartitions(); + } catch (Exception e) { + try { + mapWriter.abort(e); + } catch (Exception e2) { + logger.warn("Failed to abort writing the map output.", e2); + e.addSuppressed(e2); } - } catch (IOException e) { throw e; } + return commitMessage; } /** @@ -355,11 +384,10 @@ private void mergeSpillsWithFileStream( inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { - boolean copyThrewExecption = true; + boolean copyThrewException = true; ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); - OutputStream partitionOutput = null; + OutputStream partitionOutput = writer.openStream(); try { - partitionOutput = writer.openStream(); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); @@ -369,6 +397,7 @@ private void mergeSpillsWithFileStream( if (partitionLengthInSpill > 0) { InputStream partitionInputStream = null; + boolean copySpillThrewException = true; try { partitionInputStream = new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); @@ -379,14 +408,16 @@ private void mergeSpillsWithFileStream( partitionInputStream); } ByteStreams.copy(partitionInputStream, partitionOutput); + copySpillThrewException = false; } finally { - partitionInputStream.close(); + Closeables.close(partitionInputStream, copySpillThrewException); } } - copyThrewExecption = false; + copyThrewException = false; } + copyThrewException = false; } finally { - Closeables.close(partitionOutput, copyThrewExecption); + Closeables.close(partitionOutput, copyThrewException); } long numBytesWritten = writer.getNumBytesWritten(); writeMetrics.incBytesWritten(numBytesWritten); @@ -423,27 +454,26 @@ private void mergeSpillsWithTransferTo( spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); } for (int partition = 0; partition < numPartitions; partition++) { - boolean copyThrewExecption = true; + boolean copyThrewException = true; ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); - TransferrableWritableByteChannel partitionChannel = null; + WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper() + .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer))); try { - partitionChannel = writer instanceof SupportsTransferTo ? - ((SupportsTransferTo) writer).openTransferrableChannel() - : new DefaultTransferrableWritableByteChannel( - Channels.newChannel(writer.openStream())); for (int i = 0; i < spills.length; i++) { - long partitionLengthInSpill = 0L; - partitionLengthInSpill += spills[i].partitionLengths[partition]; + long partitionLengthInSpill = spills[i].partitionLengths[partition]; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); - partitionChannel.transferFrom( - spillInputChannel, spillInputChannelPositions[i], partitionLengthInSpill); + Utils.copyFileStreamNIO( + spillInputChannel, + resolvedChannel.channel(), + spillInputChannelPositions[i], + partitionLengthInSpill); + copyThrewException = false; spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } - copyThrewExecption = false; } finally { - Closeables.close(partitionChannel, copyThrewExecption); + Closeables.close(resolvedChannel, copyThrewException); } long numBytes = writer.getNumBytesWritten(); writeMetrics.incBytesWritten(numBytes); @@ -485,4 +515,30 @@ public Option stop(boolean success) { } } } + + private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) { + try { + return writer.openStream(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper { + private final WritableByteChannel channel; + + StreamFallbackChannelWrapper(OutputStream fallbackStream) { + this.channel = Channels.newChannel(fallbackStream); + } + + @Override + public WritableByteChannel channel() { + return channel; + } + + @Override + public void close() throws IOException { + channel.close(); + } + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java index f32306d4c37c7..2e44c2acda301 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.sort.io; import java.util.Map; +import java.util.Optional; import com.google.common.annotations.VisibleForTesting; @@ -26,6 +27,7 @@ import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.storage.BlockManager; public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents { @@ -70,4 +72,16 @@ public ShuffleMapOutputWriter createMapOutputWriter( return new LocalDiskShuffleMapOutputWriter( shuffleId, mapId, numPartitions, blockResolver, sparkConf); } + + @Override + public Optional createSingleFileMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)); + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index 7fc19b1270a46..444cdc4270ecd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -24,8 +24,8 @@ import java.io.OutputStream; import java.nio.channels.FileChannel; import java.nio.channels.WritableByteChannel; - import java.util.Optional; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,6 +54,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final int bufferSize; private int lastPartitionId = -1; private long currChannelPosition; + private long bytesWrittenToMergedFile = 0L; private final File outputFile; private File outputTempFile; @@ -97,6 +98,18 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I @Override public long[] commitAllPartitions() throws IOException { + // Check the position after transferTo loop to see if it is in the right position and raise a + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + outputFileChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your " + + " kernel version to see if it is 2.6.32, as there is a kernel bug which will lead " + + "to unexpected behavior when using transferTo. You can set " + + "spark.file.transferTo=false to disable this NIO feature."); + } cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); @@ -133,11 +146,10 @@ private void initStream() throws IOException { } private void initChannel() throws IOException { - if (outputFileStream == null) { - outputFileStream = new FileOutputStream(outputTempFile, true); - } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. if (outputFileChannel == null) { - outputFileChannel = outputFileStream.getChannel(); + outputFileChannel = new FileOutputStream(outputTempFile, true).getChannel(); } } @@ -227,6 +239,7 @@ public void write(byte[] buf, int pos, int length) throws IOException { public void close() { isClosed = true; partitionLengths[partitionId] = count; + bytesWrittenToMergedFile += count; } private void verifyNotClosed() { @@ -257,6 +270,7 @@ public WritableByteChannel channel() { @Override public void close() throws IOException { partitionLengths[partitionId] = getCount(); + bytesWrittenToMergedFile += partitionLengths[partitionId]; } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java new file mode 100644 index 0000000000000..6b0a797a61b52 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; + +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.util.Utils; + +public class LocalDiskSingleSpillMapOutputWriter + implements SingleSpillShuffleMapOutputWriter { + + private final int shuffleId; + private final int mapId; + private final IndexShuffleBlockResolver blockResolver; + + public LocalDiskSingleSpillMapOutputWriter( + int shuffleId, + int mapId, + IndexShuffleBlockResolver blockResolver) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.blockResolver = blockResolver; + } + + @Override + public void transferMapSpillFile( + File mapSpillFile, + long[] partitionLengths) throws IOException { + // The map spill file already has the proper format, and it contains all of the partition data. + // So just transfer it directly to the destination without any merging. + File outputFile = blockResolver.getDataFile(shuffleId, mapId); + File tempFile = Utils.tempFileWith(outputFile); + Files.move(mapSpillFile.toPath(), tempFile.toPath()); + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 23607e7ad975f..833db06420d4d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -951,7 +951,7 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val SHUFFLE_UNDAFE_FAST_MERGE_ENABLE = + private[spark] val SHUFFLE_UNSAFE_FAST_MERGE_ENABLE = ConfigBuilder("spark.shuffle.unsafe.fastMergeEnabled") .doc("Whether to perform a fast spill merge.") .booleanConf diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 0308b94fd14ba..162cf9f0d420a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -146,7 +146,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, - shuffleBlockResolver, context.taskMemoryManager(), unsafeShuffleHandle, mapId, diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 698a7f72a722c..80006f3cd9201 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -22,6 +22,7 @@ import java.nio.file.Files; import java.util.*; +import org.mockito.stubbing.Answer; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -35,7 +36,6 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.stubbing.Answer; import org.apache.spark.HashPartitioner; import org.apache.spark.MapOutputTracker; @@ -57,7 +57,7 @@ import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents; +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -141,8 +141,7 @@ public void setUp() throws IOException { }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - - Answer renameTempAnswer = invocationOnMock -> { + Answer renameTempAnswer = invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; File tmp = (File) invocationOnMock.getArguments()[3]; if (!mergedOutputFile.delete()) { @@ -175,14 +174,11 @@ public void setUp() throws IOException { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); - - TaskContext$.MODULE$.setTaskContext(taskContext); } - private UnsafeShuffleWriter createWriter( - boolean transferToEnabled) throws IOException { + private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); - return new UnsafeShuffleWriter<>( + return new UnsafeShuffleWriter( blockManager, taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), @@ -190,12 +186,7 @@ private UnsafeShuffleWriter createWriter( taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleExecutorComponents( - conf, - blockManager, - mapOutputTracker, - serializerManager, - shuffleBlockResolver)); + new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver)); } private void assertSpillFilesWereCleanedUp() { @@ -421,7 +412,7 @@ public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Except @Test public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { - conf.set(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE(), false); + conf.set(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE(), false); testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); } @@ -546,21 +537,15 @@ public void testPeakMemoryUsed() throws Exception { final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; taskMemoryManager = spy(taskMemoryManager); when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); - final UnsafeShuffleWriter writer = - new UnsafeShuffleWriter<>( + final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( blockManager, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleExecutorComponents( - conf, - blockManager, - mapOutputTracker, - serializerManager, - shuffleBlockResolver)); + new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver)); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. From 80549624b8367260323d9d55019ba195e44d392c Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 10 Sep 2019 17:42:02 -0700 Subject: [PATCH 06/14] Resolve build issues and remaining semantic conflicts --- .../api/MapOutputWriterCommitMessage.java | 7 +- .../api/ShuffleExecutorComponents.java | 15 +- .../shuffle/sort/io/DefaultShuffleDataIO.java | 43 --- .../io/DefaultShuffleExecutorComponents.java | 101 ------- .../io/DefaultShuffleMapOutputWriter.java | 269 ------------------ .../sort/io/LocalDiskShuffleDataIO.java | 8 + .../LocalDiskShuffleExecutorComponents.java | 45 ++- .../io/LocalDiskShuffleMapOutputWriter.java | 9 +- .../LocalDiskSingleSpillMapOutputWriter.java | 10 +- ...cala => LocalDiskShuffleReadSupport.scala} | 2 +- .../sort/UnsafeShuffleWriterSuite.java | 12 +- .../DAGSchedulerShufflePluginSuite.scala | 11 +- .../BlockStoreShuffleReaderSuite.scala | 4 +- .../ShuffleDriverComponentsSuite.scala | 4 +- .../BlockStoreShuffleReaderBenchmark.scala | 4 +- ...ypassMergeSortShuffleWriterBenchmark.scala | 8 +- .../BypassMergeSortShuffleWriterSuite.scala | 7 +- .../sort/SortShuffleWriterBenchmark.scala | 7 +- .../shuffle/sort/SortShuffleWriterSuite.scala | 15 +- .../sort/UnsafeShuffleWriterBenchmark.scala | 7 +- .../DefaultShuffleMapOutputWriterSuite.scala | 230 --------------- 21 files changed, 122 insertions(+), 696 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java rename core/src/main/scala/org/apache/spark/shuffle/io/{DefaultShuffleReadSupport.scala => LocalDiskShuffleReadSupport.scala} (99%) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java index e07efd57cc07f..dc51f8962206f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java @@ -11,7 +11,8 @@ public final class MapOutputWriterCommitMessage { private final long[] partitionLengths; private final Optional location; - private MapOutputWriterCommitMessage(long[] partitionLengths, Optional location) { + private MapOutputWriterCommitMessage( + long[] partitionLengths, Optional location) { this.partitionLengths = partitionLengths; this.location = location; } @@ -21,8 +22,8 @@ public static MapOutputWriterCommitMessage of(long[] partitionLengths) { } public static MapOutputWriterCommitMessage of( - long[] partitionLengths, java.util.Optional location) { - return new MapOutputWriterCommitMessage(partitionLengths, location); + long[] partitionLengths, BlockManagerId location) { + return new MapOutputWriterCommitMessage(partitionLengths, Optional.of(location)); } public long[] getPartitionLengths() { diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java index 0e6e90c404c16..94c07009f3180 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -18,12 +18,9 @@ package org.apache.spark.shuffle.api; import java.io.IOException; -<<<<<<< HEAD +import java.io.InputStream; import java.util.Map; -||||||| parent of a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. -======= import java.util.Optional; ->>>>>>> a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. import org.apache.spark.annotation.Private; @@ -45,14 +42,8 @@ public interface ShuffleExecutorComponents { /** * Called once per map task to create a writer that will be responsible for persisting all the * partitioned bytes written by that map task. -<<<<<<< HEAD - * @param shuffleId Unique identifier for the shuffle the map task is a part of -||||||| parent of a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. - * @param shuffleId Unique identifier for the shuffle the map task is a part of -======= * * @param shuffleId Unique identifier for the shuffle the map task is a part of ->>>>>>> a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. * @param mapId Within the shuffle, the identifier of the map task * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task * with the same (shuffleId, mapId) pair can be distinguished by the @@ -65,7 +56,6 @@ ShuffleMapOutputWriter createMapOutputWriter( int mapId, long mapTaskAttemptId, int numPartitions) throws IOException; -<<<<<<< HEAD /** * Returns an underlying {@link Iterable} that will iterate @@ -77,8 +67,6 @@ Iterable getPartitionReaders(Iterable blockMetada default boolean shouldWrapPartitionReaderStream() { return true; } -||||||| parent of a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. -======= /** * An optional extension for creating a map output writer that can optimize the transfer of a @@ -100,5 +88,4 @@ default Optional createSingleFileMapOutputWri long mapTaskAttemptId) throws IOException { return Optional.empty(); } ->>>>>>> a80c04ed1d... [SPARK-28570][CORE][SHUFFLE] Make UnsafeShuffleWriter use the new API. } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java deleted file mode 100644 index a6faa6ac52ca6..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.api.ShuffleDriverComponents; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; -import org.apache.spark.shuffle.api.ShuffleDataIO; -import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents; - -public class DefaultShuffleDataIO implements ShuffleDataIO { - - private final SparkConf sparkConf; - - public DefaultShuffleDataIO(SparkConf sparkConf) { - this.sparkConf = sparkConf; - } - - @Override - public ShuffleExecutorComponents executor() { - return new DefaultShuffleExecutorComponents(sparkConf); - } - - @Override - public ShuffleDriverComponents driver() { - return new DefaultShuffleDriverComponents(); - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java deleted file mode 100644 index 77edba8642728..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import com.google.common.annotations.VisibleForTesting; -import java.io.IOException; -import java.io.InputStream; -import org.apache.spark.MapOutputTracker; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkEnv; -import org.apache.spark.TaskContext; -import org.apache.spark.shuffle.api.ShuffleBlockInfo; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; -import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; -import org.apache.spark.serializer.SerializerManager; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; -import org.apache.spark.storage.BlockManager; - -import java.util.Map; - -public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { - - private final SparkConf sparkConf; - // Submodule for the read side for shuffles - implemented in Scala for ease of - // compatibility with previously written code. - private DefaultShuffleReadSupport shuffleReadSupport; - private BlockManager blockManager; - private IndexShuffleBlockResolver blockResolver; - - public DefaultShuffleExecutorComponents(SparkConf sparkConf) { - this.sparkConf = sparkConf; - } - - @VisibleForTesting - public DefaultShuffleExecutorComponents( - SparkConf sparkConf, - BlockManager blockManager, - MapOutputTracker mapOutputTracker, - SerializerManager serializerManager, - IndexShuffleBlockResolver blockResolver) { - this.sparkConf = sparkConf; - this.blockManager = blockManager; - this.blockResolver = blockResolver; - this.shuffleReadSupport = new DefaultShuffleReadSupport( - blockManager, mapOutputTracker, serializerManager, sparkConf); - } - - @Override - public void initializeExecutor(String appId, String execId, Map extraConfigs) { - blockManager = SparkEnv.get().blockManager(); - MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker(); - SerializerManager serializerManager = SparkEnv.get().serializerManager(); - blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); - shuffleReadSupport = new DefaultShuffleReadSupport( - blockManager, mapOutputTracker, serializerManager, sparkConf); - } - - @Override - public ShuffleMapOutputWriter createMapOutputWriter(int shuffleId, int mapId, long mapTaskAttemptId, int numPartitions) throws IOException { - checkInitialized(); - return new DefaultShuffleMapOutputWriter( - shuffleId, - mapId, - numPartitions, - blockManager.shuffleServerId(), - TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); - } - - @Override - public Iterable getPartitionReaders(Iterable blockMetadata) throws IOException { - return shuffleReadSupport.getPartitionReaders(blockMetadata); - } - - @Override - public boolean shouldWrapPartitionReaderStream() { - return false; - } - - private void checkInitialized() { - if (blockResolver == null) { - throw new IllegalStateException( - "Executor components must be initialized before getting writers."); - } - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java deleted file mode 100644 index 5c8d2d43dacf7..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import java.io.BufferedOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import java.nio.channels.FileChannel; - -import org.apache.spark.api.java.Optional; -import org.apache.spark.storage.BlockManagerId; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; -import org.apache.spark.shuffle.api.ShufflePartitionWriter; -import org.apache.spark.shuffle.api.SupportsTransferTo; -import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; -import org.apache.spark.internal.config.package$; -import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel; -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.storage.TimeTrackingOutputStream; -import org.apache.spark.util.Utils; - -public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { - - private static final Logger log = - LoggerFactory.getLogger(DefaultShuffleMapOutputWriter.class); - - private final int shuffleId; - private final int mapId; - private final ShuffleWriteMetricsReporter metrics; - private final IndexShuffleBlockResolver blockResolver; - private final long[] partitionLengths; - private final int bufferSize; - private int lastPartitionId = -1; - private long currChannelPosition; - private final BlockManagerId shuffleServerId; - - private final File outputFile; - private File outputTempFile; - private FileOutputStream outputFileStream; - private FileChannel outputFileChannel; - private TimeTrackingOutputStream ts; - private BufferedOutputStream outputBufferedFileStream; - - public DefaultShuffleMapOutputWriter( - int shuffleId, - int mapId, - int numPartitions, - BlockManagerId shuffleServerId, - ShuffleWriteMetricsReporter metrics, - IndexShuffleBlockResolver blockResolver, - SparkConf sparkConf) { - this.shuffleId = shuffleId; - this.mapId = mapId; - this.shuffleServerId = shuffleServerId; - this.metrics = metrics; - this.blockResolver = blockResolver; - this.bufferSize = - (int) (long) sparkConf.get( - package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; - this.partitionLengths = new long[numPartitions]; - this.outputFile = blockResolver.getDataFile(shuffleId, mapId); - this.outputTempFile = null; - } - - @Override - public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException { - if (partitionId <= lastPartitionId) { - throw new IllegalArgumentException("Partitions should be requested in increasing order."); - } - lastPartitionId = partitionId; - if (outputTempFile == null) { - outputTempFile = Utils.tempFileWith(outputFile); - } - if (outputFileChannel != null) { - currChannelPosition = outputFileChannel.position(); - } else { - currChannelPosition = 0L; - } - return new DefaultShufflePartitionWriter(partitionId); - } - - @Override - public Optional commitAllPartitions() throws IOException { - cleanUp(); - File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); - return Optional.of(shuffleServerId); - } - - @Override - public void abort(Throwable error) { - try { - cleanUp(); - } catch (Exception e) { - log.error("Unable to close appropriate underlying file stream", e); - } - if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { - log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath()); - } - } - - private void cleanUp() throws IOException { - if (outputBufferedFileStream != null) { - outputBufferedFileStream.close(); - } - if (outputFileChannel != null) { - outputFileChannel.close(); - } - if (outputFileStream != null) { - outputFileStream.close(); - } - } - - private void initStream() throws IOException { - if (outputFileStream == null) { - outputFileStream = new FileOutputStream(outputTempFile, true); - ts = new TimeTrackingOutputStream(metrics, outputFileStream); - } - if (outputBufferedFileStream == null) { - outputBufferedFileStream = new BufferedOutputStream(ts, bufferSize); - } - } - - private void initChannel() throws IOException { - if (outputFileStream == null) { - outputFileStream = new FileOutputStream(outputTempFile, true); - } - if (outputFileChannel == null) { - outputFileChannel = outputFileStream.getChannel(); - } - } - - private class DefaultShufflePartitionWriter implements SupportsTransferTo { - - private final int partitionId; - private PartitionWriterStream partStream = null; - private PartitionWriterChannel partChannel = null; - - private DefaultShufflePartitionWriter(int partitionId) { - this.partitionId = partitionId; - } - - @Override - public OutputStream openStream() throws IOException { - if (partStream == null) { - if (outputFileChannel != null) { - throw new IllegalStateException("Requested an output channel for a previous write but" + - " now an output stream has been requested. Should not be using both channels" + - " and streams to write."); - } - initStream(); - partStream = new PartitionWriterStream(partitionId); - } - return partStream; - } - - @Override - public TransferrableWritableByteChannel openTransferrableChannel() throws IOException { - if (partChannel == null) { - if (partStream != null) { - throw new IllegalStateException("Requested an output stream for a previous write but" + - " now an output channel has been requested. Should not be using both channels" + - " and streams to write."); - } - initChannel(); - partChannel = new PartitionWriterChannel(partitionId); - } - return partChannel; - } - - @Override - public long getNumBytesWritten() { - if (partChannel != null) { - try { - return partChannel.getCount(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } else if (partStream != null) { - return partStream.getCount(); - } else { - // Assume an empty partition if stream and channel are never created - return 0; - } - } - } - - private class PartitionWriterStream extends OutputStream { - private final int partitionId; - private int count = 0; - private boolean isClosed = false; - - PartitionWriterStream(int partitionId) { - this.partitionId = partitionId; - } - - public int getCount() { - return count; - } - - @Override - public void write(int b) throws IOException { - verifyNotClosed(); - outputBufferedFileStream.write(b); - count++; - } - - @Override - public void write(byte[] buf, int pos, int length) throws IOException { - verifyNotClosed(); - outputBufferedFileStream.write(buf, pos, length); - count += length; - } - - @Override - public void close() { - isClosed = true; - partitionLengths[partitionId] = count; - } - - private void verifyNotClosed() { - if (isClosed) { - throw new IllegalStateException("Attempting to write to a closed block output stream."); - } - } - } - - private class PartitionWriterChannel extends DefaultTransferrableWritableByteChannel { - - private final int partitionId; - - PartitionWriterChannel(int partitionId) { - super(outputFileChannel); - this.partitionId = partitionId; - } - - public long getCount() throws IOException { - long writtenPosition = outputFileChannel.position(); - return writtenPosition - currChannelPosition; - } - - @Override - public void close() throws IOException { - partitionLengths[partitionId] = getCount(); - } - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java index cabcb171ac23a..2db32a1f30860 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java @@ -18,8 +18,10 @@ package org.apache.spark.shuffle.sort.io; import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.api.ShuffleDriverComponents; import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleDataIO; +import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents; /** * Implementation of the {@link ShuffleDataIO} plugin system that replicates the local shuffle @@ -33,8 +35,14 @@ public LocalDiskShuffleDataIO(SparkConf sparkConf) { this.sparkConf = sparkConf; } + @Override + public ShuffleDriverComponents driver() { + return new DefaultShuffleDriverComponents(); + } + @Override public ShuffleExecutorComponents executor() { return new LocalDiskShuffleExecutorComponents(sparkConf); } + } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java index 2e44c2acda301..c8d70d72eb02e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -17,22 +17,30 @@ package org.apache.spark.shuffle.sort.io; +import java.io.InputStream; import java.util.Map; import java.util.Optional; import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.MapOutputTracker; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.shuffle.api.ShuffleBlockInfo; import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport; import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockManagerId; public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents { private final SparkConf sparkConf; + private LocalDiskShuffleReadSupport shuffleReadSupport; + private BlockManagerId shuffleServerId; private BlockManager blockManager; private IndexShuffleBlockResolver blockResolver; @@ -44,10 +52,16 @@ public LocalDiskShuffleExecutorComponents(SparkConf sparkConf) { public LocalDiskShuffleExecutorComponents( SparkConf sparkConf, BlockManager blockManager, - IndexShuffleBlockResolver blockResolver) { + MapOutputTracker mapOutputTracker, + SerializerManager serializerManager, + IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId) { this.sparkConf = sparkConf; this.blockManager = blockManager; this.blockResolver = blockResolver; + this.shuffleServerId = shuffleServerId; + this.shuffleReadSupport = new LocalDiskShuffleReadSupport( + blockManager, mapOutputTracker, serializerManager, sparkConf); } @Override @@ -56,7 +70,12 @@ public void initializeExecutor(String appId, String execId, Map if (blockManager == null) { throw new IllegalStateException("No blockManager available from the SparkEnv."); } + shuffleServerId = blockManager.shuffleServerId(); blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker(); + SerializerManager serializerManager = SparkEnv.get().serializerManager(); + shuffleReadSupport = new LocalDiskShuffleReadSupport( + blockManager, mapOutputTracker, serializerManager, sparkConf); } @Override @@ -70,7 +89,12 @@ public ShuffleMapOutputWriter createMapOutputWriter( "Executor components must be initialized before getting writers."); } return new LocalDiskShuffleMapOutputWriter( - shuffleId, mapId, numPartitions, blockResolver, sparkConf); + shuffleId, + mapId, + numPartitions, + blockResolver, + shuffleServerId, + sparkConf); } @Override @@ -82,6 +106,21 @@ public Optional createSingleFileMapOutputWrit throw new IllegalStateException( "Executor components must be initialized before getting writers."); } - return Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)); + return Optional.of(new LocalDiskSingleSpillMapOutputWriter( + shuffleId, mapId, blockResolver, shuffleServerId)); + } + + @Override + public Iterable getPartitionReaders(Iterable blockMetadata) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting readers."); + } + return shuffleReadSupport.getPartitionReaders(blockMetadata); + } + + @Override + public boolean shouldWrapPartitionReaderStream() { + return false; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index 444cdc4270ecd..a4566cf908b79 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -26,6 +26,8 @@ import java.nio.channels.WritableByteChannel; import java.util.Optional; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,6 +54,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final IndexShuffleBlockResolver blockResolver; private final long[] partitionLengths; private final int bufferSize; + private final BlockManagerId shuffleServerId; private int lastPartitionId = -1; private long currChannelPosition; private long bytesWrittenToMergedFile = 0L; @@ -67,6 +70,7 @@ public LocalDiskShuffleMapOutputWriter( int mapId, int numPartitions, IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId, SparkConf sparkConf) { this.shuffleId = shuffleId; this.mapId = mapId; @@ -74,6 +78,7 @@ public LocalDiskShuffleMapOutputWriter( this.bufferSize = (int) (long) sparkConf.get( package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + this.shuffleServerId = shuffleServerId; this.partitionLengths = new long[numPartitions]; this.outputFile = blockResolver.getDataFile(shuffleId, mapId); this.outputTempFile = null; @@ -97,7 +102,7 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I } @Override - public long[] commitAllPartitions() throws IOException { + public MapOutputWriterCommitMessage commitAllPartitions() throws IOException { // Check the position after transferTo loop to see if it is in the right position and raise a // exception if it is incorrect. The position will not be increased to the expected length // after calling transferTo in kernel version 2.6.32. This issue is described at @@ -113,7 +118,7 @@ public long[] commitAllPartitions() throws IOException { cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); - return partitionLengths; + return MapOutputWriterCommitMessage.of(partitionLengths, shuffleServerId); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java index 6b0a797a61b52..219f9ee1296dd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -22,7 +22,9 @@ import java.nio.file.Files; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.storage.BlockManagerId; import org.apache.spark.util.Utils; public class LocalDiskSingleSpillMapOutputWriter @@ -31,18 +33,21 @@ public class LocalDiskSingleSpillMapOutputWriter private final int shuffleId; private final int mapId; private final IndexShuffleBlockResolver blockResolver; + private final BlockManagerId shuffleServerId; public LocalDiskSingleSpillMapOutputWriter( int shuffleId, int mapId, - IndexShuffleBlockResolver blockResolver) { + IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId) { this.shuffleId = shuffleId; this.mapId = mapId; this.blockResolver = blockResolver; + this.shuffleServerId = shuffleServerId; } @Override - public void transferMapSpillFile( + public MapOutputWriterCommitMessage transferMapSpillFile( File mapSpillFile, long[] partitionLengths) throws IOException { // The map spill file already has the proper format, and it contains all of the partition data. @@ -51,5 +56,6 @@ public void transferMapSpillFile( File tempFile = Utils.tempFileWith(outputFile); Files.move(mapSpillFile.toPath(), tempFile.toPath()); blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); + return MapOutputWriterCommitMessage.of(partitionLengths, shuffleServerId); } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala rename to core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala index 6ab14e3780572..9e1c1816d306c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala @@ -28,7 +28,7 @@ import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.shuffle.api.ShuffleBlockInfo import org.apache.spark.storage.{BlockManager, ShuffleBlockAttemptId, ShuffleBlockFetcherIterator, ShuffleBlockId} -class DefaultShuffleReadSupport( +class LocalDiskShuffleReadSupport( blockManager: BlockManager, mapOutputTracker: MapOutputTracker, serializerManager: SerializerManager, diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 80006f3cd9201..da7f7dd143d37 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -186,7 +186,11 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver)); + new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + shuffleBlockResolver, + BlockManagerId.apply("localhost", 7077))); } private void assertSpillFilesWereCleanedUp() { @@ -545,7 +549,11 @@ public void testPeakMemoryUsed() throws Exception { taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver)); + new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + shuffleBlockResolver, + BlockManagerId.apply("localhost", 7077))); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala index 39fbc3a1b5851..9d3a52a237cbe 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala @@ -17,21 +17,20 @@ package org.apache.spark.scheduler import java.util -import org.apache.spark.{FetchFailed, HashPartitioner, ShuffleDependency, SparkConf, Success} -import org.apache.spark.api.shuffle.{ShuffleDriverComponents, ShuffleExecutorComponents} +import org.apache.spark.{FetchFailed, HashPartitioner, ShuffleDependency, SparkConf, Success} import org.apache.spark.internal.config import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} -import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO import org.apache.spark.storage.BlockManagerId class PluginShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { - val defaultShuffleDataIO = new DefaultShuffleDataIO(sparkConf) + val localDiskShuffleDataIO = new LocalDiskShuffleDataIO(sparkConf) override def driver(): ShuffleDriverComponents = - new PluginShuffleDriverComponents(defaultShuffleDataIO.driver()) + new PluginShuffleDriverComponents(localDiskShuffleDataIO.driver()) - override def executor(): ShuffleExecutorComponents = defaultShuffleDataIO.executor() + override def executor(): ShuffleExecutorComponents = localDiskShuffleDataIO.executor() } class PluginShuffleDriverComponents(delegate: ShuffleDriverComponents) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 64f8cbc970d54..b707941006638 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.config import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId} /** @@ -143,7 +143,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val shuffleReadSupport = - new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf) + new LocalDiskShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala index e2ccb3fdce651..b571565cf4336 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -26,7 +26,7 @@ import com.google.common.collect.ImmutableMap import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS import org.apache.spark.shuffle.api.{ShuffleBlockInfo, ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleMapOutputWriter} -import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { test(s"test serialization of shuffle initialization conf to executors") { @@ -57,7 +57,7 @@ class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents { - private var delegate = new DefaultShuffleExecutorComponents(sparkConf) + private var delegate = new LocalDiskShuffleExecutorComponents(sparkConf) override def initializeExecutor( appId: String, execId: String, extraConfigs: JMap[String, String]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 0a77b9f0686ac..640dde9847398 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -40,7 +40,7 @@ import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransport import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockAttemptId, ShuffleBlockId} import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils} @@ -212,7 +212,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(sorter) - val readSupport = new DefaultShuffleReadSupport( + val readSupport = new LocalDiskShuffleReadSupport( blockManager, mapOutputTracker, serializerManager, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index 48cb3800b698a..46888259206a9 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -18,9 +18,8 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf - import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.{DefaultShuffleExecutorComponents} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -49,12 +48,13 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") - val shuffleExecutorComponents = new DefaultShuffleExecutorComponents( + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( conf, blockManager, mapOutputTracker, serializerManager, - blockResolver) + blockResolver, + blockManager.shuffleServerId) val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 16fcb89a32fce..4e3f179b214fe 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -126,7 +126,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte } shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( - conf, blockManager, blockResolver) + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + BlockManagerId("localhost", 7077)) } override def afterEach(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index e6471ce2d8e93..2cb53e4bac224 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -22,7 +22,7 @@ import org.mockito.Mockito.when import org.apache.spark.{Aggregator, SparkEnv, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -77,12 +77,13 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(taskContext) - val shuffleExecutorComponents = new DefaultShuffleExecutorComponents( + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( defaultConf, blockManager, mapOutputTracker, serializerManager, - blockResolver) + blockResolver, + blockManager.shuffleServerId) val shuffleWriter = new SortShuffleWriter[String, String, String]( blockResolver, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 0dd6040808f9e..5116ad834a521 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -21,14 +21,14 @@ import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Mockito._ import org.scalatest.Matchers +import org.apache.spark.{MapOutputTracker, Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} -import org.apache.spark.{Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents -import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.{BlockManager, BlockManagerId} import org.apache.spark.util.Utils @@ -36,6 +36,10 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) + private var mapOutputTracker: MapOutputTracker = _ + @Mock(answer = RETURNS_SMART_NULLS) + private var serializerManager: SerializerManager = _ private val shuffleId = 0 private val numMaps = 5 @@ -60,7 +64,12 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with new BaseShuffleHandle(shuffleId, numMaps = numMaps, dependency) } shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( - conf, blockManager, shuffleBlockResolver) + conf, + blockManager, + mapOutputTracker, + serializerManager, + shuffleBlockResolver, + BlockManagerId("localhost", 7077)) } override def afterAll(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index 04a557cf4384a..d012bda0ffede 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.DefaultShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -43,12 +43,13 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) - val shuffleExecutorComponents = new DefaultShuffleExecutorComponents( + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( conf, blockManager, mapOutputTracker, serializerManager, - blockResolver) + blockResolver, + blockManager.shuffleServerId) TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala deleted file mode 100644 index 92960ad956ce2..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io - -import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} -import java.math.BigInteger -import java.nio.ByteBuffer -import java.nio.channels.{Channels, WritableByteChannel} - -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} -import org.mockito.Mock -import org.mockito.Mockito.{doAnswer, doNothing, when} -import org.mockito.MockitoAnnotations -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfterEach -import org.apache.spark.{SparkConf, SparkFunSuite} - -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.shuffle.IndexShuffleBlockResolver -import org.apache.spark.shuffle.api.SupportsTransferTo -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{ByteBufferInputStream, Utils} - -class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { - - @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ - @Mock(answer = RETURNS_SMART_NULLS) private var shuffleWriteMetrics: ShuffleWriteMetrics = _ - - private val NUM_PARTITIONS = 4 - private val D_LEN = 10 - private val data: Array[Array[Int]] = (0 until NUM_PARTITIONS).map { - p => (1 to D_LEN).map(_ + p).toArray }.toArray - - private var tempFile: File = _ - private var mergedOutputFile: File = _ - private var tempDir: File = _ - private var partitionSizesInMergedFile: Array[Long] = _ - private var conf: SparkConf = _ - private var mapOutputWriter: DefaultShuffleMapOutputWriter = _ - - override def afterEach(): Unit = { - try { - Utils.deleteRecursively(tempDir) - } finally { - super.afterEach() - } - } - - override def beforeEach(): Unit = { - MockitoAnnotations.initMocks(this) - tempDir = Utils.createTempDir(null, "test") - mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) - tempFile = File.createTempFile("tempfile", "", tempDir) - partitionSizesInMergedFile = null - conf = new SparkConf() - .set("spark.app.id", "example.spark.app") - .set("spark.shuffle.unsafe.file.output.buffer", "16k") - when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) - - doNothing().when(shuffleWriteMetrics).incWriteTime(anyLong) - - doAnswer(new Answer[Void] { - def answer(invocationOnMock: InvocationOnMock): Void = { - partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] - val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - mergedOutputFile.delete - tmp.renameTo(mergedOutputFile) - } - null - } - }).when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) - mapOutputWriter = new DefaultShuffleMapOutputWriter( - 0, - 0, - NUM_PARTITIONS, - BlockManagerId("0", "localhost", 9099), - shuffleWriteMetrics, - blockResolver, - conf) - } - - private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = { - var startOffset = 0L - val result = new Array[Array[Int]](NUM_PARTITIONS) - (0 until NUM_PARTITIONS).foreach { p => - val partitionSize = partitionSizesInMergedFile(p).toInt - lazy val inner = new Array[Int](partitionSize) - lazy val innerBytebuffer = ByteBuffer.allocate(partitionSize) - if (partitionSize > 0) { - val in = new FileInputStream(mergedOutputFile) - in.getChannel.position(startOffset) - val lin = new LimitedInputStream(in, partitionSize) - var nonEmpty = true - var count = 0 - while (nonEmpty) { - try { - val readBit = lin.read() - if (fromByte) { - innerBytebuffer.put(readBit.toByte) - } else { - inner(count) = readBit - } - count += 1 - } catch { - case _: Exception => - nonEmpty = false - } - } - in.close() - } - if (fromByte) { - result(p) = innerBytebuffer.array().sliding(4, 4).map { b => - new BigInteger(b).intValue() - }.toArray - } else { - result(p) = inner - } - startOffset += partitionSize - } - result - } - - test("writing to an outputstream") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.openStream() - data(p).foreach { i => stream.write(i)} - stream.close() - intercept[IllegalStateException] { - stream.write(p) - } - assert(writer.getNumBytesWritten() == D_LEN) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(false)) - } - - test("writing to a channel") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() - val byteBuffer = ByteBuffer.allocate(D_LEN * 4) - val intBuffer = byteBuffer.asIntBuffer() - intBuffer.put(data(p)) - val numBytes = byteBuffer.remaining() - val outputTempFile = File.createTempFile("channelTemp", "", tempDir) - val outputTempFileStream = new FileOutputStream(outputTempFile) - Utils.copyStream( - new ByteBufferInputStream(byteBuffer), - outputTempFileStream, - closeStreams = true) - val tempFileInput = new FileInputStream(outputTempFile) - channel.transferFrom(tempFileInput.getChannel, 0L, numBytes) - // Bytes require * 4 - channel.close() - tempFileInput.close() - assert(writer.getNumBytesWritten == D_LEN * 4) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(true)) - } - - test("copyStreams with an outputstream") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.openStream() - val byteBuffer = ByteBuffer.allocate(D_LEN * 4) - val intBuffer = byteBuffer.asIntBuffer() - intBuffer.put(data(p)) - val in = new ByteArrayInputStream(byteBuffer.array()) - Utils.copyStream(in, stream, false, false) - in.close() - stream.close() - assert(writer.getNumBytesWritten == D_LEN * 4) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(true)) - } - - test("copyStreamsWithNIO with a channel") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() - val byteBuffer = ByteBuffer.allocate(D_LEN * 4) - val intBuffer = byteBuffer.asIntBuffer() - intBuffer.put(data(p)) - val out = new FileOutputStream(tempFile) - out.write(byteBuffer.array()) - out.close() - val in = new FileInputStream(tempFile) - channel.transferFrom(in.getChannel, 0L, byteBuffer.remaining()) - channel.close() - assert(writer.getNumBytesWritten == D_LEN * 4) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(true)) - } -} From 6bd53ec228600d5358ba530811a08e8b21aad55a Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 10 Sep 2019 18:23:10 -0700 Subject: [PATCH 07/14] More build fixes --- .../api/MapOutputWriterCommitMessage.java | 4 +- .../spark/shuffle/api/ShuffleBlockInfo.java | 4 +- .../spark/shuffle/api/SupportsTransferTo.java | 53 ------------------ .../api/TransferrableWritableByteChannel.java | 54 ------------------- ...faultTransferrableWritableByteChannel.java | 52 ------------------ .../sort/io/LocalDiskShuffleDataIO.java | 4 +- .../io/LocalDiskShuffleMapOutputWriter.java | 6 +-- ... => LocalDiskShuffleDriverComponents.java} | 2 +- .../scala/org/apache/spark/SparkContext.scala | 1 - .../shuffle/sort/SortShuffleManager.scala | 2 +- .../util/collection/ExternalSorter.scala | 3 -- .../sort/UnsafeShuffleWriterSuite.java | 4 ++ .../BlockStoreShuffleReaderSuite.scala | 20 +++++-- .../BlockStoreShuffleReaderBenchmark.scala | 12 +++-- .../BypassMergeSortShuffleWriterSuite.scala | 21 +++++--- .../sort/ShuffleWriterBenchmarkBase.scala | 1 - .../shuffle/sort/SortShuffleWriterSuite.scala | 4 +- ...LocalDiskShuffleMapOutputWriterSuite.scala | 16 +++++- 18 files changed, 70 insertions(+), 193 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/SupportsTransferTo.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/TransferrableWritableByteChannel.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java rename core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/{DefaultShuffleDriverComponents.java => LocalDiskShuffleDriverComponents.java} (96%) diff --git a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java index dc51f8962206f..f406f249a9dd9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java @@ -2,10 +2,10 @@ import java.util.Optional; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Private; import org.apache.spark.storage.BlockManagerId; -@Experimental +@Private public final class MapOutputWriterCommitMessage { private final long[] partitionLengths; diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java index 66270a512b0e7..72a67c76f28b5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java @@ -17,11 +17,11 @@ package org.apache.spark.shuffle.api; +import java.util.Objects; + import org.apache.spark.api.java.Optional; import org.apache.spark.storage.BlockManagerId; -import java.util.Objects; - /** * :: Experimental :: * An object defining the shuffle block and length metadata associated with the block. diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SupportsTransferTo.java b/core/src/main/java/org/apache/spark/shuffle/api/SupportsTransferTo.java deleted file mode 100644 index ae8cb36b7e719..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/SupportsTransferTo.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.io.IOException; - -import org.apache.spark.annotation.Experimental; - -/** - * :: Experimental :: - * Indicates that partition writers can transfer bytes directly from input byte channels to - * output channels that stream data to the underlying shuffle partition storage medium. - *

- * This API is separated out for advanced users because it only needs to be used for - * specific low-level optimizations. The idea is that the returned channel can transfer bytes - * from the input file channel out to the backing storage system without copying data into - * memory. - *

- * Most shuffle plugin implementations should use {@link ShufflePartitionWriter} instead. - * - * @since 3.0.0 - */ -@Experimental -public interface SupportsTransferTo extends ShufflePartitionWriter { - - /** - * Opens and returns a {@link TransferrableWritableByteChannel} for transferring bytes from - * input byte channels to the underlying shuffle data store. - */ - TransferrableWritableByteChannel openTransferrableChannel() throws IOException; - - /** - * Returns the number of bytes written either by this writer's output stream opened by - * {@link #openStream()} or the byte channel opened by {@link #openTransferrableChannel()}. - */ - @Override - long getNumBytesWritten(); -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/TransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/api/TransferrableWritableByteChannel.java deleted file mode 100644 index 76e0dfd8b5a05..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/TransferrableWritableByteChannel.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.io.Closeable; -import java.io.IOException; - -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; -import org.apache.spark.annotation.Experimental; - -/** - * :: Experimental :: - * Represents an output byte channel that can copy bytes from input file channels to some - * arbitrary storage system. - *

- * This API is provided for advanced users who can transfer bytes from a file channel to - * some output sink without copying data into memory. Most users should not need to use - * this functionality; this is primarily provided for the built-in shuffle storage backends - * that persist shuffle files on local disk. - *

- * For a simpler alternative, see {@link ShufflePartitionWriter}. - * - * @since 3.0.0 - */ -@Experimental -public interface TransferrableWritableByteChannel extends Closeable { - - /** - * Copy all bytes from the source readable byte channel into this byte channel. - * - * @param source File to transfer bytes from. Do not call anything on this channel other than - * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. - * @param transferStartPosition Start position of the input file to transfer from. - * @param numBytesToTransfer Number of bytes to transfer from the given source. - */ - void transferFrom(FileChannel source, long transferStartPosition, long numBytesToTransfer) - throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java deleted file mode 100644 index cb8ac86972d35..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort; - -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; -import org.apache.spark.shuffle.api.TransferrableWritableByteChannel; -import org.apache.spark.shuffle.api.SupportsTransferTo; -import org.apache.spark.util.Utils; - -/** - * This is used when transferTo is enabled but the shuffle plugin hasn't implemented - * {@link SupportsTransferTo}. - *

- * This default implementation exists as a convenience to the unsafe shuffle writer and - * the bypass merge sort shuffle writers. - */ -public class DefaultTransferrableWritableByteChannel implements TransferrableWritableByteChannel { - - private final WritableByteChannel delegate; - - public DefaultTransferrableWritableByteChannel(WritableByteChannel delegate) { - this.delegate = delegate; - } - - @Override - public void transferFrom( - FileChannel source, long transferStartPosition, long numBytesToTransfer) { - Utils.copyFileStreamNIO(source, delegate, transferStartPosition, numBytesToTransfer); - } - - @Override - public void close() throws IOException { - delegate.close(); - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java index 2db32a1f30860..77fcd34f962bf 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java @@ -21,7 +21,7 @@ import org.apache.spark.shuffle.api.ShuffleDriverComponents; import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleDataIO; -import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents; +import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents; /** * Implementation of the {@link ShuffleDataIO} plugin system that replicates the local shuffle @@ -37,7 +37,7 @@ public LocalDiskShuffleDataIO(SparkConf sparkConf) { @Override public ShuffleDriverComponents driver() { - return new DefaultShuffleDriverComponents(); + return new LocalDiskShuffleDriverComponents(); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index a4566cf908b79..064875420c473 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -26,17 +26,17 @@ import java.nio.channels.WritableByteChannel; import java.util.Optional; -import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; -import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.storage.BlockManagerId; import org.apache.spark.internal.config.package$; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.util.Utils; /** diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java similarity index 96% rename from core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java rename to core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java index c6893a49ed238..183769274841c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java @@ -26,7 +26,7 @@ import org.apache.spark.internal.config.package$; import org.apache.spark.storage.BlockManagerMaster; -public class DefaultShuffleDriverComponents implements ShuffleDriverComponents { +public class LocalDiskShuffleDriverComponents implements ShuffleDriverComponents { private BlockManagerMaster blockManagerMaster; private boolean shouldUnregisterOutputOnHostOnFetchFailure; diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8b9cf7a2e95ec..c84bc82b9a29f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -43,7 +43,6 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.conda.CondaEnvironment import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions -import org.apache.spark.api.shuffle.ShuffleDriverComponents import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 162cf9f0d420a..610c04ace3b6f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -130,7 +130,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition, context, metrics, - shuffleExecutorComponents.reads()) + shuffleExecutorComponents) } /** Get a writer for a given partition. Called on executors by map tasks. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 1c8334be9a2bb..6f9f0414a6fcd 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -24,14 +24,11 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams -import com.google.common.io.{ByteStreams, Closeables} import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ -import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} -import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} import org.apache.spark.shuffle.ShufflePartitionPairsWriter import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index da7f7dd143d37..18f3a339e246c 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -189,6 +189,8 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { new LocalDiskShuffleExecutorComponents( conf, blockManager, + mapOutputTracker, + serializerManager, shuffleBlockResolver, BlockManagerId.apply("localhost", 7077))); } @@ -552,6 +554,8 @@ public void testPeakMemoryUsed() throws Exception { new LocalDiskShuffleExecutorComponents( conf, blockManager, + mapOutputTracker, + serializerManager, shuffleBlockResolver, BlockManagerId.apply("localhost", 7077))); diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index b707941006638..a4258c0fd90b0 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import org.mockito.Mockito.{mock, when} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Mock +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -29,7 +31,7 @@ import org.apache.spark.internal.config import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} -import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId} /** @@ -59,6 +61,8 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying * ManagedBuffers that contain the data are eventually released. @@ -142,15 +146,21 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val shuffleReadSupport = - new LocalDiskShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf) + val shuffleExecutorComponents = + new LocalDiskShuffleExecutorComponents( + testConf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + localBlockManagerId) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, taskContext, metrics, - shuffleReadSupport, + shuffleExecutorComponents, serializerManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 640dde9847398..a8246aca20baa 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -39,8 +39,9 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer.{KryoSerializer, SerializerManager} -import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} +import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException, IndexShuffleBlockResolver} import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockAttemptId, ShuffleBlockId} import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils} @@ -67,6 +68,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { // this is only used when initiating the BlockManager, for comms between master and executor @Mock(answer = RETURNS_SMART_NULLS) private var rpcEnv: RpcEnv = _ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ private var tempDir: File = _ @@ -212,11 +214,13 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(sorter) - val readSupport = new LocalDiskShuffleReadSupport( + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + defaultConf, blockManager, mapOutputTracker, serializerManager, - defaultConf) + blockResolver, + blockManager.shuffleServerId) new BlockStoreShuffleReader[String, String]( shuffleHandle, @@ -224,7 +228,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { 1, taskContext, taskContext.taskMetrics().createTempShuffleReadMetrics(), - readSupport, + shuffleExecutorComponents, serializerManager, mapOutputTracker ) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 4e3f179b214fe..da1630e67a485 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.shuffle.sort import java.io.File -import java.util.{Properties, UUID} +import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -30,7 +31,6 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach -import scala.util.Random import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} @@ -84,7 +84,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte when(blockResolver.writeIndexFileAndCommit( anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) - .thenAnswer { invocationOnMock => + .thenAnswer { (invocationOnMock: InvocationOnMock) => val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] if (tmp != null) { outputFile.delete @@ -99,7 +99,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte any[SerializerInstance], anyInt(), any[ShuffleWriteMetrics])) - .thenAnswer { invocation => + .thenAnswer { (invocation: InvocationOnMock) => val args = invocation.getArguments val manager = new SerializerManager(new JavaSerializer(conf), conf) new DiskBlockObjectWriter( @@ -113,7 +113,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte } when(diskBlockManager.createTempShuffleBlock()) - .thenAnswer { _ => + .thenAnswer { (invocationOnMock: InvocationOnMock) => val blockId = new TempShuffleBlockId(UUID.randomUUID) val file = new File(tempDir, blockId.name) blockIdToFileMap.put(blockId, file) @@ -121,7 +121,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte (blockId, file) } - when(diskBlockManager.getFile(any[BlockId])).thenAnswer { invocation => + when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) => blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId]) } @@ -251,4 +251,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } + + /** + * This won't be necessary with Scala 2.12 + */ + private implicit def functionToAnswer[T](func: InvocationOnMock => T): Answer[T] = { + new Answer[T] { + override def answer(invocationOnMock: InvocationOnMock): T = func(invocationOnMock) + } + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala index e883eb61a2763..6decc9d4e2c84 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala @@ -30,7 +30,6 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.{HashPartitioner, MapOutputTracker, ShuffleDependency, SparkConf, TaskContext} - import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 5116ad834a521..326831749ce09 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -21,10 +21,10 @@ import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Mockito._ import org.scalatest.Matchers -import org.apache.spark.{MapOutputTracker, Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} +import org.apache.spark.{MapOutputTracker, Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} import org.apache.spark.memory.MemoryTestingUtils -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index 5156cc2cc47a6..1216edcf78219 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -27,10 +27,13 @@ import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mock import org.mockito.Mockito.when import org.mockito.MockitoAnnotations +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -39,6 +42,7 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA private var blockResolver: IndexShuffleBlockResolver = _ private val NUM_PARTITIONS = 4 + private val BLOCK_MANAGER_ID = BlockManagerId("localhost", 7077) private val data: Array[Array[Byte]] = (0 until NUM_PARTITIONS).map { p => if (p == 3) { Array.emptyByteArray @@ -76,7 +80,7 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) when(blockResolver.writeIndexFileAndCommit( anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) - .thenAnswer { invocationOnMock => + .thenAnswer { (invocationOnMock: InvocationOnMock) => partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] if (tmp != null) { @@ -90,6 +94,7 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA 0, NUM_PARTITIONS, blockResolver, + BLOCK_MANAGER_ID, conf) } @@ -142,4 +147,13 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA assert(mergedOutputFile.length() === partitionLengths.sum) assert(data === readRecordsFromFile()) } + + /** + * This won't be necessary with Scala 2.12 + */ + private implicit def functionToAnswer[T](func: InvocationOnMock => T): Answer[T] = { + new Answer[T] { + override def answer(invocationOnMock: InvocationOnMock): T = func(invocationOnMock) + } + } } From 99b7892bec5c9509c92f4bf441f1648ec7c06a26 Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 10 Sep 2019 19:00:36 -0700 Subject: [PATCH 08/14] More build fixes --- .../scala/org/apache/spark/util/collection/ExternalSorter.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 6f9f0414a6fcd..612574e804103 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,7 +29,6 @@ import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ -import org.apache.spark.shuffle.ShufflePartitionPairsWriter import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} import org.apache.spark.util.{Utils => TryUtils} From 15f4357cccbe8c0001c0bc1d75448b32461ccd9f Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 11 Sep 2019 11:44:22 -0700 Subject: [PATCH 09/14] Attempt to fix build --- core/src/test/scala/org/apache/spark/ShuffleSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 8da45ac6261b2..6eb8251ec4002 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -413,8 +413,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) - TaskContext.unset() val readData = reader.read().toIndexedSeq + TaskContext.unset() assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) manager.unregisterShuffle(0) From e2bfe96a727fb111de5a6be2844464122618a96b Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 11 Sep 2019 14:56:31 -0700 Subject: [PATCH 10/14] More build fixes --- .../api/MapOutputWriterCommitMessage.java | 17 +++++++++++++++++ .../shuffle/BlockStoreShuffleReaderSuite.scala | 3 ++- .../LocalDiskShuffleMapOutputWriterSuite.scala | 4 +++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java index f406f249a9dd9..5a1c82499b715 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.shuffle.api; import java.util.Optional; diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index a4258c0fd90b0..966a6fa9d005f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer +import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.Mock import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -68,6 +68,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext * ManagedBuffers that contain the data are eventually released. */ test("read() releases resources on completion") { + MockitoAnnotations.initMocks(this) val testConf = new SparkConf(false) // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the // shuffle code calls SparkEnv.get()). diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index 1216edcf78219..8aa9f51e09494 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -143,7 +143,9 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA private def verifyWrittenRecords(): Unit = { val committedLengths = mapOutputWriter.commitAllPartitions() assert(partitionSizesInMergedFile === partitionLengths) - assert(committedLengths === partitionLengths) + assert(committedLengths.getPartitionLengths === partitionLengths) + assert(committedLengths.getLocation.isPresent) + assert(committedLengths.getLocation.get === BLOCK_MANAGER_ID) assert(mergedOutputFile.length() === partitionLengths.sum) assert(data === readRecordsFromFile()) } From b3d87e20e395b9c5654ad8d1c302fba44a3da7c3 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 12 Sep 2019 18:12:43 -0700 Subject: [PATCH 11/14] [SPARK-29072] Put back usage of TimeTrackingOutputStream for UnsafeShuffleWriter and ShufflePartitionPairsWriter. --- .../spark/shuffle/sort/UnsafeShuffleWriter.java | 2 ++ .../spark/shuffle/ShufflePartitionPairsWriter.scala | 13 +++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 1ec9f590c0bbb..acb86616066a8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -58,6 +58,7 @@ import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; @@ -388,6 +389,7 @@ private void mergeSpillsWithFileStream( ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); OutputStream partitionOutput = writer.openStream(); try { + partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala index a988c5e126a76..e83254025b883 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala @@ -21,7 +21,7 @@ import java.io.{Closeable, IOException, OutputStream} import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.api.ShufflePartitionWriter -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream} import org.apache.spark.util.Utils import org.apache.spark.util.collection.PairsWriter @@ -39,6 +39,7 @@ private[spark] class ShufflePartitionPairsWriter( private var isClosed = false private var partitionStream: OutputStream = _ + private var timeTrackingStream: OutputStream = _ private var wrappedStream: OutputStream = _ private var objOut: SerializationStream = _ private var numRecordsWritten = 0 @@ -59,6 +60,7 @@ private[spark] class ShufflePartitionPairsWriter( private def open(): Unit = { try { partitionStream = partitionWriter.openStream + timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream) wrappedStream = serializerManager.wrapStream(blockId, partitionStream) objOut = serializerInstance.serializeStream(wrappedStream) } catch { @@ -78,6 +80,7 @@ private[spark] class ShufflePartitionPairsWriter( // Setting these to null will prevent the underlying streams from being closed twice // just in case any stream's close() implementation is not idempotent. wrappedStream = null + timeTrackingStream = null partitionStream = null } { // Normally closing objOut would close the inner streams as well, but just in case there @@ -86,9 +89,15 @@ private[spark] class ShufflePartitionPairsWriter( wrappedStream = closeIfNonNull(wrappedStream) // Same as above - if wrappedStream closes then assume it closes underlying // partitionStream and don't close again in the finally + timeTrackingStream = null partitionStream = null } { - partitionStream = closeIfNonNull(partitionStream) + Utils.tryWithSafeFinally { + timeTrackingStream = closeIfNonNull(timeTrackingStream) + partitionStream = null + } { + partitionStream = closeIfNonNull(partitionStream) + } } } updateBytesWritten() From 182f62c0e3581ac6736849c550ff777851b88c5a Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 12 Sep 2019 18:21:45 -0700 Subject: [PATCH 12/14] Address comments --- .../util/collection/ExternalSorter.scala | 4 +- .../ShufflePartitionPairsWriter.scala | 91 ------------------- 2 files changed, 1 insertion(+), 94 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 612574e804103..b5421b9d977f8 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -30,6 +30,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} +import org.apache.spark.shuffle.ShufflePartitionPairsWriter import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} import org.apache.spark.util.{Utils => TryUtils} @@ -732,7 +733,6 @@ private[spark] class ExternalSorter[K, V, C]( shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Unit = { - var nextPartitionId = 0 if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer @@ -758,7 +758,6 @@ private[spark] class ExternalSorter[K, V, C]( partitionPairsWriter.close() } } - nextPartitionId = partitionId + 1 } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. @@ -784,7 +783,6 @@ private[spark] class ExternalSorter[K, V, C]( partitionPairsWriter.close() } } - nextPartitionId = id + 1 } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala deleted file mode 100644 index 62f17a8e3cfbd..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.{Closeable, FilterOutputStream, OutputStream} - -import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter -import org.apache.spark.shuffle.api.ShufflePartitionWriter -import org.apache.spark.storage.BlockId - -/** - * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an - * arbitrary partition writer instead of writing to local disk through the block manager. - */ -private[spark] class ShufflePartitionPairsWriter( - partitionWriter: ShufflePartitionWriter, - serializerManager: SerializerManager, - serializerInstance: SerializerInstance, - blockId: BlockId, - writeMetrics: ShuffleWriteMetricsReporter) - extends PairsWriter with Closeable { - - private var isOpen = false - private var partitionStream: OutputStream = _ - private var wrappedStream: OutputStream = _ - private var objOut: SerializationStream = _ - private var numRecordsWritten = 0 - private var curNumBytesWritten = 0L - - override def write(key: Any, value: Any): Unit = { - if (!isOpen) { - open() - isOpen = true - } - objOut.writeKey(key) - objOut.writeValue(value) - writeMetrics.incRecordsWritten(1) - } - - private def open(): Unit = { - partitionStream = partitionWriter.openStream - wrappedStream = serializerManager.wrapStream(blockId, partitionStream) - objOut = serializerInstance.serializeStream(wrappedStream) - } - - override def close(): Unit = { - if (isOpen) { - objOut.close() - objOut = null - wrappedStream = null - partitionStream = null - isOpen = false - updateBytesWritten() - } - } - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - private def recordWritten(): Unit = { - numRecordsWritten += 1 - writeMetrics.incRecordsWritten(1) - - if (numRecordsWritten % 16384 == 0) { - updateBytesWritten() - } - } - - private def updateBytesWritten(): Unit = { - val numBytesWritten = partitionWriter.getNumBytesWritten - val bytesWrittenDiff = numBytesWritten - curNumBytesWritten - writeMetrics.incBytesWritten(bytesWrittenDiff) - curNumBytesWritten = numBytesWritten - } -} From d4bcd9392bec6a44608feba5f4c404a3220e895e Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 12 Sep 2019 18:22:56 -0700 Subject: [PATCH 13/14] Import ordering --- .../scala/org/apache/spark/util/collection/ExternalSorter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b5421b9d977f8..1216a45415a74 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,8 +29,8 @@ import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ -import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.shuffle.ShufflePartitionPairsWriter +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} import org.apache.spark.util.{Utils => TryUtils} From f27e0fdc8c315b9617d1fedc6aadd9dba55a7a91 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 12 Sep 2019 18:27:03 -0700 Subject: [PATCH 14/14] Fix stream reference --- .../org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala index e83254025b883..e0affb858c359 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala @@ -61,7 +61,7 @@ private[spark] class ShufflePartitionPairsWriter( try { partitionStream = partitionWriter.openStream timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream) - wrappedStream = serializerManager.wrapStream(blockId, partitionStream) + wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream) objOut = serializerInstance.serializeStream(wrappedStream) } catch { case e: Exception =>