diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java index 5119e34803a85..0a95193b179f7 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java @@ -29,7 +29,7 @@ */ @Experimental public interface ShuffleMapOutputWriter { - ShufflePartitionWriter getNextPartitionWriter() throws IOException; + ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException; void commitAllPartitions() throws IOException; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index aef133fe7d46a..16fece302c3bc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -204,7 +204,7 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro boolean copyThrewException = true; ShufflePartitionWriter writer = null; try { - writer = mapOutputWriter.getNextPartitionWriter(); + writer = mapOutputWriter.getPartitionWriter(i); if (!file.exists()) { copyThrewException = false; } else { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index d7a6d6450ebc0..e97e061034425 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -285,18 +285,6 @@ private long[] mergeSpills(SpillInfo[] spills, long[] partitionLengths = new long[numPartitions]; try { if (spills.length == 0) { - // The contract we are working under states that we will open a partition writer for - // each partition, regardless of number of spills - for (int i = 0; i < numPartitions; i++) { - ShufflePartitionWriter writer = null; - try { - writer = mapWriter.getNextPartitionWriter(); - } finally { - if (writer != null) { - writer.close(); - } - } - } return partitionLengths; } else { // There are multiple spills to merge, so none of these spill files' lengths were counted @@ -372,7 +360,7 @@ private long[] mergeSpillsWithFileStream( boolean copyThrewExecption = true; ShufflePartitionWriter writer = null; try { - writer = mapWriter.getNextPartitionWriter(); + writer = mapWriter.getPartitionWriter(partition); OutputStream partitionOutput = null; try { // Shield the underlying output stream from close() calls, so that we can close the @@ -451,7 +439,7 @@ private long[] mergeSpillsWithTransferTo( boolean copyThrewExecption = true; ShufflePartitionWriter writer = null; try { - writer = mapWriter.getNextPartitionWriter(); + writer = mapWriter.getPartitionWriter(partition); WritableByteChannel channel = writer.toChannel(); for (int i = 0; i < spills.length; i++) { long partitionLengthInSpill = 0L; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java index c84158e1891d7..e97eb930ba501 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -47,7 +47,7 @@ public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final IndexShuffleBlockResolver blockResolver; private final long[] partitionLengths; private final int bufferSize; - private int currPartitionId = 0; + private int lastPartitionId = -1; private long currChannelPosition; private final File outputFile; @@ -77,7 +77,11 @@ public DefaultShuffleMapOutputWriter( } @Override - public ShufflePartitionWriter getNextPartitionWriter() throws IOException { + public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException { + if (partitionId <= lastPartitionId) { + throw new IllegalArgumentException("Partitions should be requested in increasing order."); + } + lastPartitionId = partitionId; if (outputTempFile == null) { outputTempFile = Utils.tempFileWith(outputFile); } @@ -86,7 +90,7 @@ public ShufflePartitionWriter getNextPartitionWriter() throws IOException { } else { currChannelPosition = 0L; } - return new DefaultShufflePartitionWriter(currPartitionId++); + return new DefaultShufflePartitionWriter(partitionId); } @Override diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 8ccc1dfc9b3f1..df5ce73b9acf1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -721,17 +721,6 @@ private[spark] class ExternalSorter[K, V, C]( lengths } - private def writeEmptyPartition(mapOutputWriter: ShuffleMapOutputWriter): Unit = { - var partitionWriter: ShufflePartitionWriter = null - try { - partitionWriter = mapOutputWriter.getNextPartitionWriter - } finally { - if (partitionWriter != null) { - partitionWriter.close() - } - } - } - /** * Write all the data added into this ExternalSorter into a map output writer that pushes bytes * to some arbitrary backing store. This is called by the SortShuffleWriter. @@ -742,26 +731,16 @@ private[spark] class ExternalSorter[K, V, C]( shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { // Track location of each range in the map output val lengths = new Array[Long](numPartitions) - var nextPartitionId = 0 if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext()) { val partitionId = it.nextPartition() - // The contract for the plugin is that we will ask for a writer for every partition - // even if it's empty. However, the external sorter will return non-contiguous - // partition ids. So this loop "backfills" the empty partitions that form the gaps. - - // The algorithm as a whole is correct because the partition ids are returned by the - // iterator in ascending order. - for (emptyPartition <- nextPartitionId until partitionId) { - writeEmptyPartition(mapOutputWriter) - } var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null try { - partitionWriter = mapOutputWriter.getNextPartitionWriter + partitionWriter = mapOutputWriter.getPartitionWriter(partitionId) val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) partitionPairsWriter = new ShufflePartitionPairsWriter( partitionWriter, @@ -783,7 +762,6 @@ private[spark] class ExternalSorter[K, V, C]( if (partitionWriter != null) { lengths(partitionId) = partitionWriter.getNumBytesWritten } - nextPartitionId = partitionId + 1 } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. @@ -794,14 +772,11 @@ private[spark] class ExternalSorter[K, V, C]( // The algorithm as a whole is correct because the partition ids are returned by the // iterator in ascending order. - for (emptyPartition <- nextPartitionId until id) { - writeEmptyPartition(mapOutputWriter) - } val blockId = ShuffleBlockId(shuffleId, mapId, id) var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null try { - partitionWriter = mapOutputWriter.getNextPartitionWriter + partitionWriter = mapOutputWriter.getPartitionWriter(id) partitionPairsWriter = new ShufflePartitionPairsWriter( partitionWriter, serializerManager, @@ -821,16 +796,9 @@ private[spark] class ExternalSorter[K, V, C]( if (partitionWriter != null) { lengths(id) = partitionWriter.getNumBytesWritten } - nextPartitionId = id + 1 } } - // The iterator may have stopped short of opening a writer for every partition. So fill in the - // remaining empty partitions. - for (emptyPartition <- nextPartitionId until numPartitions) { - writeEmptyPartition(mapOutputWriter) - } - context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala index 22d52924a7c72..6a3666b4ad771 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -133,7 +133,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("writing to an outputstream") { (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getNextPartitionWriter + val writer = mapOutputWriter.getPartitionWriter(p) val stream = writer.toStream() data(p).foreach { i => stream.write(i)} stream.close() @@ -152,7 +152,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("writing to a channel") { (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getNextPartitionWriter + val writer = mapOutputWriter.getPartitionWriter(p) val channel = writer.toChannel() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() @@ -172,7 +172,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("copyStreams with an outputstream") { (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getNextPartitionWriter + val writer = mapOutputWriter.getPartitionWriter(p) val stream = writer.toStream() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() @@ -193,7 +193,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("copyStreamsWithNIO with a channel") { (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getNextPartitionWriter + val writer = mapOutputWriter.getPartitionWriter(p) val channel = writer.toChannel() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer()