Skip to content

Commit

Permalink
[xla] Enhancement to the GPU latency hiding scheduler.
Browse files Browse the repository at this point in the history
We extend the GPU latency hiding scheduler in two aspects. First, we add one
more stream resource for executing Send and Recv instructions respectively.
Second, we extend the scheduler to handle Send-done and Recv-done ordered
before Send and Recv in a computation for a pipelined while loop.

Add tests.

PiperOrigin-RevId: 616331573
  • Loading branch information
bixia1 authored and tensorflower-gardener committed Mar 16, 2024
1 parent 139f5ed commit 34bbb0b
Show file tree
Hide file tree
Showing 3 changed files with 374 additions and 83 deletions.
164 changes: 136 additions & 28 deletions third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

Expand All @@ -35,6 +36,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -291,15 +293,79 @@ SchedulerConfig GetSchedulerConfig(int64_t memory_limit) {
// We use two different set of resources to model the scheduling of asynchronous
// collective operations and P2P Send and Recv operations. This corresponds to
// the fact that the runtime use a stream to run asynchronous collective
// operations and another stream to run P2P Send and Recv operations.
// operations and two other streams to run P2P Send and Recv operations.
enum class GpuResourceType {
kGpuAsyncStreamSend = 0, // The resource for P2P Send operation.
kGpuAsyncStreamRecv = 1, // The resource for P2P Recv operation.
kGpuAsyncStreamCollectives = 2, // The resource for collective operations.
kGpuAsyncStreamComputes = 3, // The resource for async compute operations.
kNumTargetResources = 4,
kGpuAsyncStreamSend0 = 0, // A resource for P2P Send operation.
kGpuAsyncStreamSend1 = 1, // Another resource for P2P Send operation.
kGpuAsyncStreamRecv0 = 2, // A resource for P2P Recv operation.
kGpuAsyncStreamRecv1 = 3, // Another resource for P2P Recv operation.
kGpuAsyncStreamCollectives = 4, // The resource for collective operations.
kGpuAsyncStreamComputes = 5, // The resource for async compute operations.
kNumTargetResources = 6,
};

// Returns the pipeline stream for a P2P instruction recorded in a frontend
// attribute.
int64_t GetPipelineStream(const HloInstruction& start) {
auto it = start.frontend_attributes().map().find(kSendRecvPipelineAttr);
if (it != start.frontend_attributes().map().end() && it->second == "1") {
return 1;
}
return 0;
}

// Returns the resource type and resource usage for a P2P instruction.
std::pair<GpuResourceType, ResourceUsageType> GetP2PResourceAndUsage(
const HloInstruction& instr, const CanonicalAsyncOp& op) {
ResourceUsageType usage;
int64_t pipeline = 0;
if (op.outer == HloOpcode::kAsyncStart) {
usage = ResourceUsageType::kResourceRelease;
pipeline = GetPipelineStream(instr);
} else {
usage = ResourceUsageType::kResourceOccupy;
// Check the operand for the Send-done or Recv-done instruction.
const HloInstruction* operand = instr.operand(0);
HloOpcode operand_opcode = operand->opcode();
if (operand_opcode == HloOpcode::kSend ||
operand_opcode == HloOpcode::kRecv) {
// Not a pipelined P2P.
pipeline = GetPipelineStream(*operand);
} else {
// A pipelined P2P. Find the corresponding start-op.
const HloSendRecvInstruction* start;
const HloGetTupleElementInstruction* gte =
DynCast<HloGetTupleElementInstruction>(operand);
int64_t tuple_index = gte->tuple_index();
if (gte->operand(0)->opcode() == HloOpcode::kWhile) {
// The op is a while-result, so the start-op should be a value in the
// while-op operands.
start = DynCast<HloSendRecvInstruction>(
gte->operand(0)->operand(0)->operand(tuple_index));
} else {
// The op is a while-body parameter, so the start-op should be a value
// in the while-body result.
const HloComputation* computation = instr.parent();
start = DynCast<HloSendRecvInstruction>(
computation->root_instruction()->operand(tuple_index));
}
pipeline = GetPipelineStream(*start);
}
}
HloOpcode opcode = op.inner;
GpuResourceType resource;
if (pipeline == 0) {
resource = opcode == HloOpcode::kSend
? GpuResourceType::kGpuAsyncStreamSend0
: GpuResourceType::kGpuAsyncStreamRecv0;
} else {
resource = opcode == HloOpcode::kSend
? GpuResourceType::kGpuAsyncStreamSend1
: GpuResourceType::kGpuAsyncStreamRecv1;
}
return {resource, usage};
}

// Base GPU async tracker that enables async tracking only for async collectives
// that are marked for async execution.
class GpuAsyncTrackerBase : public AsyncTracker {
Expand Down Expand Up @@ -346,26 +412,21 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase {
CanonicalAsyncOp op = GetCanonicalAsyncOp(instr);
if (op.outer == HloOpcode::kAsyncStart ||
op.outer == HloOpcode::kAsyncDone) {
ResourceUsageType usage = op.outer == HloOpcode::kAsyncStart
? ResourceUsageType::kResourceRelease
: ResourceUsageType::kResourceOccupy;
ResourcesVector resources;
auto add_resource = [&](GpuResourceType resource_type) {
const int64_t gpu_stream_resource = GetFirstTargetDefinedResource() +
static_cast<int64_t>(resource_type);
resources.push_back(std::make_pair(gpu_stream_resource, usage));
};

if (op.inner == HloOpcode::kSend) {
add_resource(GpuResourceType::kGpuAsyncStreamSend);
} else if (op.inner == HloOpcode::kRecv) {
add_resource(GpuResourceType::kGpuAsyncStreamRecv);
} else if (hlo_query::IsCollectiveCommunicationOp(op.inner)) {
add_resource(GpuResourceType::kGpuAsyncStreamCollectives);
ResourceUsageType usage;
GpuResourceType resource;
if (op.inner == HloOpcode::kSend || op.inner == HloOpcode::kRecv) {
std::tie(resource, usage) = GetP2PResourceAndUsage(instr, op);
} else {
add_resource(GpuResourceType::kGpuAsyncStreamComputes);
usage = op.outer == HloOpcode::kAsyncStart
? ResourceUsageType::kResourceRelease
: ResourceUsageType::kResourceOccupy;
resource = hlo_query::IsCollectiveCommunicationOp(op.inner)
? GpuResourceType::kGpuAsyncStreamCollectives
: GpuResourceType::kGpuAsyncStreamComputes;
}
return resources;
return {std::make_pair(
GetFirstTargetDefinedResource() + static_cast<int64_t>(resource),
usage)};
}
return GpuAsyncTrackerBase::GetResourcesFromInstruction(instr);
}
Expand Down Expand Up @@ -415,10 +476,14 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase {
first_target_resource + GetNumTargetDefinedResources());
switch (
static_cast<GpuResourceType>(resource_type - first_target_resource)) {
case GpuResourceType::kGpuAsyncStreamSend:
return "kGpuAsyncStreamSend";
case GpuResourceType::kGpuAsyncStreamRecv:
return "kGpuAsyncStreamRecv";
case GpuResourceType::kGpuAsyncStreamSend0:
return "kGpuAsyncStreamSend0";
case GpuResourceType::kGpuAsyncStreamSend1:
return "kGpuAsyncStreamSend1";
case GpuResourceType::kGpuAsyncStreamRecv0:
return "kGpuAsyncStreamRecv0";
case GpuResourceType::kGpuAsyncStreamRecv1:
return "kGpuAsyncStreamRecv1";
case GpuResourceType::kGpuAsyncStreamCollectives:
return "kGpuAsyncStreamCollectives";
case GpuResourceType::kGpuAsyncStreamComputes:
Expand All @@ -438,6 +503,49 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase {
first_target_resource + GetNumTargetDefinedResources());
return ResourceHazardType::kUnshareable;
}

int64_t GetNumResourcesPerInstruction(
int64_t resource_type, const HloInstruction& instr) const override {
int64_t num_resources = GpuAsyncTrackerBase::GetNumResourcesPerInstruction(
resource_type, instr);
if (num_resources <= 0 || instr.opcode() != HloOpcode::kWhile) {
return num_resources;
}
// For while-loop with pipelined Send/Recv, the while-body first releases
// the Send/Recv resource and then uses the resource. Therefore, subtract 1
// from num_resources for the relevant resource type.
int64_t first_p2p_resource =
GetFirstTargetDefinedResource() +
static_cast<int64_t>(GpuResourceType::kGpuAsyncStreamSend0);
if (resource_type < first_p2p_resource ||
resource_type > first_p2p_resource + 4) {
return num_resources;
}
auto find_instruction_for_pipeline = [&](HloOpcode opcode,
int64_t pipeline) {
for (auto operand : instr.operand(0)->operands()) {
if (operand->opcode() == opcode) {
int64_t cur_pipeline = GetPipelineStream(*operand);
if (cur_pipeline == pipeline) {
return true;
}
}
}
return false;
};
bool found;
// Look into the while-op init-values to find pipelined Send/Recv.
if (resource_type == first_p2p_resource) {
found = find_instruction_for_pipeline(HloOpcode::kSend, 0);
} else if (resource_type == first_p2p_resource + 1) {
found = find_instruction_for_pipeline(HloOpcode::kSend, 1);
} else if (resource_type == first_p2p_resource + 2) {
found = find_instruction_for_pipeline(HloOpcode::kRecv, 0);
} else {
found = find_instruction_for_pipeline(HloOpcode::kRecv, 1);
}
return num_resources - (found ? 1 : 0);
}
};

class GpuLatencyEstimator : public ApproximateLatencyEstimator {
Expand Down
Loading

0 comments on commit 34bbb0b

Please sign in to comment.