Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25299] ShuffleLocation/FetchFailed integrations with scheduler #548

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package org.apache.spark.api.shuffle;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.api.java.Optional;

import java.io.Serializable;
import java.util.List;

/**
* Represents metadata about where shuffle blocks were written in a single map task.
Expand All @@ -35,5 +37,30 @@ public interface MapShuffleLocations extends Serializable {
/**
* Get the location for a given shuffle block written by this map task.
*/
ShuffleLocation getLocationForBlock(int reduceId);
List<ShuffleLocation> getLocationsForBlock(int reduceId);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned this on the doc, but I'm skeptical about supporting different locations for each (map, reduce) block, instead of just replicating the entire output of one map task to the same places. I don't think I properly understood that part even before this change ... I'll need to look through this more carefully to figure out what the effect of that would be, in particular how much bookkeeping is required on the driver.


/**
* Mark a location for a block in this map output as unreachable, and thus partitions can no
* longer be fetched from that location.
* <p>
* This is called by the scheduler when it detects that a block could not be fetched from the
* file server located at this host and port.
* <p>
* This should return true if there exists a data loss from the removal of this shuffle
* location. Otherwise, if all partitions can still be fetched from alternative locations,
* this should return false.
*/
boolean invalidateShuffleLocation(String host, Optional<Integer> port);

/**
* Mark all locations within this MapShuffleLocations with this execId as unreachable.
* <p>
* This is called by the scheduler when it detects that an executor cannot be reached to
* fetch file data.
* <p>
* This should return true if there exists a data loss from the removal of shuffle locations
* with this execId. Otherwise, if all partitions can still be fetched form alternative locaitons,
* this should return false.
*/
boolean invalidateShuffleLocation(String executorId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.api.shuffle;

import org.apache.spark.api.java.Optional;

import java.util.Objects;

/**
Expand All @@ -31,10 +29,10 @@ public class ShuffleBlockInfo {
private final int mapId;
private final int reduceId;
private final long length;
private final Optional<ShuffleLocation> shuffleLocation;
private final ShuffleLocation[] shuffleLocation;

public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length,
Optional<ShuffleLocation> shuffleLocation) {
ShuffleLocation[] shuffleLocation) {
this.shuffleId = shuffleId;
this.mapId = mapId;
this.reduceId = reduceId;
Expand All @@ -58,7 +56,7 @@ public long getLength() {
return length;
}

public Optional<ShuffleLocation> getShuffleLocation() {
public ShuffleLocation[] getShuffleLocation() {
return shuffleLocation;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,12 @@ public interface ShuffleDriverComponents {
void cleanupApplication() throws IOException;

void removeShuffleData(int shuffleId, boolean blocking) throws IOException;

/**
* Whether to unregister other map statuses on the same hosts or executors
* when a shuffle task returns a {@link org.apache.spark.FetchFailed}.
*/
default boolean unregisterOtherMapStatusesOnFetchFailure() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,33 @@

package org.apache.spark.api.shuffle;

import org.apache.spark.api.java.Optional;

/**
* Marker interface representing a location of a shuffle block. Implementations of shuffle readers
* and writers are expected to cast this down to an implementation-specific representation.
*/
public interface ShuffleLocation {}
public abstract class ShuffleLocation {
/**
* The host and port on which the shuffle block is located.
*/
public abstract String host();
public abstract int port();

/**
* The executor on which the ShuffleLocation is located. Returns {@link Optional#empty()} if
* location is not associated with an executor.
*/
public Optional<String> execId() {
return Optional.empty();
}

@Override
public String toString() {
String shuffleLocation = String.format("ShuffleLocation %s:%d", host(), port());
if (execId().isPresent()) {
return String.format("%s (execId: %s)", shuffleLocation, execId().get());
}
return shuffleLocation;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@

package org.apache.spark.shuffle.sort;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;

import com.google.common.collect.Lists;
import org.apache.spark.api.java.Optional;
import org.apache.spark.api.shuffle.MapShuffleLocations;
import org.apache.spark.api.shuffle.ShuffleLocation;
import org.apache.spark.storage.BlockManagerId;

import java.util.List;
import java.util.Objects;

public class DefaultMapShuffleLocations implements MapShuffleLocations, ShuffleLocation {
public class DefaultMapShuffleLocations extends ShuffleLocation implements MapShuffleLocations {

/**
* We borrow the cache size from the BlockManagerId's cache - around 1MB, which should be
Expand All @@ -45,18 +49,34 @@ public DefaultMapShuffleLocations load(BlockManagerId blockManagerId) {
});

private final BlockManagerId location;
@JsonIgnore
private final List<ShuffleLocation> locationsArray;

public DefaultMapShuffleLocations(BlockManagerId blockManagerId) {
this.location = blockManagerId;
this.locationsArray = Lists.newArrayList(this);
}

public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) {
return DEFAULT_SHUFFLE_LOCATIONS_CACHE.getUnchecked(blockManagerId);
}

@Override
public ShuffleLocation getLocationForBlock(int reduceId) {
return this;
public List<ShuffleLocation> getLocationsForBlock(int reduceId) {
return locationsArray;
}

@Override
public boolean invalidateShuffleLocation(String host, Optional<Integer> port) {
if (port.isPresent()) {
return this.host().equals(host) && this.port() == port.get();
}
return this.host().equals(host);
}

@Override
public boolean invalidateShuffleLocation(String executorId) {
return location.executorId().equals(executorId);
}

public BlockManagerId getBlockManagerId() {
Expand All @@ -73,4 +93,19 @@ public boolean equals(Object other) {
public int hashCode() {
return Objects.hashCode(location);
}

@Override
public String host() {
return location.host();
}

@Override
public int port() {
return location.port();
}

@Override
public Optional<String> execId() {
return Optional.of(location.executorId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ public void removeShuffleData(int shuffleId, boolean blocking) throws IOExceptio
blockManagerMaster.removeShuffle(shuffleId, blocking);
}

@Override
public boolean unregisterOtherMapStatusesOnFetchFailure() {
return true;
}

private void checkInitialized() {
if (blockManagerMaster == null) {
throw new IllegalStateException("Driver components must be initialized before using");
Expand Down
66 changes: 47 additions & 19 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.spark.api.java.Optional
import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation}
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -102,11 +103,23 @@ private class ShuffleStatus(numPartitions: Int) {
* This is a no-op if there is no registered map output or if the registered output is from a
* different block manager.
*/
def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
def removeMapOutput(mapId: Int, shuffleLocations: Seq[ShuffleLocation]): Unit = synchronized {
if (mapStatuses(mapId) != null) {
var shouldDelete = false
if (shuffleLocations.isEmpty) {
shouldDelete = true
} else {
shuffleLocations.foreach { location =>
shouldDelete = mapStatuses(mapId)
.mapShuffleLocations
.invalidateShuffleLocation(location.host(), Optional.of(location.port()))
}
}
if (shouldDelete) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
}
}
}

Expand All @@ -115,7 +128,14 @@ private class ShuffleStatus(numPartitions: Int) {
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = {
removeOutputsByFilter(x => x.host == host)
for (mapId <- 0 until mapStatuses.length) {
if (mapStatuses(mapId) != null &&
mapStatuses(mapId).mapShuffleLocations.invalidateShuffleLocation(host, Optional.empty())) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
}
}
}

/**
Expand All @@ -124,7 +144,14 @@ private class ShuffleStatus(numPartitions: Int) {
* still registered with that execId.
*/
def removeOutputsOnExecutor(execId: String): Unit = synchronized {
removeOutputsByFilter(x => x.executorId == execId)
for (mapId <- 0 until mapStatuses.length) {
if (mapStatuses(mapId) != null &&
mapStatuses(mapId).mapShuffleLocations.invalidateShuffleLocation(execId)) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
}
}
}

/**
Expand Down Expand Up @@ -283,7 +310,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging

// For testing
def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int)
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
: Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = {
getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)
}

Expand All @@ -297,7 +324,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* describing the shuffle blocks that are stored at that block manager.
*/
def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]
: Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])]

/**
* Deletes map output status information for the specified shuffle stage.
Expand Down Expand Up @@ -424,10 +451,10 @@ private[spark] class MapOutputTrackerMaster(
}

/** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
def unregisterMapOutput(shuffleId: Int, mapId: Int, shuffleLocations: Seq[ShuffleLocation]) {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeMapOutput(mapId, bmAddress)
shuffleStatus.removeMapOutput(mapId, shuffleLocations)
incrementEpoch()
case None =>
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
Expand Down Expand Up @@ -647,7 +674,7 @@ private[spark] class MapOutputTrackerMaster(
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
: Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
Expand Down Expand Up @@ -684,7 +711,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr

// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
: Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
Expand Down Expand Up @@ -873,9 +900,9 @@ private[spark] object MapOutputTracker extends Logging {
shuffleId: Int,
startPartition: Int,
endPartition: Int,
statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
statuses: Array[MapStatus]): Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = {
assert (statuses != null)
val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]]
val splitsByAddress = new HashMap[Seq[ShuffleLocation], ListBuffer[(BlockId, Long)]]
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
Expand All @@ -885,12 +912,13 @@ private[spark] object MapOutputTracker extends Logging {
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
if (status.mapShuffleLocations == null) {
splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) +=
if (status.mapShuffleLocations == null
|| status.mapShuffleLocations.getLocationsForBlock(part).isEmpty) {
splitsByAddress.getOrElseUpdate(Seq.empty, ListBuffer()) +=
((ShuffleBlockId(shuffleId, mapId, part), size))
} else {
val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part)
splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) +=
val shuffleLocations = status.mapShuffleLocations.getLocationsForBlock(part)
splitsByAddress.getOrElseUpdate(shuffleLocations.asScala, ListBuffer()) +=
((ShuffleBlockId(shuffleId, mapId, part), size))
}
}
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ class SparkContext(config: SparkConf) extends SafeLogging {
_dagScheduler = ds
}

private[spark] def shuffleDriverComponents: ShuffleDriverComponents =
_shuffleDriverComponents

/**
* A unique identifier for the Spark application.
* Its format depends on the scheduler implementation.
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/TaskEndReason.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import java.io.{ObjectInputStream, ObjectOutputStream}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.shuffle.ShuffleLocation
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.storage.BlockManagerId
Expand Down Expand Up @@ -81,14 +82,14 @@ case object Resubmitted extends TaskFailedReason {
*/
@DeveloperApi
case class FetchFailed(
bmAddress: BlockManagerId, // Note that bmAddress can be null
shuffleLocation: Seq[ShuffleLocation], // Note that shuffleLocation cannot be null
shuffleId: Int,
mapId: Int,
reduceId: Int,
message: String)
extends TaskFailedReason {
override def toErrorString: String = {
val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString
val bmAddressString = if (shuffleLocation == null) "null" else shuffleLocation.toString
s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " +
s"message=\n$message\n)"
}
Expand Down
Loading