Skip to content

Commit

Permalink
[XLA:GPU][NFC] Eliminate HandleXXX(HloInstruction *) functions from I…
Browse files Browse the repository at this point in the history
…REmitter.

- The HandleXXX(HloInstruction *) functions are no longer needed as we now convert
   the whole HLO module to LMHLO before IR emission.
- Also removed a couple of member variables that are not needed anymore.

PiperOrigin-RevId: 374908610
Change-Id: I7d5d2c2ac6e959bfc195723c1e7617f917e23825
  • Loading branch information
jurahul authored and tensorflower-gardener committed May 20, 2021
1 parent 8ce8777 commit 996f7d5
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 195 deletions.
159 changes: 2 additions & 157 deletions tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,11 +608,6 @@ StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSliceForMlir(
v, ir_emitter_context_->allocations(), constant_name);
}

Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitUsingElementalIrEmitter(input);
}

Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
// Replace unnested op with a fused nested op.
//
Expand Down Expand Up @@ -716,11 +711,6 @@ Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
return EmitLoopFusionFromMlir(input, unroll_factor);
}

Status IrEmitterUnnested::HandleConstant(HloInstruction* constant) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(constant));
return EmitConstant(input);
}

Status IrEmitterUnnested::EmitConstant(MlirEmitterInput mlir_input) {
auto get_global = mlir::cast<mlir::memref::GetGlobalOp>(mlir_input.op);
auto module = get_global->getParentOfType<mlir::ModuleOp>();
Expand Down Expand Up @@ -777,11 +767,6 @@ Status IrEmitterUnnested::EmitConstant(MlirEmitterInput mlir_input) {
return Status::OK();
}

Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(conditional));
return EmitConditionalFromMlir(mlir_input);
}

static ConditionalThunkConfig GetConditionalThunkConfig(
mlir::lmhlo::CaseOp op, std::vector<ThunkSequence> branch_thunk_sequences) {
ConditionalThunkConfig config;
Expand Down Expand Up @@ -1079,11 +1064,6 @@ Status IrEmitterUnnested::EmitSliceToDynamicFromMlir(
return Status::OK();
}

Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(custom_call));
return EmitCustomCallFromMlir(input);
}

Status IrEmitterUnnested::EmitCustomCallFromMlir(MlirEmitterInput input) {
using mlir::dyn_cast;
using mlir::isa;
Expand Down Expand Up @@ -1700,11 +1680,6 @@ Status IrEmitterUnnested::EmitCustomCallThunkFromMlir(MlirEmitterInput input) {
return Status::OK();
}

Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fft));
return EmitFftThunkFromMlir(input);
}

Status IrEmitterUnnested::EmitFftThunkFromMlir(MlirEmitterInput input) {
auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(input.op);
const Shape operand_shape = TypeToShape(fft_op.operand().getType());
Expand All @@ -1729,11 +1704,6 @@ Status IrEmitterUnnested::EmitFftThunkFromMlir(MlirEmitterInput input) {
return Status::OK();
}

Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitTriangularSolveFromMlir(input);
}

Status IrEmitterUnnested::EmitTriangularSolveFromMlir(MlirEmitterInput input) {
auto triangular_solve_op =
mlir::cast<mlir::lmhlo::TriangularSolveOp>(input.op);
Expand Down Expand Up @@ -1878,14 +1848,6 @@ static Status ProcessFusionForConversion(mlir::Region* region,
return Status::OK();
}

StatusOr<MlirEmitterInput> IrEmitterUnnested::GetMlirEmitterInput(
HloInstruction* hlo) {
MlirEmitterInput input;
TF_ASSIGN_OR_RETURN(input.op, lhlo_scratch_emitter_->EmitOp(hlo));
input.thunk_info = GetThunkInfo(hlo);
return input;
}

// TODO(timshen): update the comment once the HandleFusion code path deleted.
//
// This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
Expand Down Expand Up @@ -1997,11 +1959,6 @@ Status IrEmitterUnnested::EmitLoopFusionFromMlir(
return Status::OK();
}

Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(fusion));
return EmitFusionFromMlir(mlir_input);
}

Status IrEmitterUnnested::EmitFusionFromMlir(MlirEmitterInput mlir_input) {
auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);

Expand Down Expand Up @@ -2218,11 +2175,6 @@ Status IrEmitterUnnested::EmitFusionFromMlir(MlirEmitterInput mlir_input) {
return EmitLoopFusionFromMlir(mlir_input);
}

Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(copy));
return EmitCopyFromMlir(input);
}

Status IrEmitterUnnested::EmitCopyFromMlir(MlirEmitterInput input) {
auto copy = mlir::cast<mlir::lmhlo::CopyOp>(input.op);
auto operand_shape = TypeToShape(copy.operand().getType());
Expand Down Expand Up @@ -2273,11 +2225,6 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
return Status::OK();
}

Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(reduce));
return EmitReduceFromMlir(mlir_input);
}

Status IrEmitterUnnested::EmitReduceFromMlir(MlirEmitterInput mlir_input) {
const FusionLayoutAnalysis dummy_analysis;
if (GetHloOutputs(mlir_input.op).size() == 1 &&
Expand All @@ -2289,22 +2236,6 @@ Status IrEmitterUnnested::EmitReduceFromMlir(MlirEmitterInput mlir_input) {
return EmitUsingElementalIrEmitter(mlir_input);
}

Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
// For all tuples, we expect the elements of the tuple to be directly consumed
// by instructions that read from that tuple either directly, or through a
// GTE instruction. This is possible we do not support "dynamic tuples" since
// tuple-select is not supported. As a result, we never need to materialize a
// tuple (which has a runtime representation of an array of pointers) in
// memory at runtime. So there is no need to generate any code for tuples.
return Status::OK();
}

Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) {
// GetTupleElement IR is emitted in the IR context of the user instruction,
// and so we do not build a kernel for GetTupleElement instructions.
return Status::OK();
}

Status IrEmitterUnnested::AssertNonDeterminismIsOkay(const string& op_name) {
if (hlo_module_config_.debug_options().xla_gpu_deterministic_ops()) {
return Unimplemented(
Expand All @@ -2316,12 +2247,6 @@ Status IrEmitterUnnested::AssertNonDeterminismIsOkay(const string& op_name) {
return Status::OK();
}

Status IrEmitterUnnested::HandleSelectAndScatter(
HloInstruction* select_and_scatter) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(select_and_scatter));
return EmitSelectAndScatterFromMlir(input);
}

Status IrEmitterUnnested::EmitSelectAndScatterFromMlir(
MlirEmitterInput mlir_input) {
auto select_and_scatter_op =
Expand Down Expand Up @@ -2564,11 +2489,6 @@ Status IrEmitterUnnested::EmitSelectAndScatterFromMlir(
.EmitLoop(name, index_type);
}

Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(xla_while));
return EmitWhileFromMlir(mlir_input);
}

Status IrEmitterUnnested::EmitWhileFromMlir(MlirEmitterInput mlir_input) {
auto while_op = mlir::cast<mlir::lmhlo::WhileOp>(mlir_input.op);

Expand All @@ -2595,16 +2515,6 @@ Status IrEmitterUnnested::EmitWhileFromMlir(MlirEmitterInput mlir_input) {
return Status::OK();
}

Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
return Unimplemented("Rng should be expanded for GPU.");
}

Status IrEmitterUnnested::HandleRngGetAndUpdateState(
HloInstruction* rng_state) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(rng_state));
return EmitRngGetAndUpdateState(input);
}

Status IrEmitterUnnested::EmitRngGetAndUpdateState(
MlirEmitterInput mlir_input) {
auto rng_op =
Expand Down Expand Up @@ -2636,11 +2546,6 @@ Status IrEmitterUnnested::EmitRngGetAndUpdateState(
return Status::OK();
}

Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(scatter));
return EmitScatterFromMlir(input);
}

Status IrEmitterUnnested::EmitScatterFromMlir(MlirEmitterInput mlir_input) {
ThunkSequence thunks;

Expand Down Expand Up @@ -2999,11 +2904,6 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
return module->entry_computation();
}

Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(sort));
return EmitSortFromMlir(mlir_input);
}

Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(mlir_input.op);
MlirEmitterContext context;
Expand Down Expand Up @@ -3210,23 +3110,6 @@ Status IrEmitterUnnested::EmitReplicaOrPartitionIdFromMlir(
return Status::OK();
}

Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitReplicaOrPartitionIdFromMlir<ReplicaIdThunk,
mlir::lmhlo::ReplicaIdOp>(input);
}

Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitReplicaOrPartitionIdFromMlir<PartitionIdThunk,
mlir::lmhlo::PartitionIdOp>(input);
}

Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitCollectivePermuteFromMlir(input);
}

Status IrEmitterUnnested::EmitCollectivePermuteFromMlir(
MlirEmitterInput input) {
auto collective_permute_op =
Expand Down Expand Up @@ -3354,29 +3237,6 @@ Status IrEmitterUnnested::EmitNcclThunkFromMlir(MlirEmitterInput input) {
return Status::OK();
}

Status IrEmitterUnnested::HandleAllGather(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitNcclThunkFromMlir<NcclAllGatherThunk, mlir::lmhlo::AllGatherOp>(
input);
}

Status IrEmitterUnnested::HandleAllReduce(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitNcclThunkFromMlir<NcclAllReduceThunk, mlir::lmhlo::AllReduceOp>(
input);
}

Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitNcclThunkFromMlir<NcclAllToAllThunk, mlir::lmhlo::AllToAllOp>(
input);
}

Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed));
return EmitInfeedFromMlir(input);
}

Status IrEmitterUnnested::EmitInfeedFromMlir(MlirEmitterInput input) {
auto infeed_op = mlir::cast<mlir::lmhlo::InfeedOp>(input.op);

Expand All @@ -3394,11 +3254,6 @@ Status IrEmitterUnnested::EmitInfeedFromMlir(MlirEmitterInput input) {
return Status::OK();
}

Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(outfeed));
return EmitOutfeedFromMlir(input);
}

Status IrEmitterUnnested::EmitOutfeedFromMlir(MlirEmitterInput input) {
auto outfeed_op = mlir::cast<mlir::lmhlo::OutfeedOp>(input.op);

Expand All @@ -3416,10 +3271,6 @@ Status IrEmitterUnnested::EmitOutfeedFromMlir(MlirEmitterInput input) {
return Status::OK();
}

Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
return Status::OK();
}

std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunkForMlirImpl(
absl::string_view name, Thunk::ThunkInfo thunk_info,
absl::Span<const BufferSlice> slices,
Expand Down Expand Up @@ -4153,8 +4004,8 @@ void IrEmitterUnnested::EmitPrologueForReduction(
}
const HloInstruction* init_value = reduce_hlo->operand(1);

init_ir_value = (*fused_emitter->GetGenerator(
init_value))(IrArray::Index(b_.getInt32Ty()))
init_ir_value = (*fused_emitter->GetGenerator(init_value))(
IrArray::Index(b_.getInt32Ty()))
.ValueOrDie();
} else {
init_ir_value = operand_ir_arrays[1].EmitReadArrayElement(
Expand Down Expand Up @@ -5880,11 +5731,5 @@ void MlirEmitterContext::SetOperation(mlir::Operation* op) {
}
}

Status IrEmitterUnnested::HandleBitcast(HloInstruction* bitcast) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(bitcast));
DCHECK_EQ(nullptr, input.op);
return Status::OK();
}

} // namespace gpu
} // namespace xla

0 comments on commit 996f7d5

Please sign in to comment.