Skip to content

Commit

Permalink
Support RngBitGenerator HloInstruction with single output in the HLO …
Browse files Browse the repository at this point in the history
…-> MHLO conversion.

For the rng-bit-generator operator, xla HLO instruction can have two kinds of shapes, (1) tuple(output_state, output_data), and (2) output_data. On the contrary, `mhlo::RngBitGeneratorOp` has only one shape, (output_state, output_data). This cl supports RngBitGenerator HloInstruction with single output in the HLO -> MHLO conversion.

PiperOrigin-RevId: 620358061
  • Loading branch information
ZixuanJiang authored and tensorflower-gardener committed Apr 2, 2024
1 parent 4b642c5 commit 9db671f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
43 changes: 37 additions & 6 deletions third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,12 @@ absl::StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
auto new_operation,
ImportInstructionWithLayout(instruction, operands, builder));
if (new_operation) {
instruction_value_map_[instruction] = new_operation->getResult(0);
unsigned int idx =
(instruction->opcode() == HloOpcode::kRngBitGenerator &&
instruction->shape().IsArray())
? 1
: 0;
instruction_value_map_[instruction] = new_operation->getResult(idx);
}
}

Expand Down Expand Up @@ -1643,18 +1648,44 @@ absl::StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
}
}
case HloOpcode::kRngBitGenerator: {
// HloRngBitGeneratorInstruction can have two kinds of shapes, (1)
// tuple(output_state, output_data), and (2) output_data.
// mhlo::RngBitGeneratorOp has only one shape, (output_state,
// output_data).
auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);

auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get(
builder_->getContext(),
*mlir::mhlo::symbolizeRngAlgorithm(rng_op->algorithm()));
attributes.push_back(
builder_->getNamedAttr("rng_algorithm", algorithm_attr));

// Flatten the return type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
if (rng_op->shape().IsArray()) {
TF_ASSIGN_OR_RETURN(auto state_type,
ConvertShapeToType<RankedTensorType>(
rng_op->operand(0)->shape(), *builder_));
flattened_ret_types.insert(flattened_ret_types.begin(), state_type);

if (instruction->has_sharding()) {
Shape tuple_shape = ShapeUtil::MakeTupleShape(
{rng_op->operand(0)->shape(), instruction->shape()});
HloSharding tuple_sharding = HloSharding::Tuple(
tuple_shape, {HloSharding::Replicate(), instruction->sharding()});
CHECK_EQ(attributes.front().getName().str(), kShardingAttr);
attributes.front() = builder_->getNamedAttr(
kShardingAttr, ConvertSharding(tuple_sharding, builder_));
}
}
CHECK_EQ(flattened_ret_types.size(), 2);

auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get(
builder_->getContext(),
*mlir::mhlo::symbolizeRngAlgorithm(rng_op->algorithm()));
auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
loc, flattened_ret_types, algorithm_attr, operands[0]);

loc, flattened_ret_types, operands[0], attributes);
if (rng_op->shape().IsArray()) {
return op.getOperation();
}
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
result_type);
}
Expand Down
25 changes: 20 additions & 5 deletions third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt
Original file line number Diff line number Diff line change
Expand Up @@ -1723,14 +1723,29 @@ add {
ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1)
}

// CHECK-LABEL: func private @rngbitgen
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xui64>)
%rngbitgen (Arg_0.1: u64[3]) -> (u64[3], u32[2,2]) {
// CHECK-LABEL: func private @rngbitgen_tuple_shape
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xui64>) -> (tuple<tensor<3xui64>, tensor<2x2xui32>> {mhlo.sharding = "{{\{}}{maximal device=0}, {maximal device=1}}"})
%rngbitgen_tuple_shape (Arg_0.1: u64[3]) -> (u64[3], u32[2,2]) {
%Arg_0.1 = u64[3] parameter(0)
// CHECK: %[[RNG0:.+]], %[[RNG1:.+]] = "mhlo.rng_bit_generator"(%[[ARG0]]) {rng_algorithm = #mhlo.rng_algorithm<PHILOX>} : (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>)
// CHECK: %[[RNG0:.+]], %[[RNG1:.+]] = "mhlo.rng_bit_generator"(%[[ARG0]])
// CHECK-SAME: mhlo.sharding = "{{\{}}{maximal device=0}, {maximal device=1}}"
// CHECK-SAME: rng_algorithm = #mhlo.rng_algorithm<PHILOX>
// CHECK-SAME: (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>)
// CHECK: %[[TUPLE:.+]] = mhlo.tuple %[[RNG0]], %[[RNG1]] {xla_shape = "(u64[3]{0}, u32[2,2]{1,0})"} : tuple<tensor<3xui64>, tensor<2x2xui32>>
// CHECK: return %[[TUPLE]]
ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox
ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox, sharding={{maximal device=0}, {maximal device=1}}
}

// CHECK-LABEL: func private @rngbitgen_array_shape
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xui64>) -> (tensor<2x2xui32> {mhlo.sharding = "{maximal device=0}"})
%rngbitgen_array_shape (Arg_0.1: u64[3]) -> u32[2,2] {
%Arg_0.1 = u64[3] parameter(0)
// CHECK: %[[RNG0:.+]], %[[RNG1:.+]] = "mhlo.rng_bit_generator"(%[[ARG0]])
// CHECK-SAME: mhlo.sharding = "{{\{}}{replicated}, {maximal device=0}}"
// CHECK-SAME: rng_algorithm = #mhlo.rng_algorithm<DEFAULT>
// CHECK-SAME: (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>)
// CHECK: return %[[RNG1]]
ROOT %rng-bit-generator.2 = u32[2,2] rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_default, sharding={maximal device=0}
}

// CHECK-LABEL: func private @cbrt
Expand Down

0 comments on commit 9db671f

Please sign in to comment.