diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java deleted file mode 100644 index dd7c0ac7320cb..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; - -/** - * :: Experimental :: - * An interface for launching Shuffle related components - * - * @since 3.0.0 - */ -@Experimental -public interface ShuffleDataIO { - String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."; - - ShuffleDriverComponents driver(); - ShuffleExecutorComponents executor(); -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java deleted file mode 100644 index a5fa032bf651d..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; - -import java.util.Map; - -/** - * :: Experimental :: - * An interface for building shuffle support for Executors - * - * @since 3.0.0 - */ -@Experimental -public interface ShuffleExecutorComponents { - void initializeExecutor(String appId, String execId, Map extraConfigs); - - ShuffleWriteSupport writes(); - - ShuffleReadSupport reads(); -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java deleted file mode 100644 index 025fc096faaad..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import java.io.IOException; - -import org.apache.spark.annotation.Experimental; -import org.apache.spark.api.java.Optional; -import org.apache.spark.storage.BlockManagerId; - -/** - * :: Experimental :: - * An interface for creating and managing shuffle partition writers - * - * @since 3.0.0 - */ -@Experimental -public interface ShuffleMapOutputWriter { - ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException; - - Optional commitAllPartitions() throws IOException; - - void abort(Throwable error) throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java deleted file mode 100644 index 83947bd4d6fa4..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import org.apache.spark.annotation.Experimental; - -import java.io.IOException; -import java.io.InputStream; - -/** - * :: Experimental :: - * An interface for reading shuffle records. - * @since 3.0.0 - */ -@Experimental -public interface ShuffleReadSupport { - /** - * Returns an underlying {@link Iterable} that will iterate - * through shuffle data, given an iterable for the shuffle blocks to fetch. - */ - Iterable getPartitionReaders(Iterable blockMetadata) - throws IOException; - - default boolean shouldWrapStream() { - return true; - } -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java deleted file mode 100644 index 866b61d0bafd9..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import java.io.IOException; - -import org.apache.spark.annotation.Experimental; - -/** - * :: Experimental :: - * Indicates that partition writers can transfer bytes directly from input byte channels to - * output channels that stream data to the underlying shuffle partition storage medium. - *

- * This API is separated out for advanced users because it only needs to be used for - * specific low-level optimizations. The idea is that the returned channel can transfer bytes - * from the input file channel out to the backing storage system without copying data into - * memory. - *

- * Most shuffle plugin implementations should use {@link ShufflePartitionWriter} instead. - * - * @since 3.0.0 - */ -@Experimental -public interface SupportsTransferTo extends ShufflePartitionWriter { - - /** - * Opens and returns a {@link TransferrableWritableByteChannel} for transferring bytes from - * input byte channels to the underlying shuffle data store. - */ - TransferrableWritableByteChannel openTransferrableChannel() throws IOException; - - /** - * Returns the number of bytes written either by this writer's output stream opened by - * {@link #openStream()} or the byte channel opened by {@link #openTransferrableChannel()}. - */ - @Override - long getNumBytesWritten(); -} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java deleted file mode 100644 index 18234d7c4c944..0000000000000 --- a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.shuffle; - -import java.io.Closeable; -import java.io.IOException; - -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; -import org.apache.spark.annotation.Experimental; - -/** - * :: Experimental :: - * Represents an output byte channel that can copy bytes from input file channels to some - * arbitrary storage system. - *

- * This API is provided for advanced users who can transfer bytes from a file channel to - * some output sink without copying data into memory. Most users should not need to use - * this functionality; this is primarily provided for the built-in shuffle storage backends - * that persist shuffle files on local disk. - *

- * For a simpler alternative, see {@link ShufflePartitionWriter}. - * - * @since 3.0.0 - */ -@Experimental -public interface TransferrableWritableByteChannel extends Closeable { - - /** - * Copy all bytes from the source readable byte channel into this byte channel. - * - * @param source File to transfer bytes from. Do not call anything on this channel other than - * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. - * @param transferStartPosition Start position of the input file to transfer from. - * @param numBytesToTransfer Number of bytes to transfer from the given source. - */ - void transferFrom(FileChannel source, long transferStartPosition, long numBytesToTransfer) - throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java new file mode 100644 index 0000000000000..5a1c82499b715 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.util.Optional; + +import org.apache.spark.annotation.Private; +import org.apache.spark.storage.BlockManagerId; + +@Private +public final class MapOutputWriterCommitMessage { + + private final long[] partitionLengths; + private final Optional location; + + private MapOutputWriterCommitMessage( + long[] partitionLengths, Optional location) { + this.partitionLengths = partitionLengths; + this.location = location; + } + + public static MapOutputWriterCommitMessage of(long[] partitionLengths) { + return new MapOutputWriterCommitMessage(partitionLengths, Optional.empty()); + } + + public static MapOutputWriterCommitMessage of( + long[] partitionLengths, BlockManagerId location) { + return new MapOutputWriterCommitMessage(partitionLengths, Optional.of(location)); + } + + public long[] getPartitionLengths() { + return partitionLengths; + } + + public Optional getLocation() { + return location; + } +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java similarity index 98% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java index 34daf2c137a12..72a67c76f28b5 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; + +import java.util.Objects; import org.apache.spark.api.java.Optional; import org.apache.spark.storage.BlockManagerId; -import java.util.Objects; - /** * :: Experimental :: * An object defining the shuffle block and length metadata associated with the block. diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java new file mode 100644 index 0000000000000..5126f0c3577f8 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * An interface for plugging in modules for storing and reading temporary shuffle data. + *

+ * This is the root of a plugin system for storing shuffle bytes to arbitrary storage + * backends in the sort-based shuffle algorithm implemented by the + * {@link org.apache.spark.shuffle.sort.SortShuffleManager}. If another shuffle algorithm is + * needed instead of sort-based shuffle, one should implement + * {@link org.apache.spark.shuffle.ShuffleManager} instead. + *

+ * A single instance of this module is loaded per process in the Spark application. + * The default implementation reads and writes shuffle data from the local disks of + * the executor, and is the implementation of shuffle file storage that has remained + * consistent throughout most of Spark's history. + *

+ * Alternative implementations of shuffle data storage can be loaded via setting + * spark.shuffle.sort.io.plugin.class. + * @since 3.0.0 + */ +@Private +public interface ShuffleDataIO { + + String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."; + + ShuffleDriverComponents driver(); + + /** + * Called once on executor processes to bootstrap the shuffle data storage modules that + * are only invoked on the executors. + */ + ShuffleExecutorComponents executor(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java similarity index 97% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java rename to core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java index 8b54968f9b134..cbc59bc7b6a05 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; import java.io.IOException; import java.util.Map; diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java new file mode 100644 index 0000000000000..94c07009f3180 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; +import java.util.Optional; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * An interface for building shuffle support for Executors. + * + * @since 3.0.0 + */ +@Private +public interface ShuffleExecutorComponents { + + /** + * Called once per executor to bootstrap this module with state that is specific to + * that executor, specifically the application ID and executor ID. + */ + void initializeExecutor(String appId, String execId, Map extraConfigs); + + /** + * Called once per map task to create a writer that will be responsible for persisting all the + * partitioned bytes written by that map task. + * + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param mapId Within the shuffle, the identifier of the map task + * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task + * with the same (shuffleId, mapId) pair can be distinguished by the + * different values of mapTaskAttemptId. + * @param numPartitions The number of partitions that will be written by the map task. Some of + * these partitions may be empty. + */ + ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId, + int numPartitions) throws IOException; + + /** + * Returns an underlying {@link Iterable} that will iterate + * through shuffle data, given an iterable for the shuffle blocks to fetch. + */ + Iterable getPartitionReaders(Iterable blockMetadata) + throws IOException; + + default boolean shouldWrapPartitionReaderStream() { + return true; + } + + /** + * An optional extension for creating a map output writer that can optimize the transfer of a + * single partition file, as the entire result of a map task, to the backing store. + *

+ * Most implementations should return the default {@link Optional#empty()} to indicate that + * they do not support this optimization. This primarily is for backwards-compatibility in + * preserving an optimization in the local disk shuffle storage implementation. + * + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param mapId Within the shuffle, the identifier of the map task + * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task + * with the same (shuffleId, mapId) pair can be distinguished by the + * different values of mapTaskAttemptId. + */ + default Optional createSingleFileMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId) throws IOException { + return Optional.empty(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java new file mode 100644 index 0000000000000..8fcc73ba3c9b2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.IOException; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * A top-level writer that returns child writers for persisting the output of a map task, + * and then commits all of the writes as one atomic operation. + * + * @since 3.0.0 + */ +@Private +public interface ShuffleMapOutputWriter { + + /** + * Creates a writer that can open an output stream to persist bytes targeted for a given reduce + * partition id. + *

+ * The chunk corresponds to bytes in the given reduce partition. This will not be called twice + * for the same partition within any given map task. The partition identifier will be in the + * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was + * provided upon the creation of this map output writer via + * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}. + *

+ * Calls to this method will be invoked with monotonically increasing reducePartitionIds; each + * call to this method will be called with a reducePartitionId that is strictly greater than + * the reducePartitionIds given to any previous call to this method. This method is not + * guaranteed to be called for every partition id in the above described range. In particular, + * no guarantees are made as to whether or not this method will be called for empty partitions. + */ + ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException; + + /** + * Commits the writes done by all partition writers returned by all calls to this object's + * {@link #getPartitionWriter(int)}, and returns a bundle of metadata associated with the + * behavior of the write. + *

+ * This should ensure that the writes conducted by this module's partition writers are + * available to downstream reduce tasks. If this method throws any exception, this module's + * {@link #abort(Throwable)} method will be invoked before propagating the exception. + *

+ * This can also close any resources and clean up temporary state if necessary. + *

+ * The returned array should contain two sets of metadata: + * + * 1. For each partition from (0) to (numPartitions - 1), the number of bytes written by + * the partition writer for that partition id. + * + * 2. If the partition data was stored on the local disk of this executor, also provide + * the block manager id where these bytes can be fetched from. + */ + MapOutputWriterCommitMessage commitAllPartitions() throws IOException; + + /** + * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}. + *

+ * This should invalidate the results of writing bytes. This can also close any resources and + * clean up temporary state if necessary. + */ + void abort(Throwable error) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java new file mode 100644 index 0000000000000..928875156a70f --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.IOException; +import java.util.Optional; +import java.io.OutputStream; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * An interface for opening streams to persist partition bytes to a backing data store. + *

+ * This writer stores bytes for one (mapper, reducer) pair, corresponding to one shuffle + * block. + * + * @since 3.0.0 + */ +@Private +public interface ShufflePartitionWriter { + + /** + * Open and return an {@link OutputStream} that can write bytes to the underlying + * data store. + *

+ * This method will only be called once on this partition writer in the map task, to write the + * bytes to the partition. The output stream will only be used to write the bytes for this + * partition. The map task closes this output stream upon writing all the bytes for this + * block, or if the write fails for any reason. + *

+ * Implementations that intend on combining the bytes for all the partitions written by this + * map task should reuse the same OutputStream instance across all the partition writers provided + * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that + * {@link OutputStream#close()} does not close the resource, since it will be reused across + * partition writes. The underlying resources should be cleaned up in + * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + */ + OutputStream openStream() throws IOException; + + /** + * Opens and returns a {@link WritableByteChannelWrapper} for transferring bytes from + * input byte channels to the underlying shuffle data store. + *

+ * This method will only be called once on this partition writer in the map task, to write the + * bytes to the partition. The channel will only be used to write the bytes for this + * partition. The map task closes this channel upon writing all the bytes for this + * block, or if the write fails for any reason. + *

+ * Implementations that intend on combining the bytes for all the partitions written by this + * map task should reuse the same channel instance across all the partition writers provided + * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that + * {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel + * will be reused across partition writes. The underlying resources should be cleaned up in + * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + *

+ * This method is primarily for advanced optimizations where bytes can be copied from the input + * spill files to the output channel without copying data into memory. If such optimizations are + * not supported, the implementation should return {@link Optional#empty()}. By default, the + * implementation returns {@link Optional#empty()}. + *

+ * Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the + * underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure + * that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()}, + * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + */ + default Optional openChannelWrapper() throws IOException { + return Optional.empty(); + } + + /** + * Returns the number of bytes written either by this writer's output stream opened by + * {@link #openStream()} or the byte channel opened by {@link #openChannelWrapper()}. + *

+ * This can be different from the number of bytes given by the caller. For example, the + * stream might compress or encrypt the bytes before persisting the data to the backing + * data store. + */ + long getNumBytesWritten(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java similarity index 62% rename from core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java rename to core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java index 7ee1d8a554073..bddb97bdf0d7e 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java @@ -15,23 +15,23 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; +import java.io.File; import java.io.IOException; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Private; /** - * :: Experimental :: - * An interface for deploying a shuffle map output writer - * - * @since 3.0.0 + * Optional extension for partition writing that is optimized for transferring a single + * file to the backing store. */ -@Experimental -public interface ShuffleWriteSupport { - ShuffleMapOutputWriter createMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId, - int numPartitions) throws IOException; +@Private +public interface SingleSpillShuffleMapOutputWriter { + + /** + * Transfer a file that contains the bytes of all the partitions written by this map task. + */ + MapOutputWriterCommitMessage transferMapSpillFile( + File mapOutputFile, long[] partitionLengths) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java similarity index 59% rename from core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java rename to core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java index 74c928b0b9c8f..a204903008a51 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java @@ -15,30 +15,28 @@ * limitations under the License. */ -package org.apache.spark.api.shuffle; +package org.apache.spark.shuffle.api; -import java.io.IOException; -import java.io.OutputStream; +import java.io.Closeable; +import java.nio.channels.WritableByteChannel; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Private; /** - * :: Experimental :: - * An interface for giving streams / channels for shuffle writes. + * :: Private :: + * + * A thin wrapper around a {@link WritableByteChannel}. + *

+ * This is primarily provided for the local disk shuffle implementation to provide a + * {@link java.nio.channels.FileChannel} that keeps the channel open across partition writes. * * @since 3.0.0 */ -@Experimental -public interface ShufflePartitionWriter { - - /** - * Opens and returns an underlying {@link OutputStream} that can write bytes to the underlying - * data store. - */ - OutputStream openStream() throws IOException; +@Private +public interface WritableByteChannelWrapper extends Closeable { /** - * Get the number of bytes written by this writer's stream returned by {@link #openStream()}. + * The underlying channel to write bytes into. */ - long getNumBytesWritten(); + WritableByteChannel channel(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 63aee8ad50da3..94ad5fc66185b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,11 +21,10 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.OutputStream; -import java.nio.channels.Channels; import java.nio.channels.FileChannel; +import java.util.Optional; import javax.annotation.Nullable; -import org.apache.spark.api.java.Optional; import scala.None$; import scala.Option; import scala.Product2; @@ -40,11 +39,11 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; -import org.apache.spark.api.shuffle.SupportsTransferTo; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.internal.config.package$; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -90,13 +89,13 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int mapId; private final long mapTaskAttemptId; private final Serializer serializer; - private final ShuffleWriteSupport shuffleWriteSupport; + private final ShuffleExecutorComponents shuffleExecutorComponents; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; - private long[] partitionLengths; + private MapOutputWriterCommitMessage commitMessage; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -112,34 +111,33 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { long mapTaskAttemptId, SparkConf conf, ShuffleWriteMetricsReporter writeMetrics, - ShuffleWriteSupport shuffleWriteSupport) { + ShuffleExecutorComponents shuffleExecutorComponents) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); this.blockManager = blockManager; final ShuffleDependency dep = handle.dependency(); this.mapId = mapId; - this.shuffleId = dep.shuffleId(); this.mapTaskAttemptId = mapTaskAttemptId; + this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); - this.shuffleWriteSupport = shuffleWriteSupport; + this.shuffleExecutorComponents = shuffleExecutorComponents; } @Override public void write(Iterator> records) throws IOException { assert (partitionWriters == null); - ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport + ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); try { if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; - Optional location = mapOutputWriter.commitAllPartitions(); + commitMessage = mapOutputWriter.commitAllPartitions(); mapStatus = MapStatus$.MODULE$.apply( - location.orNull(), - partitionLengths, + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), mapTaskAttemptId); return; } @@ -172,14 +170,17 @@ public void write(Iterator> records) throws IOException { } } - partitionLengths = writePartitionedData(mapOutputWriter); - Optional location = mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(location.orNull(), partitionLengths, mapTaskAttemptId); + commitMessage = writePartitionedData(mapOutputWriter); + mapStatus = MapStatus$.MODULE$.apply( + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + mapTaskAttemptId); } catch (Exception e) { try { mapOutputWriter.abort(e); } catch (Exception e2) { logger.error("Failed to abort the writer after failing to write map output.", e2); + e.addSuppressed(e2); } throw e; } @@ -187,7 +188,7 @@ public void write(Iterator> records) throws IOException { @VisibleForTesting long[] getPartitionLengths() { - return partitionLengths; + return commitMessage.getPartitionLengths(); } /** @@ -195,61 +196,75 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException { + private MapOutputWriterCommitMessage writePartitionedData( + ShuffleMapOutputWriter mapOutputWriter) throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; - if (partitionWriters == null) { - // We were passed an empty iterator - return lengths; - } - final long writeStartTime = System.nanoTime(); - try { - for (int i = 0; i < numPartitions; i++) { - final File file = partitionWriterSegments[i].file(); - ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); - if (file.exists()) { - boolean copyThrewException = true; - if (transferToEnabled) { - FileInputStream in = new FileInputStream(file); - TransferrableWritableByteChannel outputChannel = null; - try (FileChannel inputChannel = in.getChannel()) { - if (writer instanceof SupportsTransferTo) { - outputChannel = ((SupportsTransferTo) writer).openTransferrableChannel(); + if (partitionWriters != null) { + final long writeStartTime = System.nanoTime(); + try { + for (int i = 0; i < numPartitions; i++) { + final File file = partitionWriterSegments[i].file(); + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); + if (file.exists()) { + if (transferToEnabled) { + // Using WritableByteChannelWrapper to make resource closing consistent between + // this implementation and UnsafeShuffleWriter. + Optional maybeOutputChannel = writer.openChannelWrapper(); + if (maybeOutputChannel.isPresent()) { + writePartitionedDataWithChannel(file, maybeOutputChannel.get()); } else { - // Use default transferrable writable channel anyways in order to have parity with - // UnsafeShuffleWriter. - outputChannel = new DefaultTransferrableWritableByteChannel( - Channels.newChannel(writer.openStream())); + writePartitionedDataWithStream(file, writer); } - outputChannel.transferFrom(inputChannel, 0L, inputChannel.size()); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - Closeables.close(outputChannel, copyThrewException); + } else { + writePartitionedDataWithStream(file, writer); } - } else { - FileInputStream in = new FileInputStream(file); - OutputStream outputStream = null; - try { - outputStream = writer.openStream(); - Utils.copyStream(in, outputStream, false, false); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - Closeables.close(outputStream, copyThrewException); + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); } } - if (!file.delete()) { - logger.error("Unable to delete file for partition {}", i); - } } - lengths[i] = writer.getNumBytesWritten(); + } finally { + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + } + return mapOutputWriter.commitAllPartitions(); + } + + private void writePartitionedDataWithChannel( + File file, + WritableByteChannelWrapper outputChannel) throws IOException { + boolean copyThrewException = true; + try { + FileInputStream in = new FileInputStream(file); + try (FileChannel inputChannel = in.getChannel()) { + Utils.copyFileStreamNIO( + inputChannel, outputChannel.channel(), 0L, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + } finally { + Closeables.close(outputChannel, copyThrewException); + } + } + + private void writePartitionedDataWithStream(File file, ShufflePartitionWriter writer) + throws IOException { + boolean copyThrewException = true; + FileInputStream in = new FileInputStream(file); + OutputStream outputStream; + try { + outputStream = writer.openStream(); + try { + Utils.copyStream(in, outputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(outputStream, copyThrewException); } } finally { - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + Closeables.close(in, copyThrewException); } - partitionWriters = null; - return lengths; } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java deleted file mode 100644 index 64ce851e392d2..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort; - -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; -import org.apache.spark.util.Utils; - -/** - * This is used when transferTo is enabled but the shuffle plugin hasn't implemented - * {@link org.apache.spark.api.shuffle.SupportsTransferTo}. - *

- * This default implementation exists as a convenience to the unsafe shuffle writer and - * the bypass merge sort shuffle writers. - */ -public class DefaultTransferrableWritableByteChannel implements TransferrableWritableByteChannel { - - private final WritableByteChannel delegate; - - public DefaultTransferrableWritableByteChannel(WritableByteChannel delegate) { - this.delegate = delegate; - } - - @Override - public void transferFrom( - FileChannel source, long transferStartPosition, long numBytesToTransfer) { - Utils.copyFileStreamNIO(source, delegate, transferStartPosition, numBytesToTransfer); - } - - @Override - public void close() throws IOException { - delegate.close(); - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9627f1151f837..acb86616066a8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -18,13 +18,13 @@ package org.apache.spark.shuffle.sort; import java.nio.channels.Channels; +import java.util.Optional; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; import java.util.Iterator; -import org.apache.spark.api.java.Optional; -import org.apache.spark.storage.BlockManagerId; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -39,11 +39,6 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.api.shuffle.SupportsTransferTo; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; @@ -56,8 +51,16 @@ import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -74,7 +77,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; - private final ShuffleWriteSupport shuffleWriteSupport; + private final ShuffleExecutorComponents shuffleExecutorComponents; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -111,7 +114,7 @@ public UnsafeShuffleWriter( TaskContext taskContext, SparkConf sparkConf, ShuffleWriteMetricsReporter writeMetrics, - ShuffleWriteSupport shuffleWriteSupport) throws IOException { + ShuffleExecutorComponents shuffleExecutorComponents) { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -127,7 +130,7 @@ public UnsafeShuffleWriter( this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = writeMetrics; - this.shuffleWriteSupport = shuffleWriteSupport; + this.shuffleExecutorComponents = shuffleExecutorComponents; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); @@ -216,34 +219,20 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport - .createMapOutputWriter(shuffleId, - mapId, - taskContext.taskAttemptId(), - partitioner.numPartitions()); - final long[] partitionLengths; - Optional location; + final MapOutputWriterCommitMessage commitMessage; try { - try { - partitionLengths = mergeSpills(spills, mapWriter); - } finally { - for (SpillInfo spill : spills) { - if (spill.file.exists() && !spill.file.delete()) { - logger.error("Error while deleting spill file {}", spill.file.getPath()); - } + commitMessage = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); } } - location = mapWriter.commitAllPartitions(); - } catch (Exception e) { - try { - mapWriter.abort(e); - } catch (Exception innerE) { - logger.error("Failed to abort the Map Output Writer", innerE); - } - throw e; } mapStatus = MapStatus$.MODULE$.apply( - location.orNull(), partitionLengths, taskContext.attemptNumber()); + commitMessage.getLocation().orElse(null), + commitMessage.getPartitionLengths(), + taskContext.attemptNumber()); } @VisibleForTesting @@ -275,57 +264,94 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, - ShuffleMapOutputWriter mapWriter) throws IOException { + private MapOutputWriterCommitMessage mergeSpills(SpillInfo[] spills) throws IOException { + MapOutputWriterCommitMessage commitMessage; + if (spills.length == 0) { + final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents + .createMapOutputWriter( + shuffleId, + mapId, + taskContext.taskAttemptId(), + partitioner.numPartitions()); + return mapWriter.commitAllPartitions(); + } else if (spills.length == 1) { + Optional maybeSingleFileWriter = + shuffleExecutorComponents.createSingleFileMapOutputWriter( + shuffleId, mapId, taskContext.taskAttemptId()); + if (maybeSingleFileWriter.isPresent()) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + long[] partitionLengths = spills[0].partitionLengths; + return maybeSingleFileWriter.get().transferMapSpillFile( + spills[0].file, partitionLengths); + } else { + commitMessage = mergeSpillsUsingStandardWriter(spills); + } + } else { + commitMessage = mergeSpillsUsingStandardWriter(spills); + } + return commitMessage; + } + + private MapOutputWriterCommitMessage mergeSpillsUsingStandardWriter( + SpillInfo[] spills) throws IOException { + MapOutputWriterCommitMessage commitMessage; final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = - (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE()); + (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE()); final boolean fastMergeIsSupported = !compressionEnabled || - CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); - final int numPartitions = partitioner.numPartitions(); - long[] partitionLengths = new long[numPartitions]; + final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents + .createMapOutputWriter( + shuffleId, + mapId, + taskContext.taskAttemptId(), + partitioner.numPartitions()); try { - if (spills.length == 0) { - return partitionLengths; - } else { - // There are multiple spills to merge, so none of these spill files' lengths were counted - // towards our shuffle write count or shuffle write time. If we use the slow merge path, - // then the final output file's size won't necessarily be equal to the sum of the spill - // files' sizes. To guard against this case, we look at the output file's actual size when - // computing shuffle bytes written. - // - // We allow the individual merge methods to report their own IO times since different merge - // strategies use different IO techniques. We count IO during merge towards the shuffle - // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" - // branch in ExternalSorter. - if (fastMergeEnabled && fastMergeIsSupported) { - // Compression is disabled or we are using an IO compression codec that supports - // decompression of concatenated compressed streams, so we can perform a fast spill merge - // that doesn't need to interpret the spilled bytes. - if (transferToEnabled && !encryptionEnabled) { - logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, mapWriter); - } else { - logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, null); - } + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // write time, which appears to be consistent with the "not bypassing merge-sort" branch in + // ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled && !encryptionEnabled) { + logger.debug("Using transferTo-based fast merge"); + mergeSpillsWithTransferTo(spills, mapWriter); } else { - logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); + logger.debug("Using fileStream-based fast merge"); + mergeSpillsWithFileStream(spills, mapWriter, null); } - // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has - // in-memory records, we write out the in-memory records to a file but do not count that - // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs - // to be counted as shuffle write, but this will lead to double-counting of the final - // SpillInfo's bytes. - writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - return partitionLengths; + } else { + logger.debug("Using slow merge"); + mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); + commitMessage = mapWriter.commitAllPartitions(); + } catch (Exception e) { + try { + mapWriter.abort(e); + } catch (Exception e2) { + logger.warn("Failed to abort writing the map output.", e2); + e.addSuppressed(e2); } - } catch (IOException e) { throw e; } + return commitMessage; } /** @@ -344,12 +370,11 @@ private long[] mergeSpills(SpillInfo[] spills, * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithFileStream( + private void mergeSpillsWithFileStream( SpillInfo[] spills, ShuffleMapOutputWriter mapWriter, @Nullable CompressionCodec compressionCodec) throws IOException { final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; boolean threwException = true; @@ -360,11 +385,11 @@ private long[] mergeSpillsWithFileStream( inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { - boolean copyThrewExecption = true; + boolean copyThrewException = true; ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); - OutputStream partitionOutput = null; + OutputStream partitionOutput = writer.openStream(); try { - partitionOutput = writer.openStream(); + partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); @@ -374,6 +399,7 @@ private long[] mergeSpillsWithFileStream( if (partitionLengthInSpill > 0) { InputStream partitionInputStream = null; + boolean copySpillThrewException = true; try { partitionInputStream = new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); @@ -384,17 +410,18 @@ private long[] mergeSpillsWithFileStream( partitionInputStream); } ByteStreams.copy(partitionInputStream, partitionOutput); + copySpillThrewException = false; } finally { - partitionInputStream.close(); + Closeables.close(partitionInputStream, copySpillThrewException); } } - copyThrewExecption = false; + copyThrewException = false; } + copyThrewException = false; } finally { - Closeables.close(partitionOutput, copyThrewExecption); + Closeables.close(partitionOutput, copyThrewException); } long numBytesWritten = writer.getNumBytesWritten(); - partitionLengths[partition] = numBytesWritten; writeMetrics.incBytesWritten(numBytesWritten); } threwException = false; @@ -405,7 +432,6 @@ private long[] mergeSpillsWithFileStream( Closeables.close(stream, threwException); } } - return partitionLengths; } /** @@ -417,11 +443,10 @@ private long[] mergeSpillsWithFileStream( * @param mapWriter the map output writer to use for output. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithTransferTo( + private void mergeSpillsWithTransferTo( SpillInfo[] spills, ShuffleMapOutputWriter mapWriter) throws IOException { final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; @@ -431,30 +456,28 @@ private long[] mergeSpillsWithTransferTo( spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); } for (int partition = 0; partition < numPartitions; partition++) { - boolean copyThrewExecption = true; + boolean copyThrewException = true; ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); - TransferrableWritableByteChannel partitionChannel = null; + WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper() + .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer))); try { - partitionChannel = writer instanceof SupportsTransferTo ? - ((SupportsTransferTo) writer).openTransferrableChannel() - : new DefaultTransferrableWritableByteChannel( - Channels.newChannel(writer.openStream())); for (int i = 0; i < spills.length; i++) { - long partitionLengthInSpill = 0L; - partitionLengthInSpill += spills[i].partitionLengths[partition]; + long partitionLengthInSpill = spills[i].partitionLengths[partition]; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); - partitionChannel.transferFrom( - spillInputChannel, spillInputChannelPositions[i], partitionLengthInSpill); + Utils.copyFileStreamNIO( + spillInputChannel, + resolvedChannel.channel(), + spillInputChannelPositions[i], + partitionLengthInSpill); + copyThrewException = false; spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } - copyThrewExecption = false; } finally { - Closeables.close(partitionChannel, copyThrewExecption); + Closeables.close(resolvedChannel, copyThrewException); } long numBytes = writer.getNumBytesWritten(); - partitionLengths[partition] = numBytes; writeMetrics.incBytesWritten(numBytes); } threwException = false; @@ -466,7 +489,6 @@ private long[] mergeSpillsWithTransferTo( Closeables.close(spillInputChannels[i], threwException); } } - return partitionLengths; } @Override @@ -495,4 +517,30 @@ public Option stop(boolean success) { } } } + + private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) { + try { + return writer.openStream(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper { + private final WritableByteChannel channel; + + StreamFallbackChannelWrapper(OutputStream fallbackStream) { + this.channel = Channels.newChannel(fallbackStream); + } + + @Override + public WritableByteChannel channel() { + return channel; + } + + @Override + public void close() throws IOException { + channel.close(); + } + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java deleted file mode 100644 index 3b5f9670d64d2..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import org.apache.spark.MapOutputTracker; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkEnv; -import org.apache.spark.api.shuffle.ShuffleExecutorComponents; -import org.apache.spark.api.shuffle.ShuffleReadSupport; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; -import org.apache.spark.serializer.SerializerManager; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; -import org.apache.spark.storage.BlockManager; - -import java.util.Map; - -public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { - - private final SparkConf sparkConf; - private BlockManager blockManager; - private IndexShuffleBlockResolver blockResolver; - private MapOutputTracker mapOutputTracker; - private SerializerManager serializerManager; - - public DefaultShuffleExecutorComponents(SparkConf sparkConf) { - this.sparkConf = sparkConf; - } - - @Override - public void initializeExecutor(String appId, String execId, Map extraConfigs) { - blockManager = SparkEnv.get().blockManager(); - mapOutputTracker = SparkEnv.get().mapOutputTracker(); - serializerManager = SparkEnv.get().serializerManager(); - blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); - } - - @Override - public ShuffleWriteSupport writes() { - checkInitialized(); - return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); - } - - @Override - public ShuffleReadSupport reads() { - checkInitialized(); - return new DefaultShuffleReadSupport(blockManager, - mapOutputTracker, - serializerManager, - sparkConf); - } - - private void checkInitialized() { - if (blockResolver == null) { - throw new IllegalStateException( - "Executor components must be initialized before getting writers."); - } - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java similarity index 61% rename from core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java rename to core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java index 7c124c1fe68bc..77fcd34f962bf 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java @@ -18,26 +18,31 @@ package org.apache.spark.shuffle.sort.io; import org.apache.spark.SparkConf; -import org.apache.spark.api.shuffle.ShuffleDriverComponents; -import org.apache.spark.api.shuffle.ShuffleExecutorComponents; -import org.apache.spark.api.shuffle.ShuffleDataIO; -import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents; +import org.apache.spark.shuffle.api.ShuffleDriverComponents; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleDataIO; +import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents; -public class DefaultShuffleDataIO implements ShuffleDataIO { +/** + * Implementation of the {@link ShuffleDataIO} plugin system that replicates the local shuffle + * storage and index file functionality that has historically been used from Spark 2.4 and earlier. + */ +public class LocalDiskShuffleDataIO implements ShuffleDataIO { private final SparkConf sparkConf; - public DefaultShuffleDataIO(SparkConf sparkConf) { + public LocalDiskShuffleDataIO(SparkConf sparkConf) { this.sparkConf = sparkConf; } @Override - public ShuffleExecutorComponents executor() { - return new DefaultShuffleExecutorComponents(sparkConf); + public ShuffleDriverComponents driver() { + return new LocalDiskShuffleDriverComponents(); } @Override - public ShuffleDriverComponents driver() { - return new DefaultShuffleDriverComponents(); + public ShuffleExecutorComponents executor() { + return new LocalDiskShuffleExecutorComponents(sparkConf); } + } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java new file mode 100644 index 0000000000000..c8d70d72eb02e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.io.InputStream; +import java.util.Map; +import java.util.Optional; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.MapOutputTracker; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.shuffle.api.ShuffleBlockInfo; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockManagerId; + +public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents { + + private final SparkConf sparkConf; + private LocalDiskShuffleReadSupport shuffleReadSupport; + private BlockManagerId shuffleServerId; + private BlockManager blockManager; + private IndexShuffleBlockResolver blockResolver; + + public LocalDiskShuffleExecutorComponents(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @VisibleForTesting + public LocalDiskShuffleExecutorComponents( + SparkConf sparkConf, + BlockManager blockManager, + MapOutputTracker mapOutputTracker, + SerializerManager serializerManager, + IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId) { + this.sparkConf = sparkConf; + this.blockManager = blockManager; + this.blockResolver = blockResolver; + this.shuffleServerId = shuffleServerId; + this.shuffleReadSupport = new LocalDiskShuffleReadSupport( + blockManager, mapOutputTracker, serializerManager, sparkConf); + } + + @Override + public void initializeExecutor(String appId, String execId, Map extraConfigs) { + blockManager = SparkEnv.get().blockManager(); + if (blockManager == null) { + throw new IllegalStateException("No blockManager available from the SparkEnv."); + } + shuffleServerId = blockManager.shuffleServerId(); + blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker(); + SerializerManager serializerManager = SparkEnv.get().serializerManager(); + shuffleReadSupport = new LocalDiskShuffleReadSupport( + blockManager, mapOutputTracker, serializerManager, sparkConf); + } + + @Override + public ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId, + int numPartitions) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return new LocalDiskShuffleMapOutputWriter( + shuffleId, + mapId, + numPartitions, + blockResolver, + shuffleServerId, + sparkConf); + } + + @Override + public Optional createSingleFileMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return Optional.of(new LocalDiskSingleSpillMapOutputWriter( + shuffleId, mapId, blockResolver, shuffleServerId)); + } + + @Override + public Iterable getPartitionReaders(Iterable blockMetadata) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting readers."); + } + return shuffleReadSupport.getPartitionReaders(blockMetadata); + } + + @Override + public boolean shouldWrapPartitionReaderStream() { + return false; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java similarity index 68% rename from core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index ad55b3db377f6..064875420c473 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -23,73 +23,73 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Optional; -import org.apache.spark.api.java.Optional; -import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.SparkConf; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.api.shuffle.SupportsTransferTo; -import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; -import org.apache.spark.internal.config.package$; -import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel; -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; -public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { +/** + * Implementation of {@link ShuffleMapOutputWriter} that replicates the functionality of shuffle + * persisting shuffle data to local disk alongside index files, identical to Spark's historic + * canonical shuffle storage mechanism. + */ +public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { private static final Logger log = - LoggerFactory.getLogger(DefaultShuffleMapOutputWriter.class); + LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class); private final int shuffleId; private final int mapId; - private final ShuffleWriteMetricsReporter metrics; private final IndexShuffleBlockResolver blockResolver; private final long[] partitionLengths; private final int bufferSize; + private final BlockManagerId shuffleServerId; private int lastPartitionId = -1; private long currChannelPosition; - private final BlockManagerId shuffleServerId; + private long bytesWrittenToMergedFile = 0L; private final File outputFile; private File outputTempFile; private FileOutputStream outputFileStream; private FileChannel outputFileChannel; - private TimeTrackingOutputStream ts; private BufferedOutputStream outputBufferedFileStream; - public DefaultShuffleMapOutputWriter( + public LocalDiskShuffleMapOutputWriter( int shuffleId, int mapId, int numPartitions, - BlockManagerId shuffleServerId, - ShuffleWriteMetricsReporter metrics, IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId, SparkConf sparkConf) { this.shuffleId = shuffleId; this.mapId = mapId; - this.shuffleServerId = shuffleServerId; - this.metrics = metrics; this.blockResolver = blockResolver; this.bufferSize = (int) (long) sparkConf.get( package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + this.shuffleServerId = shuffleServerId; this.partitionLengths = new long[numPartitions]; this.outputFile = blockResolver.getDataFile(shuffleId, mapId); this.outputTempFile = null; } @Override - public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException { - if (partitionId <= lastPartitionId) { + public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException { + if (reducePartitionId <= lastPartitionId) { throw new IllegalArgumentException("Partitions should be requested in increasing order."); } - lastPartitionId = partitionId; + lastPartitionId = reducePartitionId; if (outputTempFile == null) { outputTempFile = Utils.tempFileWith(outputFile); } @@ -98,24 +98,32 @@ public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOExcep } else { currChannelPosition = 0L; } - return new DefaultShufflePartitionWriter(partitionId); + return new LocalDiskShufflePartitionWriter(reducePartitionId); } @Override - public Optional commitAllPartitions() throws IOException { + public MapOutputWriterCommitMessage commitAllPartitions() throws IOException { + // Check the position after transferTo loop to see if it is in the right position and raise a + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + outputFileChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your " + + " kernel version to see if it is 2.6.32, as there is a kernel bug which will lead " + + "to unexpected behavior when using transferTo. You can set " + + "spark.file.transferTo=false to disable this NIO feature."); + } cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); - return Optional.of(shuffleServerId); + return MapOutputWriterCommitMessage.of(partitionLengths, shuffleServerId); } @Override - public void abort(Throwable error) { - try { - cleanUp(); - } catch (Exception e) { - log.error("Unable to close appropriate underlying file stream", e); - } + public void abort(Throwable error) throws IOException { + cleanUp(); if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath()); } @@ -136,29 +144,27 @@ private void cleanUp() throws IOException { private void initStream() throws IOException { if (outputFileStream == null) { outputFileStream = new FileOutputStream(outputTempFile, true); - ts = new TimeTrackingOutputStream(metrics, outputFileStream); } if (outputBufferedFileStream == null) { - outputBufferedFileStream = new BufferedOutputStream(ts, bufferSize); + outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize); } } private void initChannel() throws IOException { - if (outputFileStream == null) { - outputFileStream = new FileOutputStream(outputTempFile, true); - } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. if (outputFileChannel == null) { - outputFileChannel = outputFileStream.getChannel(); + outputFileChannel = new FileOutputStream(outputTempFile, true).getChannel(); } } - private class DefaultShufflePartitionWriter implements SupportsTransferTo { + private class LocalDiskShufflePartitionWriter implements ShufflePartitionWriter { private final int partitionId; private PartitionWriterStream partStream = null; private PartitionWriterChannel partChannel = null; - private DefaultShufflePartitionWriter(int partitionId) { + private LocalDiskShufflePartitionWriter(int partitionId) { this.partitionId = partitionId; } @@ -177,7 +183,7 @@ public OutputStream openStream() throws IOException { } @Override - public TransferrableWritableByteChannel openTransferrableChannel() throws IOException { + public Optional openChannelWrapper() throws IOException { if (partChannel == null) { if (partStream != null) { throw new IllegalStateException("Requested an output stream for a previous write but" + @@ -187,7 +193,7 @@ public TransferrableWritableByteChannel openTransferrableChannel() throws IOExce initChannel(); partChannel = new PartitionWriterChannel(partitionId); } - return partChannel; + return Optional.of(partChannel); } @Override @@ -238,6 +244,7 @@ public void write(byte[] buf, int pos, int length) throws IOException { public void close() { isClosed = true; partitionLengths[partitionId] = count; + bytesWrittenToMergedFile += count; } private void verifyNotClosed() { @@ -247,12 +254,11 @@ private void verifyNotClosed() { } } - private class PartitionWriterChannel extends DefaultTransferrableWritableByteChannel { + private class PartitionWriterChannel implements WritableByteChannelWrapper { private final int partitionId; PartitionWriterChannel(int partitionId) { - super(outputFileChannel); this.partitionId = partitionId; } @@ -261,9 +267,15 @@ public long getCount() throws IOException { return writtenPosition - currChannelPosition; } + @Override + public WritableByteChannel channel() { + return outputFileChannel; + } + @Override public void close() throws IOException { partitionLengths[partitionId] = getCount(); + bytesWrittenToMergedFile += partitionLengths[partitionId]; } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java similarity index 52% rename from core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java rename to core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java index d6210f045840b..219f9ee1296dd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -17,36 +17,45 @@ package org.apache.spark.shuffle.sort.io; -import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; -import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; -import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; + import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.util.Utils; -public class DefaultShuffleWriteSupport implements ShuffleWriteSupport { +public class LocalDiskSingleSpillMapOutputWriter + implements SingleSpillShuffleMapOutputWriter { - private final SparkConf sparkConf; + private final int shuffleId; + private final int mapId; private final IndexShuffleBlockResolver blockResolver; private final BlockManagerId shuffleServerId; - public DefaultShuffleWriteSupport( - SparkConf sparkConf, + public LocalDiskSingleSpillMapOutputWriter( + int shuffleId, + int mapId, IndexShuffleBlockResolver blockResolver, BlockManagerId shuffleServerId) { - this.sparkConf = sparkConf; + this.shuffleId = shuffleId; + this.mapId = mapId; this.blockResolver = blockResolver; this.shuffleServerId = shuffleServerId; } @Override - public ShuffleMapOutputWriter createMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId, - int numPartitions) { - return new DefaultShuffleMapOutputWriter( - shuffleId, mapId, numPartitions, shuffleServerId, - TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); + public MapOutputWriterCommitMessage transferMapSpillFile( + File mapSpillFile, + long[] partitionLengths) throws IOException { + // The map spill file already has the proper format, and it contains all of the partition data. + // So just transfer it directly to the destination without any merging. + File outputFile = blockResolver.getDataFile(shuffleId, mapId); + File tempFile = Utils.tempFileWith(outputFile); + Files.move(mapSpillFile.toPath(), tempFile.toPath()); + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); + return MapOutputWriterCommitMessage.of(partitionLengths, shuffleServerId); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java similarity index 93% rename from core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java rename to core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java index c6f43b91f90a0..183769274841c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java @@ -22,11 +22,11 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.SparkEnv; -import org.apache.spark.api.shuffle.ShuffleDriverComponents; +import org.apache.spark.shuffle.api.ShuffleDriverComponents; import org.apache.spark.internal.config.package$; import org.apache.spark.storage.BlockManagerMaster; -public class DefaultShuffleDriverComponents implements ShuffleDriverComponents { +public class LocalDiskShuffleDriverComponents implements ShuffleDriverComponents { private BlockManagerMaster blockManagerMaster; private boolean shouldUnregisterOutputOnHostOnFetchFailure; diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bcd47ba0c29c1..98232380cc266 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -23,11 +23,11 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Scheduled import scala.collection.JavaConverters._ -import org.apache.spark.api.shuffle.ShuffleDriverComponents import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} +import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, ThreadUtils, Utils} /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f359022716571..c84bc82b9a29f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -43,7 +43,6 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.conda.CondaEnvironment import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} @@ -58,6 +57,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents} import org.apache.spark.status.{AppStatusSource, AppStatusStore} import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage._ diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a852a06be9125..833db06420d4d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode} -import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -770,10 +770,10 @@ package object config { .createWithDefault(false) private[spark] val SHUFFLE_IO_PLUGIN_CLASS = - ConfigBuilder("spark.shuffle.io.plugin.class") + ConfigBuilder("spark.shuffle.sort.io.plugin.class") .doc("Name of the class to use for shuffle IO.") .stringConf - .createWithDefault(classOf[DefaultShuffleDataIO].getName) + .createWithDefault(classOf[LocalDiskShuffleDataIO].getName) private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") @@ -951,7 +951,7 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val SHUFFLE_UNDAFE_FAST_MERGE_ENABLE = + private[spark] val SHUFFLE_UNSAFE_FAST_MERGE_ENABLE = ConfigBuilder("spark.shuffle.unsafe.fastMergeEnabled") .doc("Whether to perform a fast spill merge.") .booleanConf diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index a20a849cc6421..e614dbc8c9542 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,17 +17,14 @@ package org.apache.spark.shuffle -import java.io.InputStream - import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.api.java.Optional -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.api.{ShuffleBlockInfo, ShuffleExecutorComponents} import org.apache.spark.storage.ShuffleBlockAttemptId import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -42,7 +39,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, - shuffleReadSupport: ShuffleReadSupport, + shuffleExecutorComponents: ShuffleExecutorComponents, serializerManager: SerializerManager = SparkEnv.get.serializerManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, sparkConf: SparkConf = SparkEnv.get.conf) @@ -57,7 +54,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val streamsIterator = - shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] { + shuffleExecutorComponents.getPartitionReaders(new Iterable[ShuffleBlockInfo] { override def iterator: Iterator[ShuffleBlockInfo] = { mapOutputTracker .getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) @@ -76,18 +73,18 @@ private[spark] class BlockStoreShuffleReader[K, C]( } }.asJava).iterator() - val retryingWrappedStreams = streamsIterator.asScala.map(readSupportStream => { - if (shuffleReadSupport.shouldWrapStream()) { + val retryingWrappedStreams = streamsIterator.asScala.map(rawReaderStream => { + if (shuffleExecutorComponents.shouldWrapPartitionReaderStream()) { if (compressShuffle) { compressionCodec.compressedInputStream( - serializerManager.wrapForEncryption(readSupportStream)) + serializerManager.wrapForEncryption(rawReaderStream)) } else { - serializerManager.wrapForEncryption(readSupportStream) + serializerManager.wrapForEncryption(rawReaderStream) } } else { // The default implementation checks for corrupt streams, so it will already have // decompressed/decrypted the bytes - readSupportStream + rawReaderStream } }) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala new file mode 100644 index 0000000000000..e0affb858c359 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.{Closeable, IOException, OutputStream} + +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.api.ShufflePartitionWriter +import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream} +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.PairsWriter + +/** + * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an + * arbitrary partition writer instead of writing to local disk through the block manager. + */ +private[spark] class ShufflePartitionPairsWriter( + partitionWriter: ShufflePartitionWriter, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + blockId: BlockId, + writeMetrics: ShuffleWriteMetricsReporter) + extends PairsWriter with Closeable { + + private var isClosed = false + private var partitionStream: OutputStream = _ + private var timeTrackingStream: OutputStream = _ + private var wrappedStream: OutputStream = _ + private var objOut: SerializationStream = _ + private var numRecordsWritten = 0 + private var curNumBytesWritten = 0L + + override def write(key: Any, value: Any): Unit = { + if (isClosed) { + throw new IOException("Partition pairs writer is already closed.") + } + if (objOut == null) { + open() + } + objOut.writeKey(key) + objOut.writeValue(value) + recordWritten() + } + + private def open(): Unit = { + try { + partitionStream = partitionWriter.openStream + timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream) + wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream) + objOut = serializerInstance.serializeStream(wrappedStream) + } catch { + case e: Exception => + Utils.tryLogNonFatalError { + close() + } + throw e + } + } + + override def close(): Unit = { + if (!isClosed) { + Utils.tryWithSafeFinally { + Utils.tryWithSafeFinally { + objOut = closeIfNonNull(objOut) + // Setting these to null will prevent the underlying streams from being closed twice + // just in case any stream's close() implementation is not idempotent. + wrappedStream = null + timeTrackingStream = null + partitionStream = null + } { + // Normally closing objOut would close the inner streams as well, but just in case there + // was an error in initialization etc. we make sure we clean the other streams up too. + Utils.tryWithSafeFinally { + wrappedStream = closeIfNonNull(wrappedStream) + // Same as above - if wrappedStream closes then assume it closes underlying + // partitionStream and don't close again in the finally + timeTrackingStream = null + partitionStream = null + } { + Utils.tryWithSafeFinally { + timeTrackingStream = closeIfNonNull(timeTrackingStream) + partitionStream = null + } { + partitionStream = closeIfNonNull(partitionStream) + } + } + } + updateBytesWritten() + } { + isClosed = true + } + } + } + + private def closeIfNonNull[T <: Closeable](closeable: T): T = { + if (closeable != null) { + closeable.close() + } + null.asInstanceOf[T] + } + + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + private def recordWritten(): Unit = { + numRecordsWritten += 1 + writeMetrics.incRecordsWritten(1) + + if (numRecordsWritten % 16384 == 0) { + updateBytesWritten() + } + } + + private def updateBytesWritten(): Unit = { + val numBytesWritten = partitionWriter.getNumBytesWritten + val bytesWrittenDiff = numBytesWritten - curNumBytesWritten + writeMetrics.incBytesWritten(bytesWrittenDiff) + curNumBytesWritten = numBytesWritten + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala similarity index 90% rename from core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala rename to core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala index e18097c2c590a..9e1c1816d306c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala @@ -22,17 +22,17 @@ import java.io.InputStream import scala.collection.JavaConverters._ import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockAttemptId, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.shuffle.api.ShuffleBlockInfo +import org.apache.spark.storage.{BlockManager, ShuffleBlockAttemptId, ShuffleBlockFetcherIterator, ShuffleBlockId} -class DefaultShuffleReadSupport( +class LocalDiskShuffleReadSupport( blockManager: BlockManager, mapOutputTracker: MapOutputTracker, serializerManager: SerializerManager, - conf: SparkConf) extends ShuffleReadSupport { + conf: SparkConf) { private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 private val maxReqsInFlight = conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) @@ -41,7 +41,7 @@ class DefaultShuffleReadSupport( private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) - override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): + def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { val iterableToReturn = if (blockMetadata.asScala.isEmpty) { @@ -70,8 +70,6 @@ class DefaultShuffleReadSupport( } iterableToReturn.asJava } - - override def shouldWrapStream(): Boolean = false } private class ShuffleBlockFetcherIterable( diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index c364c8d08db20..610c04ace3b6f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -22,9 +22,9 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import org.apache.spark._ -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents} import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents} import org.apache.spark.util.Utils /** @@ -130,7 +130,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition, context, metrics, - shuffleExecutorComponents.reads()) + shuffleExecutorComponents) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -152,7 +152,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context, env.conf, metrics, - shuffleExecutorComponents.writes()) + shuffleExecutorComponents) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, @@ -161,10 +161,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context.taskAttemptId(), env.conf, metrics, - shuffleExecutorComponents.writes()) + shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents.writes()) + shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 26f3f2267d44d..0082b4c9c6b24 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -18,10 +18,10 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ -import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( @@ -29,7 +29,7 @@ private[spark] class SortShuffleWriter[K, V, C]( handle: BaseShuffleHandle[K, V, C], mapId: Int, context: TaskContext, - writeSupport: ShuffleWriteSupport) + shuffleExecutorComponents: ShuffleExecutorComponents) extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency @@ -64,11 +64,14 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val mapOutputWriter = writeSupport.createMapOutputWriter( + val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions) - val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - val location = mapOutputWriter.commitAllPartitions - mapStatus = MapStatus(location.orNull, partitionLengths, context.taskAttemptId()) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + val commitMessage = mapOutputWriter.commitAllPartitions() + mapStatus = MapStatus( + commitMessage.getLocation.orElse(null), + commitMessage.getPartitionLengths, + context.taskAttemptId()) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index f9f4e3594e4f9..758621c52495b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -234,7 +234,7 @@ private[spark] class DiskBlockObjectWriter( /** * Writes a key-value pair. */ - def write(key: Any, value: Any) { + override def write(key: Any, value: Any) { if (!streamOpen) { open() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 14d34e1c47c8e..1216a45415a74 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,11 +26,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams import org.apache.spark._ -import org.apache.spark.api.shuffle.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ +import org.apache.spark.shuffle.ShufflePartitionPairsWriter +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} +import org.apache.spark.util.{Utils => TryUtils} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -675,9 +677,9 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project. - * We should figure out an alternative way to test that so that we can remove this otherwise - * unused code path. + * TODO(SPARK-28764): remove this, as this is only used by UnsafeRowSerializerSuite in the SQL + * project. We should figure out an alternative way to test that so that we can remove this + * otherwise unused code path. */ def writePartitionedFile( blockId: BlockId, @@ -728,9 +730,9 @@ private[spark] class ExternalSorter[K, V, C]( * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ def writePartitionedMapOutput( - shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { - // Track location of each range in the map output - val lengths = new Array[Long](numPartitions) + shuffleId: Int, + mapId: Int, + mapOutputWriter: ShuffleMapOutputWriter): Unit = { if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer @@ -739,7 +741,7 @@ private[spark] class ExternalSorter[K, V, C]( val partitionId = it.nextPartition() var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null - try { + TryUtils.tryWithSafeFinally { partitionWriter = mapOutputWriter.getPartitionWriter(partitionId) val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) partitionPairsWriter = new ShufflePartitionPairsWriter( @@ -751,28 +753,19 @@ private[spark] class ExternalSorter[K, V, C]( while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(partitionPairsWriter) } - } finally { + } { if (partitionPairsWriter != null) { partitionPairsWriter.close() } } - if (partitionWriter != null) { - lengths(partitionId) = partitionWriter.getNumBytesWritten - } } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. for ((id, elements) <- this.partitionedIterator) { - // The contract for the plugin is that we will ask for a writer for every partition - // even if it's empty. However, the external sorter will return non-contiguous - // partition ids. So this loop "backfills" the empty partitions that form the gaps. - - // The algorithm as a whole is correct because the partition ids are returned by the - // iterator in ascending order. val blockId = ShuffleBlockId(shuffleId, mapId, id) var partitionWriter: ShufflePartitionWriter = null var partitionPairsWriter: ShufflePartitionPairsWriter = null - try { + TryUtils.tryWithSafeFinally { partitionWriter = mapOutputWriter.getPartitionWriter(id) partitionPairsWriter = new ShufflePartitionPairsWriter( partitionWriter, @@ -785,22 +778,17 @@ private[spark] class ExternalSorter[K, V, C]( partitionPairsWriter.write(elem._1, elem._2) } } - } finally { - if (partitionPairsWriter!= null) { + } { + if (partitionPairsWriter != null) { partitionPairsWriter.close() } } - if (partitionWriter != null) { - lengths(id) = partitionWriter.getNumBytesWritten - } } } context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - - lengths } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala index 9d7c209f242e1..05ed72c3e3778 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala @@ -17,6 +17,11 @@ package org.apache.spark.util.collection +/** + * An abstraction of a consumer of key-value pairs, primarily used when + * persisting partitioned data, either through the shuffle writer plugins + * or via DiskBlockObjectWriter. + */ private[spark] trait PairsWriter { def write(key: Any, value: Any): Unit diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala deleted file mode 100644 index 8538a78b377c8..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.{Closeable, FilterOutputStream, OutputStream} - -import org.apache.spark.api.shuffle.ShufflePartitionWriter -import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter -import org.apache.spark.storage.BlockId - -/** - * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an - * arbitrary partition writer instead of writing to local disk through the block manager. - */ -private[spark] class ShufflePartitionPairsWriter( - partitionWriter: ShufflePartitionWriter, - serializerManager: SerializerManager, - serializerInstance: SerializerInstance, - blockId: BlockId, - writeMetrics: ShuffleWriteMetricsReporter) - extends PairsWriter with Closeable { - - private var isOpen = false - private var partitionStream: OutputStream = _ - private var wrappedStream: OutputStream = _ - private var objOut: SerializationStream = _ - private var numRecordsWritten = 0 - private var curNumBytesWritten = 0L - - override def write(key: Any, value: Any): Unit = { - if (!isOpen) { - open() - isOpen = true - } - objOut.writeKey(key) - objOut.writeValue(value) - writeMetrics.incRecordsWritten(1) - } - - private def open(): Unit = { - partitionStream = partitionWriter.openStream - wrappedStream = serializerManager.wrapStream(blockId, partitionStream) - objOut = serializerInstance.serializeStream(wrappedStream) - } - - override def close(): Unit = { - if (isOpen) { - objOut.close() - objOut = null - wrappedStream = null - partitionStream = null - isOpen = false - updateBytesWritten() - } - } - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - private def recordWritten(): Unit = { - numRecordsWritten += 1 - writeMetrics.incRecordsWritten(1) - - if (numRecordsWritten % 16384 == 0) { - updateBytesWritten() - } - } - - private def updateBytesWritten(): Unit = { - val numBytesWritten = partitionWriter.getNumBytesWritten - val bytesWrittenDiff = numBytesWritten - curNumBytesWritten - writeMetrics.incBytesWritten(bytesWrittenDiff) - curNumBytesWritten = numBytesWritten - } -} diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 4c2e6ac6474da..18f3a339e246c 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -38,6 +38,7 @@ import org.mockito.MockitoAnnotations; import org.apache.spark.HashPartitioner; +import org.apache.spark.MapOutputTracker; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; @@ -56,7 +57,7 @@ import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport; +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -87,6 +88,8 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; + @Mock(answer = RETURNS_SMART_NULLS) MapOutputTracker mapOutputTracker; + @Mock(answer = RETURNS_SMART_NULLS) SerializerManager serializerManager; @After public void tearDown() { @@ -138,8 +141,7 @@ public void setUp() throws IOException { }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - - Answer renameTempAnswer = invocationOnMock -> { + Answer renameTempAnswer = invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; File tmp = (File) invocationOnMock.getArguments()[3]; if (!mergedOutputFile.delete()) { @@ -172,23 +174,25 @@ public void setUp() throws IOException { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); - - TaskContext$.MODULE$.setTaskContext(taskContext); } - private UnsafeShuffleWriter createWriter( - boolean transferToEnabled) throws IOException { + private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); - return new UnsafeShuffleWriter<>( + return new UnsafeShuffleWriter( blockManager, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId()) - ); + new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + shuffleBlockResolver, + BlockManagerId.apply("localhost", 7077))); } private void assertSpillFilesWereCleanedUp() { @@ -414,7 +418,7 @@ public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Except @Test public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { - conf.set(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE(), false); + conf.set(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE(), false); testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); } @@ -539,16 +543,21 @@ public void testPeakMemoryUsed() throws Exception { final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; taskMemoryManager = spy(taskMemoryManager); when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); - final UnsafeShuffleWriter writer = - new UnsafeShuffleWriter<>( + final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( blockManager, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId())); + new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + shuffleBlockResolver, + BlockManagerId.apply("localhost", 7077))); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 1cd7296e9de53..6eb8251ec4002 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -383,14 +383,15 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // simultaneously, and everything is still OK def writeAndClose( - writer: ShuffleWriter[Int, Int], - taskContext: TaskContext)( - iter: Iterator[(Int, Int)]): Option[MapStatus] = { - TaskContext.setTaskContext(taskContext) - val files = writer.write(iter) - val status = writer.stop(true) - TaskContext.unset - status + writer: ShuffleWriter[Int, Int], + taskContext: TaskContext)( + iter: Iterator[(Int, Int)]): Option[MapStatus] = { + try { + val files = writer.write(iter) + writer.stop(true) + } finally { + TaskContext.unset() + } } val interleaver = new InterleaveIterators( data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2)) @@ -413,6 +414,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq + TaskContext.unset() assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) manager.unregisterShuffle(0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala index 68bc5c2961e2d..9d3a52a237cbe 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala @@ -19,18 +19,18 @@ package org.apache.spark.scheduler import java.util import org.apache.spark.{FetchFailed, HashPartitioner, ShuffleDependency, SparkConf, Success} -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} import org.apache.spark.internal.config import org.apache.spark.rdd.RDD -import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO import org.apache.spark.storage.BlockManagerId class PluginShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { - val defaultShuffleDataIO = new DefaultShuffleDataIO(sparkConf) + val localDiskShuffleDataIO = new LocalDiskShuffleDataIO(sparkConf) override def driver(): ShuffleDriverComponents = - new PluginShuffleDriverComponents(defaultShuffleDataIO.driver()) + new PluginShuffleDriverComponents(localDiskShuffleDataIO.driver()) - override def executor(): ShuffleExecutorComponents = defaultShuffleDataIO.executor() + override def executor(): ShuffleExecutorComponents = localDiskShuffleDataIO.executor() } class PluginShuffleDriverComponents(delegate: ShuffleDriverComponents) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 64f8cbc970d54..966a6fa9d005f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import org.mockito.Mockito.{mock, when} +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -29,7 +31,7 @@ import org.apache.spark.internal.config import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId} /** @@ -59,11 +61,14 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying * ManagedBuffers that contain the data are eventually released. */ test("read() releases resources on completion") { + MockitoAnnotations.initMocks(this) val testConf = new SparkConf(false) // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the // shuffle code calls SparkEnv.get()). @@ -142,15 +147,21 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val shuffleReadSupport = - new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf) + val shuffleExecutorComponents = + new LocalDiskShuffleExecutorComponents( + testConf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + localBlockManagerId) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, taskContext, metrics, - shuffleReadSupport, + shuffleExecutorComponents, serializerManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala index 0abfa4d8d8413..b571565cf4336 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.shuffle -import java.util +import java.io.InputStream +import java.lang.{Iterable => JIterable} +import java.util.{Map => JMap} import com.google.common.collect.ImmutableMap -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleReadSupport, ShuffleWriteSupport} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.shuffle.api.{ShuffleBlockInfo, ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleMapOutputWriter} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { test(s"test serialization of shuffle initialization conf to executors") { @@ -43,7 +44,7 @@ class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext } class TestShuffleDriverComponents extends ShuffleDriverComponents { - override def initializeApplication(): util.Map[String, String] = + override def initializeApplication(): JMap[String, String] = ImmutableMap.of("test-key", "test-value") } @@ -55,21 +56,29 @@ class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { } class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents { - override def initializeExecutor(appId: String, execId: String, - extraConfigs: util.Map[String, String]): Unit = { + + private var delegate = new LocalDiskShuffleExecutorComponents(sparkConf) + + override def initializeExecutor( + appId: String, execId: String, extraConfigs: JMap[String, String]): Unit = { assert(extraConfigs.get("test-key") == "test-value") + delegate.initializeExecutor(appId, execId, extraConfigs) + } + + override def createMapOutputWriter( + shuffleId: Int, + mapId: Int, + mapTaskAttemptId: Long, + numPartitions: Int): ShuffleMapOutputWriter = { + delegate.createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions) } - override def writes(): ShuffleWriteSupport = { - val blockManager = SparkEnv.get.blockManager - val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) - new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId) + override def getPartitionReaders( + blockMetadata: JIterable[ShuffleBlockInfo]): JIterable[InputStream] = { + delegate.getPartitionReaders(blockMetadata) } - override def reads(): ShuffleReadSupport = { - val blockManager = SparkEnv.get.blockManager - val mapOutputTracker = SparkEnv.get.mapOutputTracker - val serializerManager = SparkEnv.get.serializerManager - new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf) + override def shouldWrapPartitionReaderStream(): Boolean = { + delegate.shouldWrapPartitionReaderStream() } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 0a77b9f0686ac..a8246aca20baa 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -39,8 +39,9 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer.{KryoSerializer, SerializerManager} -import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} -import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException, IndexShuffleBlockResolver} +import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockAttemptId, ShuffleBlockId} import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils} @@ -67,6 +68,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { // this is only used when initiating the BlockManager, for comms between master and executor @Mock(answer = RETURNS_SMART_NULLS) private var rpcEnv: RpcEnv = _ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ private var tempDir: File = _ @@ -212,11 +214,13 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { when(dependency.aggregator).thenReturn(aggregator) when(dependency.keyOrdering).thenReturn(sorter) - val readSupport = new DefaultShuffleReadSupport( + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + defaultConf, blockManager, mapOutputTracker, serializerManager, - defaultConf) + blockResolver, + blockManager.shuffleServerId) new BlockStoreShuffleReader[String, String]( shuffleHandle, @@ -224,7 +228,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { 1, taskContext, taskContext.taskMetrics().createTempShuffleReadMetrics(), - readSupport, + shuffleExecutorComponents, serializerManager, mapOutputTracker ) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index dbcf09400c97e..46888259206a9 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -19,8 +19,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -49,9 +48,13 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") - val shuffleWriteSupport = - new DefaultShuffleWriteSupport( - conf, blockResolver, BlockManagerId("0", "localhost", 7077, None)) + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + blockManager.shuffleServerId) val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, @@ -60,7 +63,7 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase taskContext.taskAttemptId(), conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleWriteSupport + shuffleExecutorComponents ) shuffleWriter diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index bd241b5ebfaef..da1630e67a485 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.shuffle.sort import java.io.File -import java.util.{Properties, UUID} +import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -30,15 +31,14 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach -import scala.util.Random import org.apache.spark._ -import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -49,11 +49,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ + @Mock(answer = RETURNS_SMART_NULLS) private var serializerManager: SerializerManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var mapOutputTracker: MapOutputTracker = _ private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ - private var writeSupport: ShuffleWriteSupport = _ + private var shuffleExecutorComponents: ShuffleExecutorComponents = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) .set("spark.app.id", "sampleApp") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() @@ -62,39 +64,42 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte override def beforeEach(): Unit = { super.beforeEach() + MockitoAnnotations.initMocks(this) tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) taskMetrics = new TaskMetrics - MockitoAnnotations.initMocks(this) shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( shuffleId = 0, numMaps = 2, dependency = dependency ) + val memoryManager = new TestMemoryManager(conf) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) when(dependency.partitioner).thenReturn(new HashPartitioner(7)) when(dependency.serializer).thenReturn(new JavaSerializer(conf)) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - doAnswer(new Answer[Void] { - def answer(invocationOnMock: InvocationOnMock): Void = { - val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + + when(blockResolver.writeIndexFileAndCommit( + anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + .thenAnswer { (invocationOnMock: InvocationOnMock) => + val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) } null } - }).when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) - when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(blockManager.getDiskWriter( any[BlockId], any[File], any[SerializerInstance], anyInt(), - any[ShuffleWriteMetrics] - )).thenAnswer(new Answer[DiskBlockObjectWriter] { - override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { + any[ShuffleWriteMetrics])) + .thenAnswer { (invocation: InvocationOnMock) => val args = invocation.getArguments val manager = new SerializerManager(new JavaSerializer(conf), conf) new DiskBlockObjectWriter( @@ -104,44 +109,29 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte args(3).asInstanceOf[Int], syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], - blockId = args(0).asInstanceOf[BlockId] - ) + blockId = args(0).asInstanceOf[BlockId]) } - }) - when(diskBlockManager.createTempShuffleBlock()).thenAnswer( - new Answer[(TempShuffleBlockId, File)] { - override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = { - val blockId = new TempShuffleBlockId(UUID.randomUUID) - val file = new File(tempDir, blockId.name) - blockIdToFileMap.put(blockId, file) - temporaryFilesCreated += file - (blockId, file) - } - }) - when(diskBlockManager.getFile(any[BlockId])).thenAnswer( - new Answer[File] { - override def answer(invocation: InvocationOnMock): File = { - blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get - } - }) - val memoryManager = new TestMemoryManager(conf) - val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) - when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + when(diskBlockManager.createTempShuffleBlock()) + .thenAnswer { (invocationOnMock: InvocationOnMock) => + val blockId = new TempShuffleBlockId(UUID.randomUUID) + val file = new File(tempDir, blockId.name) + blockIdToFileMap.put(blockId, file) + temporaryFilesCreated += file + (blockId, file) + } - TaskContext.setTaskContext(new TaskContextImpl( - stageId = 0, - stageAttemptNumber = 0, - partitionId = 0, - taskAttemptId = Random.nextInt(10000), - attemptNumber = 0, - taskMemoryManager = taskMemoryManager, - localProperties = new Properties, - metricsSystem = null, - taskMetrics = taskMetrics)) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) => + blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId]) + } - writeSupport = - new DefaultShuffleWriteSupport(conf, blockResolver, blockManager.shuffleServerId) + shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + BlockManagerId("localhost", 7077)) } override def afterEach(): Unit = { @@ -160,11 +150,11 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - taskContext.taskAttemptId(), + 0L, // MapTaskAttemptId conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) + shuffleExecutorComponents) + writer.write(Iterator.empty) writer.stop( /* success = */ true) assert(writer.getPartitionLengths.sum === 0) @@ -178,55 +168,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(taskMetrics.memoryBytesSpilled === 0) } - test("write with some empty partitions") { - val transferConf = conf.clone.set("spark.file.transferTo", "false") - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - val writer = new BypassMergeSortShuffleWriter[Int, Int]( - blockManager, - shuffleHandle, - 0, // MapId - taskContext.taskAttemptId(), - transferConf, - taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) - writer.write(records) - writer.stop( /* success = */ true) - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - } - - test("write with some empty partitions with transferTo") { - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - val writer = new BypassMergeSortShuffleWriter[Int, Int]( - blockManager, - shuffleHandle, - 0, // MapId - taskContext.taskAttemptId(), - conf, - taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) - writer.write(records) - writer.stop( /* success = */ true) - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) + Seq(true, false).foreach { transferTo => + test(s"write with some empty partitions - transferTo $transferTo") { + val transferConf = conf.clone.set("spark.file.transferTo", transferTo.toString) + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + shuffleHandle, + 0, // MapId + 0L, + transferConf, + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleExecutorComponents) + writer.write(records) + writer.stop( /* success = */ true) + assert(temporaryFilesCreated.nonEmpty) + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } } test("only generate temp shuffle file for non-empty partition") { @@ -247,11 +213,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - taskContext.taskAttemptId(), + 0L, conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) + shuffleExecutorComponents) intercept[SparkException] { writer.write(records) @@ -270,11 +235,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - taskContext.taskAttemptId(), + 0L, conf, taskContext.taskMetrics().shuffleWriteMetrics, - writeSupport - ) + shuffleExecutorComponents) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { if (i == 99990) { @@ -287,4 +251,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } + + /** + * This won't be necessary with Scala 2.12 + */ + private implicit def functionToAnswer[T](func: InvocationOnMock => T): Answer[T] = { + new Answer[T] { + override def answer(invocationOnMock: InvocationOnMock): T = func(invocationOnMock) + } + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala index 26b92e5203b50..6decc9d4e2c84 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala @@ -29,7 +29,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random -import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkConf, TaskContext} +import org.apache.spark.{HashPartitioner, MapOutputTracker, ShuffleDependency, SparkConf, TaskContext} import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} @@ -50,9 +50,11 @@ abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase { @Mock(answer = RETURNS_SMART_NULLS) protected var taskContext: TaskContext = _ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEnv: RpcEnv = _ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + // only used to retrieve info about the maps at the beginning, doesn't affect perf + @Mock(answer = RETURNS_SMART_NULLS) protected var mapOutputTracker: MapOutputTracker = _ protected val defaultConf: SparkConf = new SparkConf(loadDefaults = false) - protected val serializer: Serializer = new KryoSerializer(defaultConf) + protected val serializer: Serializer = new KryoSerializer(defaultConf) protected val partitioner: HashPartitioner = new HashPartitioner(10) protected val serializerManager: SerializerManager = new SerializerManager(serializer, defaultConf) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index 7e7a86b3e6b2a..2cb53e4bac224 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -22,8 +22,7 @@ import org.mockito.Mockito.when import org.apache.spark.{Aggregator, SparkEnv, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -78,16 +77,20 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(taskContext) - val writeSupport = - new DefaultShuffleWriteSupport( - defaultConf, blockResolver, BlockManagerId("0", "localhost", 7077, None)) + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + defaultConf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + blockManager.shuffleServerId) val shuffleWriter = new SortShuffleWriter[String, String, String]( blockResolver, shuffleHandle, 0, taskContext, - writeSupport) + shuffleExecutorComponents) shuffleWriter } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala new file mode 100644 index 0000000000000..326831749ce09 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Mockito._ +import org.scalatest.Matchers + +import org.apache.spark.{MapOutputTracker, Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents +import org.apache.spark.storage.{BlockManager, BlockManagerId} +import org.apache.spark.util.Utils + + +class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers { + + @Mock(answer = RETURNS_SMART_NULLS) + private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) + private var mapOutputTracker: MapOutputTracker = _ + @Mock(answer = RETURNS_SMART_NULLS) + private var serializerManager: SerializerManager = _ + + private val shuffleId = 0 + private val numMaps = 5 + private var shuffleHandle: BaseShuffleHandle[Int, Int, Int] = _ + private val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + private val serializer = new JavaSerializer(conf) + private var shuffleExecutorComponents: ShuffleExecutorComponents = _ + + override def beforeEach(): Unit = { + super.beforeEach() + MockitoAnnotations.initMocks(this) + val partitioner = new Partitioner() { + def numPartitions = numMaps + def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions) + } + shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.partitioner).thenReturn(partitioner) + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps = numMaps, dependency) + } + shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + shuffleBlockResolver, + BlockManagerId("localhost", 7077)) + } + + override def afterAll(): Unit = { + try { + shuffleBlockResolver.stop() + } finally { + super.afterAll() + } + } + + test("write empty iterator") { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val writer = new SortShuffleWriter[Int, Int, Int]( + shuffleBlockResolver, + shuffleHandle, + mapId = 1, + context, + shuffleExecutorComponents) + writer.write(Iterator.empty) + writer.stop(success = true) + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 1) + val writeMetrics = context.taskMetrics().shuffleWriteMetrics + assert(!dataFile.exists()) + assert(writeMetrics.bytesWritten === 0) + assert(writeMetrics.recordsWritten === 0) + } + + test("write with some records") { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val records = List[(Int, Int)]((1, 2), (2, 3), (4, 4), (6, 5)) + val writer = new SortShuffleWriter[Int, Int, Int]( + shuffleBlockResolver, + shuffleHandle, + mapId = 2, + context, + shuffleExecutorComponents) + writer.write(records.toIterator) + writer.stop(success = true) + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 2) + val writeMetrics = context.taskMetrics().shuffleWriteMetrics + assert(dataFile.exists()) + assert(dataFile.length() === writeMetrics.bytesWritten) + assert(records.size === writeMetrics.recordsWritten) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index b09ccb334e4f1..d012bda0ffede 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -18,8 +18,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents /** * Benchmark to measure performance for aggregate primitives. @@ -44,9 +43,13 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) - val shuffleWriteSupport = - new DefaultShuffleWriteSupport( - conf, blockResolver, BlockManagerId("0", "localhost", 7077, None)) + val shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, + blockManager, + mapOutputTracker, + serializerManager, + blockResolver, + blockManager.shuffleServerId) TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( @@ -57,7 +60,7 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleWriteSupport) + shuffleExecutorComponents) } def writeBenchmarkWithSmallDataset(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala deleted file mode 100644 index 3ccb549912782..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io - -import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} -import java.math.BigInteger -import java.nio.ByteBuffer -import java.nio.channels.{Channels, WritableByteChannel} - -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} -import org.mockito.Mock -import org.mockito.Mockito.{doAnswer, doNothing, when} -import org.mockito.MockitoAnnotations -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.api.shuffle.SupportsTransferTo -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.shuffle.IndexShuffleBlockResolver -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{ByteBufferInputStream, Utils} - -class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { - - @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ - @Mock(answer = RETURNS_SMART_NULLS) private var shuffleWriteMetrics: ShuffleWriteMetrics = _ - - private val NUM_PARTITIONS = 4 - private val D_LEN = 10 - private val data: Array[Array[Int]] = (0 until NUM_PARTITIONS).map { - p => (1 to D_LEN).map(_ + p).toArray }.toArray - - private var tempFile: File = _ - private var mergedOutputFile: File = _ - private var tempDir: File = _ - private var partitionSizesInMergedFile: Array[Long] = _ - private var conf: SparkConf = _ - private var mapOutputWriter: DefaultShuffleMapOutputWriter = _ - - override def afterEach(): Unit = { - try { - Utils.deleteRecursively(tempDir) - } finally { - super.afterEach() - } - } - - override def beforeEach(): Unit = { - MockitoAnnotations.initMocks(this) - tempDir = Utils.createTempDir(null, "test") - mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) - tempFile = File.createTempFile("tempfile", "", tempDir) - partitionSizesInMergedFile = null - conf = new SparkConf() - .set("spark.app.id", "example.spark.app") - .set("spark.shuffle.unsafe.file.output.buffer", "16k") - when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) - - doNothing().when(shuffleWriteMetrics).incWriteTime(anyLong) - - doAnswer(new Answer[Void] { - def answer(invocationOnMock: InvocationOnMock): Void = { - partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] - val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - mergedOutputFile.delete - tmp.renameTo(mergedOutputFile) - } - null - } - }).when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) - mapOutputWriter = new DefaultShuffleMapOutputWriter( - 0, - 0, - NUM_PARTITIONS, - BlockManagerId("0", "localhost", 9099), - shuffleWriteMetrics, - blockResolver, - conf) - } - - private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = { - var startOffset = 0L - val result = new Array[Array[Int]](NUM_PARTITIONS) - (0 until NUM_PARTITIONS).foreach { p => - val partitionSize = partitionSizesInMergedFile(p).toInt - lazy val inner = new Array[Int](partitionSize) - lazy val innerBytebuffer = ByteBuffer.allocate(partitionSize) - if (partitionSize > 0) { - val in = new FileInputStream(mergedOutputFile) - in.getChannel.position(startOffset) - val lin = new LimitedInputStream(in, partitionSize) - var nonEmpty = true - var count = 0 - while (nonEmpty) { - try { - val readBit = lin.read() - if (fromByte) { - innerBytebuffer.put(readBit.toByte) - } else { - inner(count) = readBit - } - count += 1 - } catch { - case _: Exception => - nonEmpty = false - } - } - in.close() - } - if (fromByte) { - result(p) = innerBytebuffer.array().sliding(4, 4).map { b => - new BigInteger(b).intValue() - }.toArray - } else { - result(p) = inner - } - startOffset += partitionSize - } - result - } - - test("writing to an outputstream") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.openStream() - data(p).foreach { i => stream.write(i)} - stream.close() - intercept[IllegalStateException] { - stream.write(p) - } - assert(writer.getNumBytesWritten() == D_LEN) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(false)) - } - - test("writing to a channel") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() - val byteBuffer = ByteBuffer.allocate(D_LEN * 4) - val intBuffer = byteBuffer.asIntBuffer() - intBuffer.put(data(p)) - val numBytes = byteBuffer.remaining() - val outputTempFile = File.createTempFile("channelTemp", "", tempDir) - val outputTempFileStream = new FileOutputStream(outputTempFile) - Utils.copyStream( - new ByteBufferInputStream(byteBuffer), - outputTempFileStream, - closeStreams = true) - val tempFileInput = new FileInputStream(outputTempFile) - channel.transferFrom(tempFileInput.getChannel, 0L, numBytes) - // Bytes require * 4 - channel.close() - tempFileInput.close() - assert(writer.getNumBytesWritten == D_LEN * 4) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(true)) - } - - test("copyStreams with an outputstream") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.openStream() - val byteBuffer = ByteBuffer.allocate(D_LEN * 4) - val intBuffer = byteBuffer.asIntBuffer() - intBuffer.put(data(p)) - val in = new ByteArrayInputStream(byteBuffer.array()) - Utils.copyStream(in, stream, false, false) - in.close() - stream.close() - assert(writer.getNumBytesWritten == D_LEN * 4) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(true)) - } - - test("copyStreamsWithNIO with a channel") { - (0 until NUM_PARTITIONS).foreach{ p => - val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() - val byteBuffer = ByteBuffer.allocate(D_LEN * 4) - val intBuffer = byteBuffer.asIntBuffer() - intBuffer.put(data(p)) - val out = new FileOutputStream(tempFile) - out.write(byteBuffer.array()) - out.close() - val in = new FileInputStream(tempFile) - channel.transferFrom(in.getChannel, 0L, byteBuffer.remaining()) - channel.close() - assert(writer.getNumBytesWritten == D_LEN * 4) - } - mapOutputWriter.commitAllPartitions() - val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray - assert(partitionSizesInMergedFile === partitionLengths) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile(true)) - } -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala new file mode 100644 index 0000000000000..8aa9f51e09494 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import java.io.{File, FileInputStream} +import java.nio.channels.FileChannel +import java.nio.file.Files +import java.util.Arrays + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.Mock +import org.mockito.Mockito.when +import org.mockito.MockitoAnnotations +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils + +class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) + private var blockResolver: IndexShuffleBlockResolver = _ + + private val NUM_PARTITIONS = 4 + private val BLOCK_MANAGER_ID = BlockManagerId("localhost", 7077) + private val data: Array[Array[Byte]] = (0 until NUM_PARTITIONS).map { p => + if (p == 3) { + Array.emptyByteArray + } else { + (0 to p * 10).map(_ + p).map(_.toByte).toArray + } + }.toArray + + private val partitionLengths = data.map(_.length) + + private var tempFile: File = _ + private var mergedOutputFile: File = _ + private var tempDir: File = _ + private var partitionSizesInMergedFile: Array[Long] = _ + private var conf: SparkConf = _ + private var mapOutputWriter: LocalDiskShuffleMapOutputWriter = _ + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def beforeEach(): Unit = { + MockitoAnnotations.initMocks(this) + tempDir = Utils.createTempDir() + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) + tempFile = File.createTempFile("tempfile", "", tempDir) + partitionSizesInMergedFile = null + conf = new SparkConf() + .set("spark.app.id", "example.spark.app") + .set("spark.shuffle.unsafe.file.output.buffer", "16k") + when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) + when(blockResolver.writeIndexFileAndCommit( + anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + .thenAnswer { (invocationOnMock: InvocationOnMock) => + partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + mergedOutputFile.delete() + tmp.renameTo(mergedOutputFile) + } + null + } + mapOutputWriter = new LocalDiskShuffleMapOutputWriter( + 0, + 0, + NUM_PARTITIONS, + blockResolver, + BLOCK_MANAGER_ID, + conf) + } + + test("writing to an outputstream") { + (0 until NUM_PARTITIONS).foreach { p => + val writer = mapOutputWriter.getPartitionWriter(p) + val stream = writer.openStream() + data(p).foreach { i => stream.write(i) } + stream.close() + intercept[IllegalStateException] { + stream.write(p) + } + } + verifyWrittenRecords() + } + + test("writing to a channel") { + (0 until NUM_PARTITIONS).foreach { p => + val writer = mapOutputWriter.getPartitionWriter(p) + val outputTempFile = File.createTempFile("channelTemp", "", tempDir) + Files.write(outputTempFile.toPath, data(p)) + val tempFileInput = new FileInputStream(outputTempFile) + val channel = writer.openChannelWrapper() + Utils.tryWithResource(new FileInputStream(outputTempFile)) { tempFileInput => + Utils.tryWithResource(writer.openChannelWrapper().get) { channelWrapper => + assert(channelWrapper.channel().isInstanceOf[FileChannel], + "Underlying channel should be a file channel") + Utils.copyFileStreamNIO( + tempFileInput.getChannel, channelWrapper.channel(), 0L, data(p).length) + } + } + } + verifyWrittenRecords() + } + + private def readRecordsFromFile() = { + val mergedOutputBytes = Files.readAllBytes(mergedOutputFile.toPath) + val result = (0 until NUM_PARTITIONS).map { part => + val startOffset = data.slice(0, part).map(_.length).sum + val partitionSize = data(part).length + Arrays.copyOfRange(mergedOutputBytes, startOffset, startOffset + partitionSize) + }.toArray + result + } + + private def verifyWrittenRecords(): Unit = { + val committedLengths = mapOutputWriter.commitAllPartitions() + assert(partitionSizesInMergedFile === partitionLengths) + assert(committedLengths.getPartitionLengths === partitionLengths) + assert(committedLengths.getLocation.isPresent) + assert(committedLengths.getLocation.get === BLOCK_MANAGER_ID) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile()) + } + + /** + * This won't be necessary with Scala 2.12 + */ + private implicit def functionToAnswer[T](func: InvocationOnMock => T): Answer[T] = { + new Answer[T] { + override def answer(invocationOnMock: InvocationOnMock): T = func(invocationOnMock) + } + } +}