diff --git a/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java
deleted file mode 100644
index b0aed4d08d387..0000000000000
--- a/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.api.shuffle;
-
-import org.apache.spark.annotation.Experimental;
-
-import java.io.Serializable;
-
-/**
- * Represents metadata about where shuffle blocks were written in a single map task.
- *
- * This is optionally returned by shuffle writers. The inner shuffle locations may
- * be accessed by shuffle readers. Shuffle locations are only necessary when the
- * location of shuffle blocks needs to be managed by the driver; shuffle plugins
- * may choose to use an external database or other metadata management systems to
- * track the locations of shuffle blocks instead.
- */
-@Experimental
-public interface MapShuffleLocations extends Serializable {
-
- /**
- * Get the location for a given shuffle block written by this map task.
- */
- ShuffleLocation getLocationForBlock(int reduceId);
-}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java
index a312831cb6282..45effd206f797 100644
--- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java
@@ -18,6 +18,7 @@
package org.apache.spark.api.shuffle;
import org.apache.spark.api.java.Optional;
+import org.apache.spark.storage.BlockManagerId;
import java.util.Objects;
@@ -31,10 +32,10 @@ public class ShuffleBlockInfo {
private final int mapId;
private final int reduceId;
private final long length;
- private final Optional shuffleLocation;
+ private final Optional shuffleLocation;
public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length,
- Optional shuffleLocation) {
+ Optional shuffleLocation) {
this.shuffleId = shuffleId;
this.mapId = mapId;
this.reduceId = reduceId;
@@ -58,7 +59,7 @@ public long getLength() {
return length;
}
- public Optional getShuffleLocation() {
+ public Optional getShuffleLocation() {
return shuffleLocation;
}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java
index 6a0ec8d44fd4f..04986ad7f04f4 100644
--- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java
@@ -30,4 +30,8 @@ public interface ShuffleDriverComponents {
void cleanupApplication() throws IOException;
void removeShuffleData(int shuffleId, boolean blocking) throws IOException;
+
+ default boolean shouldUnregisterOutputOnHostOnFetchFailure() {
+ return false;
+ }
}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java
deleted file mode 100644
index d06c11b3c01ee..0000000000000
--- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java
+++ /dev/null
@@ -1,24 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.shuffle;
-
-/**
- * Marker interface representing a location of a shuffle block. Implementations of shuffle readers
- * and writers are expected to cast this down to an implementation-specific representation.
- */
-public interface ShuffleLocation {}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java
index 062cf4ff0fba9..025fc096faaad 100644
--- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java
@@ -21,6 +21,7 @@
import org.apache.spark.annotation.Experimental;
import org.apache.spark.api.java.Optional;
+import org.apache.spark.storage.BlockManagerId;
/**
* :: Experimental ::
@@ -32,7 +33,7 @@
public interface ShuffleMapOutputWriter {
ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException;
- Optional commitAllPartitions() throws IOException;
+ Optional commitAllPartitions() throws IOException;
void abort(Throwable error) throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 128b90429209e..3e622d00b3aaa 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -25,6 +25,7 @@
import java.nio.channels.FileChannel;
import javax.annotation.Nullable;
+import org.apache.spark.api.java.Optional;
import scala.None$;
import scala.Option;
import scala.Product2;
@@ -39,8 +40,6 @@
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.Optional;
-import org.apache.spark.api.shuffle.MapShuffleLocations;
import org.apache.spark.api.shuffle.SupportsTransferTo;
import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
import org.apache.spark.api.shuffle.ShufflePartitionWriter;
@@ -134,11 +133,8 @@ public void write(Iterator> records) throws IOException {
try {
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
- Optional blockLocs = mapOutputWriter.commitAllPartitions();
- mapStatus = MapStatus$.MODULE$.apply(
- blockManager.shuffleServerId(),
- blockLocs.orNull(),
- partitionLengths);
+ Optional location = mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(Option.apply(location.orNull()), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
@@ -171,11 +167,8 @@ public void write(Iterator> records) throws IOException {
}
partitionLengths = writePartitionedData(mapOutputWriter);
- Optional mapLocations = mapOutputWriter.commitAllPartitions();
- mapStatus = MapStatus$.MODULE$.apply(
- blockManager.shuffleServerId(),
- mapLocations.orNull(),
- partitionLengths);
+ Optional location = mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(Option.apply(location.orNull()), partitionLengths);
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java
deleted file mode 100644
index ffd97c0f26605..0000000000000
--- a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.sort;
-
-import com.google.common.cache.CacheBuilder;
-import com.google.common.cache.CacheLoader;
-import com.google.common.cache.LoadingCache;
-
-import org.apache.spark.api.shuffle.MapShuffleLocations;
-import org.apache.spark.api.shuffle.ShuffleLocation;
-import org.apache.spark.storage.BlockManagerId;
-
-import java.util.Objects;
-
-public class DefaultMapShuffleLocations implements MapShuffleLocations, ShuffleLocation {
-
- /**
- * We borrow the cache size from the BlockManagerId's cache - around 1MB, which should be
- * feasible.
- */
- private static final LoadingCache
- DEFAULT_SHUFFLE_LOCATIONS_CACHE =
- CacheBuilder.newBuilder()
- .maximumSize(BlockManagerId.blockManagerIdCacheSize())
- .build(new CacheLoader() {
- @Override
- public DefaultMapShuffleLocations load(BlockManagerId blockManagerId) {
- return new DefaultMapShuffleLocations(blockManagerId);
- }
- });
-
- private final BlockManagerId location;
-
- public DefaultMapShuffleLocations(BlockManagerId blockManagerId) {
- this.location = blockManagerId;
- }
-
- public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) {
- return DEFAULT_SHUFFLE_LOCATIONS_CACHE.getUnchecked(blockManagerId);
- }
-
- @Override
- public ShuffleLocation getLocationForBlock(int reduceId) {
- return this;
- }
-
- public BlockManagerId getBlockManagerId() {
- return location;
- }
-
- @Override
- public boolean equals(Object other) {
- return other instanceof DefaultMapShuffleLocations
- && Objects.equals(((DefaultMapShuffleLocations) other).location, location);
- }
-
- @Override
- public int hashCode() {
- return Objects.hashCode(location);
- }
-}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index f147bd79773e1..b2c1b49370f4a 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -23,6 +23,8 @@
import java.nio.channels.FileChannel;
import java.util.Iterator;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.storage.BlockManagerId;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
@@ -37,8 +39,6 @@
import org.apache.spark.*;
import org.apache.spark.annotation.Private;
-import org.apache.spark.api.java.Optional;
-import org.apache.spark.api.shuffle.MapShuffleLocations;
import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
import org.apache.spark.api.shuffle.ShufflePartitionWriter;
@@ -219,7 +219,7 @@ void closeAndWriteOutput() throws IOException {
final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport
.createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
final long[] partitionLengths;
- Optional mapLocations;
+ Optional location;
try {
try {
partitionLengths = mergeSpills(spills, mapWriter);
@@ -230,7 +230,7 @@ void closeAndWriteOutput() throws IOException {
}
}
}
- mapLocations = mapWriter.commitAllPartitions();
+ location = mapWriter.commitAllPartitions();
} catch (Exception e) {
try {
mapWriter.abort(e);
@@ -239,10 +239,7 @@ void closeAndWriteOutput() throws IOException {
}
throw e;
}
- mapStatus = MapStatus$.MODULE$.apply(
- blockManager.shuffleServerId(),
- mapLocations.orNull(),
- partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(Option.apply(location.orNull()), partitionLengths);
}
@VisibleForTesting
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
index e83db4e4bcef6..ad55b3db377f6 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
@@ -24,21 +24,19 @@
import java.io.OutputStream;
import java.nio.channels.FileChannel;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.storage.BlockManagerId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.Optional;
-import org.apache.spark.api.shuffle.MapShuffleLocations;
import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
import org.apache.spark.api.shuffle.ShufflePartitionWriter;
import org.apache.spark.api.shuffle.SupportsTransferTo;
import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
import org.apache.spark.internal.config.package$;
-import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations;
import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.util.Utils;
@@ -104,11 +102,11 @@ public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOExcep
}
@Override
- public Optional commitAllPartitions() throws IOException {
+ public Optional commitAllPartitions() throws IOException {
cleanUp();
File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
- return Optional.of(DefaultMapShuffleLocations.get(shuffleServerId));
+ return Optional.of(shuffleServerId);
}
@Override
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java
index a3eddc8ec930e..e70369909a8f0 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java
@@ -20,6 +20,7 @@
import com.google.common.collect.ImmutableMap;
import org.apache.spark.SparkEnv;
import org.apache.spark.api.shuffle.ShuffleDriverComponents;
+import org.apache.spark.internal.config.package$;
import org.apache.spark.storage.BlockManagerMaster;
import java.io.IOException;
@@ -28,10 +29,15 @@
public class DefaultShuffleDriverComponents implements ShuffleDriverComponents {
private BlockManagerMaster blockManagerMaster;
+ private boolean shouldUnregisterOutputOnHostOnFetchFailure;
@Override
public Map initializeApplication() {
blockManagerMaster = SparkEnv.get().blockManager().master();
+ this.shouldUnregisterOutputOnHostOnFetchFailure =
+ SparkEnv.get().blockManager().externalShuffleServiceEnabled()
+ && (boolean) SparkEnv.get().conf()
+ .get(package$.MODULE$.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE());
return ImmutableMap.of();
}
@@ -46,6 +52,11 @@ public void removeShuffleData(int shuffleId, boolean blocking) throws IOExceptio
blockManagerMaster.removeShuffle(shuffleId, blocking);
}
+ @Override
+ public boolean shouldUnregisterOutputOnHostOnFetchFailure() {
+ return shouldUnregisterOutputOnHostOnFetchFailure;
+ }
+
private void checkInitialized() {
if (blockManagerMaster == null) {
throw new IllegalStateException("Driver components must be initialized before using");
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index ebddf5ff6f6e0..bc2399ea27fea 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -28,7 +28,6 @@ import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal
-import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation}
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
@@ -103,7 +102,7 @@ private class ShuffleStatus(numPartitions: Int) {
* different block manager.
*/
def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
- if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
+ if (mapStatuses(mapId) != null && mapStatuses(mapId).location.orNull == bmAddress) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
@@ -133,7 +132,8 @@ private class ShuffleStatus(numPartitions: Int) {
*/
def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
for (mapId <- 0 until mapStatuses.length) {
- if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) {
+ if (mapStatuses(mapId) != null && mapStatuses(mapId).location.isDefined
+ && f(mapStatuses(mapId).location.get)) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
@@ -282,9 +282,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
// For testing
- def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int)
- : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
- getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)
+ def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
+ : Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] = {
+ getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
}
/**
@@ -296,8 +296,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* and the second item is a sequence of (shuffle block id, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
*/
- def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]
+ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
+ : Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])]
/**
* Deletes map output status information for the specified shuffle stage.
@@ -579,7 +579,7 @@ private[spark] class MapOutputTrackerMaster(
/**
* Return a list of locations that each have fraction of map output greater than the specified
- * threshold.
+ * threshold. Ignores shuffle blocks without location or executor id.
*
* @param shuffleId id of the shuffle
* @param reducerId id of the reduce task
@@ -608,10 +608,12 @@ private[spark] class MapOutputTrackerMaster(
// array with null entries for each output, and registerMapOutputs, which populates it
// with valid status entries. This is possible if one thread schedules a job which
// depends on an RDD which is currently being computed by another thread.
- if (status != null) {
+ // This also ignores locations that are not on executors.
+ if (status != null && status.location.isDefined
+ && status.location.get.executorId != null) {
val blockSize = status.getSizeForBlock(reducerId)
if (blockSize > 0) {
- locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize
+ locs(status.location.get) = locs.getOrElse(status.location.get, 0L) + blockSize
totalOutputSize += blockSize
}
}
@@ -646,8 +648,8 @@ private[spark] class MapOutputTrackerMaster(
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
- def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
+ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
+ : Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
@@ -683,13 +685,12 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
private val fetching = new HashSet[Int]
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
- override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
+ override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
+ : Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
- MapOutputTracker.convertMapStatuses(
- shuffleId, startPartition, endPartition, statuses)
+ MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
@@ -873,9 +874,9 @@ private[spark] object MapOutputTracker extends Logging {
shuffleId: Int,
startPartition: Int,
endPartition: Int,
- statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
+ statuses: Array[MapStatus]): Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] = {
assert (statuses != null)
- val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]]
+ val splitsByAddress = new HashMap[Option[BlockManagerId], ListBuffer[(BlockId, Long)]]
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
@@ -885,14 +886,8 @@ private[spark] object MapOutputTracker extends Logging {
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
- if (status.mapShuffleLocations == null) {
- splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) +=
- ((ShuffleBlockId(shuffleId, mapId, part), size))
- } else {
- val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part)
- splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) +=
- ((ShuffleBlockId(shuffleId, mapId, part), size))
- }
+ splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
+ ((ShuffleBlockId(shuffleId, mapId, part), size))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 999f180193d84..f359022716571 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -308,6 +308,8 @@ class SparkContext(config: SparkConf) extends SafeLogging {
_dagScheduler = ds
}
+ private[spark] def shuffleDriverComponents: ShuffleDriverComponents = _shuffleDriverComponents
+
/**
* A unique identifier for the Spark application.
* Its format depends on the scheduler implementation.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index dd1b2595461fc..06bf23a8592cf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -163,6 +163,8 @@ private[spark] class DAGScheduler(
private[scheduler] val activeJobs = new HashSet[ActiveJob]
+ private[scheduler] val shuffleDriverComponents = sc.shuffleDriverComponents
+
/**
* Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
* and its values are arrays indexed by partition numbers. Each array value is the set of
@@ -1434,14 +1436,20 @@ private[spark] class DAGScheduler(
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
shuffleStage.pendingPartitions -= task.partitionId
val status = event.result.asInstanceOf[MapStatus]
- val execId = status.location.executorId
- logDebug("ShuffleMapTask finished on " + execId)
- if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
- logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
+ logDebug(s"ShuffleMapTask finished on ${event.taskInfo.executorId} " +
+ s"with shuffle files located at ${status.location.getOrElse("N/A")}")
+ if (status.location.isDefined && status.location.get.executorId != null) {
+ val execId = status.location.get.executorId
+ if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
+ logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
+ } else {
+ // The epoch of the task is acceptable (i.e., the task was launched after the most
+ // recent failure we're aware of for the executor), so mark the task's output as
+ // available.
+ mapOutputTracker.registerMapOutput(
+ shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
+ }
} else {
- // The epoch of the task is acceptable (i.e., the task was launched after the most
- // recent failure we're aware of for the executor), so mark the task's output as
- // available.
mapOutputTracker.registerMapOutput(
shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
}
@@ -1627,21 +1635,31 @@ private[spark] class DAGScheduler(
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled &&
- unRegisterOutputOnHostOnFetchFailure) {
- // We had a fetch failure with the external shuffle service, so we
- // assume all shuffle data on the node is bad.
- Some(bmAddress.host)
+ if (bmAddress.executorId == null) {
+ if (shuffleDriverComponents.shouldUnregisterOutputOnHostOnFetchFailure()) {
+ val currentEpoch = task.epoch
+ val host = bmAddress.host
+ logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch))
+ mapOutputTracker.removeOutputsOnHost(host)
+ clearCacheLocs()
+ }
} else {
- // Unregister shuffle data just for one executor (we don't have any
- // reason to believe shuffle data has been lost for the entire host).
- None
+ val hostToUnregisterOutputs =
+ if (shuffleDriverComponents.shouldUnregisterOutputOnHostOnFetchFailure()) {
+ // We had a fetch failure with the external shuffle service, so we
+ // assume all shuffle data on the node is bad.
+ Some(bmAddress.host)
+ } else {
+ // Unregister shuffle data just for one executor (we don't have any
+ // reason to believe shuffle data has been lost for the entire host).
+ None
+ }
+ removeExecutorAndUnregisterOutputs(
+ execId = bmAddress.executorId,
+ fileLost = true,
+ hostToUnregisterOutputs = hostToUnregisterOutputs,
+ maybeEpoch = Some(task.epoch))
}
- removeExecutorAndUnregisterOutputs(
- execId = bmAddress.executorId,
- fileLost = true,
- hostToUnregisterOutputs = hostToUnregisterOutputs,
- maybeEpoch = Some(task.epoch))
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index a61f9bd14ef2f..7ec87641a8900 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -24,9 +24,7 @@ import scala.collection.mutable
import org.roaringbitmap.RoaringBitmap
import org.apache.spark.SparkEnv
-import org.apache.spark.api.shuffle.MapShuffleLocations
import org.apache.spark.internal.config
-import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
@@ -35,17 +33,8 @@ import org.apache.spark.util.Utils
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
*/
private[spark] sealed trait MapStatus {
-
- /**
- * Locations where this task stored shuffle blocks.
- *
- * May be null if the MapOutputTracker is not tracking the location of shuffle blocks, leaving it
- * up to the implementation of shuffle plugins to do so.
- */
- def mapShuffleLocations: MapShuffleLocations
-
- /** Location where the task was run. */
- def location: BlockManagerId
+ /** Location where this task was run. */
+ def location: Option[BlockManagerId]
/**
* Estimated size for the reduce block, in bytes.
@@ -67,31 +56,11 @@ private[spark] object MapStatus {
.map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
.getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)
- // A temporary concession to the fact that we only expect implementations of shuffle provided by
- // Spark to be storing shuffle locations in the driver, meaning we want to introduce as little
- // serialization overhead as possible in such default cases.
- //
- // If more similar cases arise, consider adding a serialization API for these shuffle locations.
- private val DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 0
- private val NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 1
-
- /**
- * Visible for testing.
- */
- def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
- apply(loc, DefaultMapShuffleLocations.get(loc), uncompressedSizes)
- }
-
- def apply(
- loc: BlockManagerId,
- mapShuffleLocs: MapShuffleLocations,
- uncompressedSizes: Array[Long]): MapStatus = {
+ def apply(maybeLoc: Option[BlockManagerId], uncompressedSizes: Array[Long]): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
- HighlyCompressedMapStatus(
- loc, mapShuffleLocs, uncompressedSizes)
+ HighlyCompressedMapStatus(maybeLoc, uncompressedSizes)
} else {
- new CompressedMapStatus(
- loc, mapShuffleLocs, uncompressedSizes)
+ new CompressedMapStatus(maybeLoc, uncompressedSizes)
}
}
@@ -122,89 +91,50 @@ private[spark] object MapStatus {
math.pow(LOG_BASE, compressedSize & 0xFF).toLong
}
}
-
- def writeLocations(
- loc: BlockManagerId,
- mapShuffleLocs: MapShuffleLocations,
- out: ObjectOutput): Unit = {
- if (mapShuffleLocs != null) {
- out.writeBoolean(true)
- if (mapShuffleLocs.isInstanceOf[DefaultMapShuffleLocations]
- && mapShuffleLocs.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId == loc) {
- out.writeByte(MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID)
- } else {
- out.writeByte(MapStatus.NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID)
- out.writeObject(mapShuffleLocs)
- }
- } else {
- out.writeBoolean(false)
- }
- loc.writeExternal(out)
- }
-
- def readLocations(in: ObjectInput): (BlockManagerId, MapShuffleLocations) = {
- if (in.readBoolean()) {
- val locId = in.readByte()
- if (locId == MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) {
- val blockManagerId = BlockManagerId(in)
- (blockManagerId, DefaultMapShuffleLocations.get(blockManagerId))
- } else {
- val mapShuffleLocations = in.readObject().asInstanceOf[MapShuffleLocations]
- val blockManagerId = BlockManagerId(in)
- (blockManagerId, mapShuffleLocations)
- }
- } else {
- val blockManagerId = BlockManagerId(in)
- (blockManagerId, null)
- }
- }
}
+
/**
* A [[MapStatus]] implementation that tracks the size of each block. Size for each block is
* represented using a single byte.
*
- * @param loc Location were the task is being executed.
- * @param mapShuffleLocs locations where the task stored its shuffle blocks - may be null.
+ * @param loc location where the task is being executed.
* @param compressedSizes size of the blocks, indexed by reduce partition id.
*/
private[spark] class CompressedMapStatus(
- private[this] var loc: BlockManagerId,
- private[this] var mapShuffleLocs: MapShuffleLocations,
+ private[this] var loc: Option[BlockManagerId],
private[this] var compressedSizes: Array[Byte])
extends MapStatus with Externalizable {
- // For deserialization only
- protected def this() = this(null, null, null.asInstanceOf[Array[Byte]])
+ protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
- def this(
- loc: BlockManagerId,
- mapShuffleLocations: MapShuffleLocations,
- uncompressedSizes: Array[Long]) {
- this(
- loc,
- mapShuffleLocations,
- uncompressedSizes.map(MapStatus.compressSize))
+ def this(loc: Option[BlockManagerId], uncompressedSizes: Array[Long]) {
+ this(loc, uncompressedSizes.map(MapStatus.compressSize))
}
- override def location: BlockManagerId = loc
-
- override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs
+ override def location: Option[BlockManagerId] = loc
override def getSizeForBlock(reduceId: Int): Long = {
MapStatus.decompressSize(compressedSizes(reduceId))
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- MapStatus.writeLocations(loc, mapShuffleLocs, out)
+ if (loc.isDefined) {
+ out.writeBoolean(true)
+ loc.get.writeExternal(out)
+ } else {
+ out.writeBoolean(false)
+ }
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
- val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in)
- loc = deserializedLoc
- mapShuffleLocs = deserializedMapShuffleLocs
+ if (in.readBoolean()) {
+ loc = Some(BlockManagerId(in))
+ } else {
+ loc = None
+ }
val len = in.readInt()
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
@@ -217,15 +147,13 @@ private[spark] class CompressedMapStatus(
* plus a bitmap for tracking which blocks are empty.
*
* @param loc location where the task is being executed
- * @param mapShuffleLocs location where the task stored shuffle blocks - may be null
* @param numNonEmptyBlocks the number of non-empty blocks
* @param emptyBlocks a bitmap tracking which blocks are empty
* @param avgSize average size of the non-empty and non-huge blocks
* @param hugeBlockSizes sizes of huge blocks by their reduceId.
*/
private[spark] class HighlyCompressedMapStatus private (
- private[this] var loc: BlockManagerId,
- private[this] var mapShuffleLocs: MapShuffleLocations,
+ private[this] var loc: Option[BlockManagerId],
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
@@ -236,11 +164,9 @@ private[spark] class HighlyCompressedMapStatus private (
require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, null, -1, null, -1, null) // For deserialization only
-
- override def location: BlockManagerId = loc
+ protected def this() = this(null, -1, null, -1, null) // For deserialization only
- override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs
+ override def location: Option[BlockManagerId] = loc
override def getSizeForBlock(reduceId: Int): Long = {
assert(hugeBlockSizes != null)
@@ -255,7 +181,12 @@ private[spark] class HighlyCompressedMapStatus private (
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- MapStatus.writeLocations(loc, mapShuffleLocs, out)
+ if (loc.isDefined) {
+ out.writeBoolean(true)
+ loc.get.writeExternal(out)
+ } else {
+ out.writeBoolean(false)
+ }
emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
out.writeInt(hugeBlockSizes.size)
@@ -266,9 +197,11 @@ private[spark] class HighlyCompressedMapStatus private (
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
- val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in)
- loc = deserializedLoc
- mapShuffleLocs = deserializedMapShuffleLocs
+ if (in.readBoolean()) {
+ loc = Some(BlockManagerId(in))
+ } else {
+ loc = None
+ }
emptyBlocks = new RoaringBitmap()
emptyBlocks.readExternal(in)
avgSize = in.readLong()
@@ -284,10 +217,8 @@ private[spark] class HighlyCompressedMapStatus private (
}
private[spark] object HighlyCompressedMapStatus {
- def apply(
- loc: BlockManagerId,
- mapShuffleLocs: MapShuffleLocations,
- uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ def apply(loc: Option[BlockManagerId], uncompressedSizes: Array[Long])
+ : HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
@@ -327,12 +258,7 @@ private[spark] object HighlyCompressedMapStatus {
}
emptyBlocks.trim()
emptyBlocks.runOptimize()
- new HighlyCompressedMapStatus(
- loc,
- mapShuffleLocs,
- numNonEmptyBlocks,
- emptyBlocks,
- avgSize,
- hugeBlockSizes)
+ new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
+ hugeBlockSizes)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index ba8c92518f019..2df133dd2b13a 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -31,7 +31,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSe
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput}
import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool}
-import com.esotericsoftware.kryo.serializers.{ExternalizableSerializer, JavaSerializer => KryoJavaSerializer}
+import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.roaringbitmap.RoaringBitmap
@@ -152,8 +152,6 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer())
kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer())
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
- kryo.register(classOf[CompressedMapStatus], new ExternalizableSerializer())
- kryo.register(classOf[HighlyCompressedMapStatus], new ExternalizableSerializer())
kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas))
kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas))
@@ -487,6 +485,8 @@ private[serializer] object KryoSerializer {
private val toRegister: Seq[Class[_]] = Seq(
ByteBuffer.allocate(1).getClass,
classOf[StorageLevel],
+ classOf[CompressedMapStatus],
+ classOf[HighlyCompressedMapStatus],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
classOf[Array[Boolean]],
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 530c3694ad1ec..8d6745ba397d3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -28,7 +28,7 @@ import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.io.DefaultShuffleReadSupport
-import org.apache.spark.storage.{ShuffleBlockFetcherIterator, ShuffleBlockId}
+import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -60,7 +60,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] {
override def iterator: Iterator[ShuffleBlockInfo] = {
mapOutputTracker
- .getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition)
+ .getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
.flatMap { shuffleLocationInfo =>
shuffleLocationInfo._2.map { blockInfo =>
val block = blockInfo._1.asInstanceOf[ShuffleBlockId]
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 265a8acfa8d61..5518264c15136 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -68,3 +68,12 @@ private[spark] class MetadataFetchFailedException(
reduceId: Int,
message: String)
extends FetchFailedException(null, shuffleId, -1, reduceId, message)
+
+private[spark] class RemoteFetchFailedException(
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int,
+ message: String,
+ host: String,
+ port: Int)
+ extends FetchFailedException(BlockManagerId(host, port), shuffleId, mapId, reduceId, message)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala
index 9b9b8508e88aa..928a6f32739fd 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala
@@ -26,7 +26,6 @@ import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport}
import org.apache.spark.internal.config
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
-import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
class DefaultShuffleReadSupport(
@@ -93,12 +92,8 @@ private class ShuffleBlockFetcherIterable(
context,
blockManager.shuffleClient,
blockManager,
- mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, minReduceId, maxReduceId + 1)
- .map { shuffleLocationInfo =>
- val defaultShuffleLocation = shuffleLocationInfo._1
- .get.asInstanceOf[DefaultMapShuffleLocations]
- (defaultShuffleLocation.getBlockManagerId, shuffleLocationInfo._2)
- },
+ mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1)
+ .map(loc => (loc._1.get, loc._2)),
serializerManager.wrapStream,
maxBytesInFlight,
maxReqsInFlight,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 1fcae684b0052..3ffafd288125d 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -67,11 +67,8 @@ private[spark] class SortShuffleWriter[K, V, C](
val mapOutputWriter = writeSupport.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
- val mapLocations = mapOutputWriter.commitAllPartitions()
- mapStatus = MapStatus(
- blockManager.shuffleServerId,
- mapLocations.orNull(),
- partitionLengths)
+ val location = mapOutputWriter.commitAllPartitions
+ mapStatus = MapStatus(Option.apply(location.orNull), partitionLengths)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index d72bd6f9af6bc..8d66cbbfb7562 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -70,7 +70,12 @@ class BlockManagerId private (
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- out.writeUTF(executorId_)
+ if (executorId_ != null) {
+ out.writeBoolean(true)
+ out.writeUTF(executorId_)
+ } else {
+ out.writeBoolean(false)
+ }
out.writeUTF(host_)
out.writeInt(port_)
out.writeBoolean(topologyInfo_.isDefined)
@@ -79,7 +84,9 @@ class BlockManagerId private (
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
- executorId_ = in.readUTF()
+ if (in.readBoolean()) {
+ executorId_ = in.readUTF()
+ }
host_ = in.readUTF()
port_ = in.readInt()
val isTopologyInfoAvailable = in.readBoolean()
@@ -91,8 +98,13 @@ class BlockManagerId private (
override def toString: String = s"BlockManagerId($executorId, $host, $port, $topologyInfo)"
- override def hashCode: Int =
- ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode
+ override def hashCode: Int = {
+ if (executorId != null) {
+ ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode
+ } else {
+ (host.hashCode * 41 + port) * 41 + topologyInfo.hashCode
+ }
+ }
override def equals(that: Any): Boolean = that match {
case id: BlockManagerId =>
@@ -127,20 +139,21 @@ private[spark] object BlockManagerId {
topologyInfo: Option[String] = None): BlockManagerId =
getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo))
+ def apply(host: String, port: Int): BlockManagerId =
+ getCachedBlockManagerId(new BlockManagerId(null, host, port, None))
+
def apply(in: ObjectInput): BlockManagerId = {
val obj = new BlockManagerId()
obj.readExternal(in)
getCachedBlockManagerId(obj)
}
- val blockManagerIdCacheSize = 10000
-
/**
* The max cache size is hardcoded to 10000, since the size of a BlockManagerId
* object is about 48B, the total memory cost should be below 1MB which is feasible.
*/
val blockManagerIdCache = CacheBuilder.newBuilder()
- .maximumSize(blockManagerIdCacheSize)
+ .maximumSize(10000)
.build(new CacheLoader[BlockManagerId, BlockManagerId]() {
override def load(id: BlockManagerId) = id
})
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 5ea0907277ebf..4c2e6ac6474da 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -172,8 +172,6 @@ public void setUp() throws IOException {
when(shuffleDep.serializer()).thenReturn(serializer);
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
- when(blockManager.shuffleServerId()).thenReturn(BlockManagerId.apply(
- "0", "localhost", 9099, Option.empty()));
TaskContext$.MODULE$.setTaskContext(taskContext);
}
@@ -189,7 +187,8 @@ private UnsafeShuffleWriter