Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25299] ShuffleLocation/FetchFailed integrations with scheduler #548

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package org.apache.spark.api.shuffle;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.api.java.Optional;

import java.io.Serializable;
import java.util.List;

/**
* Represents metadata about where shuffle blocks were written in a single map task.
Expand All @@ -35,5 +37,11 @@ public interface MapShuffleLocations extends Serializable {
/**
* Get the location for a given shuffle block written by this map task.
*/
ShuffleLocation getLocationForBlock(int reduceId);
List<ShuffleLocation> getLocationsForBlock(int reduceId);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned this on the doc, but I'm skeptical about supporting different locations for each (map, reduce) block, instead of just replicating the entire output of one map task to the same places. I don't think I properly understood that part even before this change ... I'll need to look through this more carefully to figure out what the effect of that would be, in particular how much bookkeeping is required on the driver.


/**
* Deletes a host or a host/port combination from this MapShuffleLocations.
* Returns true if the removal of this ShuffleLocation results in missing partitions.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this can be worded more clearly:

Mark a location for a block in this map output as unreachable, and thus partitions can no longer be fetched from that location.
<p>
This is called by the scheduler when it detects that a block could not be fetched from the file server located at this host and port.
<p>
This should return true if partitions are completely lost via the loss of this shuffle location. Otherwise, if all partitions can still be fetched from alternative locations, this should return false.

For something like this it's better to be verbose, although we do expect only those who really know what they're doing to be implementing these APIs.

*/
boolean removeShuffleLocation(String host, Optional<Integer> port);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.api.shuffle;

import org.apache.spark.api.java.Optional;

import java.util.Objects;

/**
Expand All @@ -31,10 +29,10 @@ public class ShuffleBlockInfo {
private final int mapId;
private final int reduceId;
private final long length;
private final Optional<ShuffleLocation> shuffleLocation;
private final ShuffleLocation[] shuffleLocation;

public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length,
Optional<ShuffleLocation> shuffleLocation) {
ShuffleLocation[] shuffleLocation) {
this.shuffleId = shuffleId;
this.mapId = mapId;
this.reduceId = reduceId;
Expand All @@ -58,7 +56,7 @@ public long getLength() {
return length;
}

public Optional<ShuffleLocation> getShuffleLocation() {
public ShuffleLocation[] getShuffleLocation() {
return shuffleLocation;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,33 @@

package org.apache.spark.api.shuffle;

import org.apache.spark.api.java.Optional;

/**
* Marker interface representing a location of a shuffle block. Implementations of shuffle readers
* and writers are expected to cast this down to an implementation-specific representation.
*/
public interface ShuffleLocation {}
public abstract class ShuffleLocation {
/**
* The host and port on which the shuffle block is located.
*/
public abstract String host();
public abstract int port();

/**
* The executor on which the ShuffleLocation is located. Returns {@link Optional#empty()} if
* location is not associated with an executor.
*/
public Optional<String> execId() {
return Optional.empty();
}

@Override
public String toString() {
String shuffleLocation = String.format("ShuffleLocation %s:%d", host(), port());
if (execId().isPresent()) {
return String.format("%s (execId: %s)", shuffleLocation, execId().get());
}
return shuffleLocation;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;

import com.google.common.collect.ImmutableList;
import org.apache.spark.api.java.Optional;
import org.apache.spark.api.shuffle.MapShuffleLocations;
import org.apache.spark.api.shuffle.ShuffleLocation;
import org.apache.spark.storage.BlockManagerId;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

public class DefaultMapShuffleLocations implements MapShuffleLocations, ShuffleLocation {
public class DefaultMapShuffleLocations extends ShuffleLocation implements MapShuffleLocations {

/**
* We borrow the cache size from the BlockManagerId's cache - around 1MB, which should be
Expand All @@ -45,18 +49,28 @@ public DefaultMapShuffleLocations load(BlockManagerId blockManagerId) {
});

private final BlockManagerId location;
private final List<ShuffleLocation> locationsArray;

public DefaultMapShuffleLocations(BlockManagerId blockManagerId) {
this.location = blockManagerId;
this.locationsArray = ImmutableList.of(this);
}

public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) {
return DEFAULT_SHUFFLE_LOCATIONS_CACHE.getUnchecked(blockManagerId);
}

@Override
public ShuffleLocation getLocationForBlock(int reduceId) {
return this;
public List<ShuffleLocation> getLocationsForBlock(int reduceId) {
return locationsArray;
}

@Override
public boolean removeShuffleLocation(String host, Optional<Integer> port) {
if (port.isPresent()) {
return this.host().equals(host) && this.port() == port.get();
}
return this.host().equals(host);
}

public BlockManagerId getBlockManagerId() {
Expand All @@ -73,4 +87,19 @@ public boolean equals(Object other) {
public int hashCode() {
return Objects.hashCode(location);
}

@Override
public String host() {
return location.host();
}

@Override
public int port() {
return location.port();
}

@Override
public Optional<String> execId() {
return Optional.of(location.executorId());
}
}
57 changes: 39 additions & 18 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.spark.api.java.Optional
import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation}
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -102,11 +103,23 @@ private class ShuffleStatus(numPartitions: Int) {
* This is a no-op if there is no registered map output or if the registered output is from a
* different block manager.
*/
def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
def removeMapOutput(mapId: Int, shuffleLocations: Seq[ShuffleLocation]): Unit = synchronized {
if (mapStatuses(mapId) != null) {
var shouldDelete = false
if (shuffleLocations == null) {
shouldDelete = true
} else {
shuffleLocations.foreach { location =>
shouldDelete = mapStatuses(mapId)
.mapShuffleLocations
.removeShuffleLocation(location.host(), Optional.of(location.port()))
}
}
if (shouldDelete) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
}
}
}

Expand All @@ -115,7 +128,14 @@ private class ShuffleStatus(numPartitions: Int) {
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = {
removeOutputsByFilter(x => x.host == host)
for (mapId <- 0 until mapStatuses.length) {
if (mapStatuses(mapId) != null &&
mapStatuses(mapId).mapShuffleLocations.removeShuffleLocation(host, Optional.empty())) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
}
}
}

/**
Expand Down Expand Up @@ -283,7 +303,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging

// For testing
def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int)
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
: Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = {
getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)
}

Expand All @@ -297,7 +317,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* describing the shuffle blocks that are stored at that block manager.
*/
def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]
: Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])]

/**
* Deletes map output status information for the specified shuffle stage.
Expand Down Expand Up @@ -424,10 +444,10 @@ private[spark] class MapOutputTrackerMaster(
}

/** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
def unregisterMapOutput(shuffleId: Int, mapId: Int, shuffleLocations: Seq[ShuffleLocation]) {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeMapOutput(mapId, bmAddress)
shuffleStatus.removeMapOutput(mapId, shuffleLocations)
incrementEpoch()
case None =>
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
Expand Down Expand Up @@ -647,7 +667,7 @@ private[spark] class MapOutputTrackerMaster(
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
: Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
Expand Down Expand Up @@ -684,7 +704,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr

// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
: Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
Expand Down Expand Up @@ -873,9 +893,9 @@ private[spark] object MapOutputTracker extends Logging {
shuffleId: Int,
startPartition: Int,
endPartition: Int,
statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
statuses: Array[MapStatus]): Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = {
assert (statuses != null)
val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]]
val splitsByAddress = new HashMap[Seq[ShuffleLocation], ListBuffer[(BlockId, Long)]]
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
Expand All @@ -885,12 +905,13 @@ private[spark] object MapOutputTracker extends Logging {
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
if (status.mapShuffleLocations == null) {
splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) +=
if (status.mapShuffleLocations == null
|| status.mapShuffleLocations.getLocationsForBlock(part).isEmpty) {
splitsByAddress.getOrElseUpdate(Seq.empty, ListBuffer()) +=
((ShuffleBlockId(shuffleId, mapId, part), size))
} else {
val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part)
splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) +=
val shuffleLocations = status.mapShuffleLocations.getLocationsForBlock(part)
splitsByAddress.getOrElseUpdate(shuffleLocations.asScala, ListBuffer()) +=
((ShuffleBlockId(shuffleId, mapId, part), size))
}
}
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/TaskEndReason.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import java.io.{ObjectInputStream, ObjectOutputStream}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.shuffle.ShuffleLocation
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.storage.BlockManagerId
Expand Down Expand Up @@ -81,14 +82,14 @@ case object Resubmitted extends TaskFailedReason {
*/
@DeveloperApi
case class FetchFailed(
bmAddress: BlockManagerId, // Note that bmAddress can be null
shuffleLocation: Seq[ShuffleLocation], // Note that shuffleLocation can be null
shuffleId: Int,
mapId: Int,
reduceId: Int,
message: String)
extends TaskFailedReason {
override def toErrorString: String = {
val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString
val bmAddressString = if (shuffleLocation == null) "null" else shuffleLocation.toString
s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " +
s"message=\n$message\n)"
}
Expand Down
54 changes: 36 additions & 18 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData}
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO
import org.apache.spark.storage._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
import org.apache.spark.util._
Expand Down Expand Up @@ -1478,7 +1479,7 @@ private[spark] class DAGScheduler(
}
}

case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) =>
case FetchFailed(shuffleLocations, shuffleId, mapId, _, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleIdToMapStage(shuffleId)

Expand Down Expand Up @@ -1511,7 +1512,7 @@ private[spark] class DAGScheduler(
mapOutputTracker.unregisterAllMapOutput(shuffleId)
} else if (mapId != -1) {
// Mark the map whose fetch failed as broken in the map stage
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, shuffleLocations)
}

if (failedStage.rdd.isBarrier()) {
Expand Down Expand Up @@ -1626,22 +1627,39 @@ private[spark] class DAGScheduler(
}

// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled &&
unRegisterOutputOnHostOnFetchFailure) {
// We had a fetch failure with the external shuffle service, so we
// assume all shuffle data on the node is bad.
Some(bmAddress.host)
} else {
// Unregister shuffle data just for one executor (we don't have any
// reason to believe shuffle data has been lost for the entire host).
None
}
removeExecutorAndUnregisterOutputs(
execId = bmAddress.executorId,
fileLost = true,
hostToUnregisterOutputs = hostToUnregisterOutputs,
maybeEpoch = Some(task.epoch))
if (shuffleLocations != null) {
val toRemoveHost =
if (env.conf.get(config.SHUFFLE_IO_PLUGIN_CLASS) ==
classOf[DefaultShuffleDataIO].getName) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're going to be doing equality tests on these classes, we should make sure they are final

env.blockManager.externalShuffleServiceEnabled &&
unRegisterOutputOnHostOnFetchFailure
} else {
true // always remove for remote shuffle storage
}

shuffleLocations.foreach(location => {
var epochAllowsRemoval = false
// If the location belonged to an executor, remove all outputs on the executor
val maybeExecId = location.execId()
val currentEpoch = Some(task.epoch).getOrElse(mapOutputTracker.getEpoch)
if (maybeExecId.isPresent) {
val execId = maybeExecId.get()
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
failedEpoch(execId) = currentEpoch
epochAllowsRemoval = true
blockManagerMaster.removeExecutor(execId)
mapOutputTracker.removeOutputsOnExecutor(execId)
}
} else {
// If the location doesn't belong to an executor, the epoch doesn't matter
epochAllowsRemoval = true
}

if (toRemoveHost && epochAllowsRemoval) {
mapOutputTracker.removeOutputsOnHost(location.host())
}
})
clearCacheLocs()
}
}

Expand Down
Loading