Skip to content

Commit

Permalink
Minor refactor: rename the 'lower bound batch threads' transform to a…
Browse files Browse the repository at this point in the history
… more generic 'reconfig batch op'. It makes no logical changes.

PiperOrigin-RevId: 636898956
  • Loading branch information
tensorflower-gardener committed May 28, 2024
1 parent 092d33a commit c1cfd89
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 86 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/tfrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ cc_library(
"transforms/deduplicate_if_result_pass.cc",
"transforms/fuse_tpu_compile_and_execute_ops.cc",
"transforms/insert_tensor_copy.cc",
"transforms/lower_bound_batch_threads.cc",
"transforms/lower_saved_model.cc",
"transforms/merge_tf_if_ops.cc",
"transforms/optimize.cc",
"transforms/optimize_tf_control_flow_side_effect.cc",
"transforms/passes.cc",
"transforms/reconfig_batch_op.cc",
"transforms/remove_device_attribute.cc",
"transforms/remove_tf_if_const_args.cc",
"transforms/reorder_assert.cc",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tf-tfrt-opt -split-input-file -tfrt-lower-bound-batch-threads="tfrt-min-num-batch-threads=2" %s | FileCheck %s --dump-input=always
// RUN: tf-tfrt-opt -split-input-file -tfrt-reconfig-batch-op="tfrt-min-num-batch-threads=2" %s | FileCheck %s --dump-input=always

// -----

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/tfrt/transforms/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper(
pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass());

// Lower bound on the number of batch threads in `tf.BatchFunction`.
pm.addPass(tfrt_compiler::CreateLowerBoundBatchThreadsPass(
options.min_num_batch_threads));
pm.addPass(tfrt_compiler::CreateReconfigBatchOpPass(
{.min_num_batch_threads = options.min_num_batch_threads}));

// Deduplicate functions invoked by tf.BatchFunction with the same
// shared_name
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/compiler/mlir/tfrt/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_

#include <cstdint>
#include <memory>

#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -67,8 +68,11 @@ std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDeduplicateFunctionsInovkedByBatchFunctionPass();

// Create a pass to lower bound the number of threads in tf.BatchFunction.
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerBoundBatchThreadsPass(int64_t min_num_batch_threads);
struct ReconfigBatchOpPassOptions {
int64_t min_num_batch_threads = 1;
};
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateReconfigBatchOpPass(
ReconfigBatchOpPassOptions options);

// Create a pass to fuse the TPU Ops for TFRT.
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,31 @@ namespace tensorflow {
namespace tfrt_compiler {
namespace {

class LowerBoundBatchThreadsPass
: public mlir::PassWrapper<LowerBoundBatchThreadsPass,
class ReconfigBatchOpPass
: public mlir::PassWrapper<ReconfigBatchOpPass,
mlir::OperationPass<mlir::ModuleOp>> {
public:
explicit LowerBoundBatchThreadsPass(uint64_t min_num_batch_threads)
: mlir::PassWrapper<LowerBoundBatchThreadsPass,
explicit ReconfigBatchOpPass(ReconfigBatchOpPassOptions options)
: mlir::PassWrapper<ReconfigBatchOpPass,
mlir::OperationPass<mlir::ModuleOp>>() {
min_num_batch_threads_ = min_num_batch_threads;
min_num_batch_threads_ = options.min_num_batch_threads;
}
LowerBoundBatchThreadsPass()
: mlir::PassWrapper<LowerBoundBatchThreadsPass,
ReconfigBatchOpPass()
: mlir::PassWrapper<ReconfigBatchOpPass,
mlir::OperationPass<mlir::ModuleOp>>() {}
LowerBoundBatchThreadsPass(const LowerBoundBatchThreadsPass& other)
: mlir::PassWrapper<LowerBoundBatchThreadsPass,
ReconfigBatchOpPass(const ReconfigBatchOpPass& other)
: mlir::PassWrapper<ReconfigBatchOpPass,
mlir::OperationPass<mlir::ModuleOp>>(other) {}

LowerBoundBatchThreadsPass& operator=(
const LowerBoundBatchThreadsPass& other) = delete;
ReconfigBatchOpPass& operator=(const ReconfigBatchOpPass& other) = delete;

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerBoundBatchThreadsPass)
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReconfigBatchOpPass)

private:
llvm::StringRef getArgument() const final {
return "tfrt-lower-bound-batch-threads";
}
llvm::StringRef getArgument() const final { return "tfrt-reconfig-batch-op"; }

llvm::StringRef getDescription() const final {
return "Lower bound batch threads for batch ops.";
return "Reconfig batch op such as num_batch_threads.";
}

void runOnOperation() override {
Expand All @@ -82,12 +79,12 @@ class LowerBoundBatchThreadsPass

} // namespace

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerBoundBatchThreadsPass(int64_t min_num_batch_threads) {
return std::make_unique<LowerBoundBatchThreadsPass>(min_num_batch_threads);
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateReconfigBatchOpPass(
ReconfigBatchOpPassOptions options) {
return std::make_unique<ReconfigBatchOpPass>(options);
}

static mlir::PassRegistration<LowerBoundBatchThreadsPass> register_pass;
static mlir::PassRegistration<ReconfigBatchOpPass> register_pass;

} // namespace tfrt_compiler
} // namespace tensorflow
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ cc_library(
"//xla:status",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
Expand Down
44 changes: 31 additions & 13 deletions third_party/xla/xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <variant>
#include <vector>
Expand Down Expand Up @@ -64,6 +63,7 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

Expand Down Expand Up @@ -172,7 +172,8 @@ absl::Status CollectSliceInfo(
const BufferAssignment& buffer_assignment,
const HloInstruction& fusion_instr,
absl::Span<HloInstruction*> slice_instrs,
std::vector<std::optional<std::vector<BufferAllocation::Slice>>>&
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>&
offset_buffer_indices,
std::vector<std::optional<Shape>>& orig_shapes,
std::vector<std::optional<Shape>>& sliced_shapes,
Expand All @@ -183,15 +184,30 @@ absl::Status CollectSliceInfo(
return absl::OkStatus();
}

std::vector<BufferAllocation::Slice> offset_slices;
std::vector<std::variant<int64_t, BufferAllocation::Slice>> offset_slices;
for (auto idx_op : slice_instr->index_operands()) {
const auto* param = Cast<HloParameterInstruction>(idx_op);
TF_ASSIGN_OR_RETURN(
auto offset_slice,
GetAllocationSlice(buffer_assignment,
fusion_instr.operand(param->parameter_number()),
/*index=*/{}));
offset_slices.push_back(offset_slice);
const auto* offset_param = fusion_instr.operand(param->parameter_number());

if (auto* cst_offset = DynCast<HloConstantInstruction>(offset_param)) {
auto s32_scalar = ShapeUtil::MakeShape(PrimitiveType::S32, {});
auto s64_scalar = ShapeUtil::MakeShape(PrimitiveType::S64, {});

if (cst_offset->shape() == s32_scalar) {
offset_slices.emplace_back() = cst_offset->literal().data<int32_t>()[0];
} else if (cst_offset->shape() == s64_scalar) {
offset_slices.emplace_back() = cst_offset->literal().data<int64_t>()[0];
} else {
return absl::InternalError(
absl::StrCat("Unsupported constant offset shape: ",
cst_offset->shape().ToString()));
}

} else {
TF_ASSIGN_OR_RETURN(offset_slices.emplace_back(),
GetAllocationSlice(buffer_assignment, offset_param,
/*index=*/{}));
}
}
offset_buffer_indices[arg_idx] = std::move(offset_slices);
orig_shapes[arg_idx] = slice_instr->operand(0)->shape();
Expand Down Expand Up @@ -256,7 +272,8 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
const BufferAssignment& buffer_assignment =
ir_emitter_context.buffer_assignment();

std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices(4, std::nullopt);
std::vector<std::optional<Shape>> orig_shapes(4, std::nullopt);
std::vector<std::optional<Shape>> sliced_shapes(4, std::nullopt);
Expand Down Expand Up @@ -379,7 +396,7 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
thunk_info, std::move(config), slice_lhs_fake, slice_rhs_fake,
slice_out_fake, slice_workspace_fake, deterministic_ops));

std::vector<std::optional<const BufferAllocation::Slice>> arguments{
std::vector<std::optional<BufferAllocation::Slice>> arguments{
lhs_slice, rhs_slice, output, workspace};

thunk = std::make_unique<AddressComputationThunk>(
Expand Down Expand Up @@ -435,15 +452,16 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
num_args += ShapeUtil::GetLeafCount(operand->shape());
});

std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices(num_args, std::nullopt);
std::vector<std::optional<Shape>> orig_shapes(num_args, std::nullopt);
std::vector<std::optional<Shape>> sliced_shapes(num_args, std::nullopt);
std::vector<std::optional<uint64_t>> offset_byte_sizes(num_args,
std::nullopt);

std::vector<HloInstruction*> slice_instrs(num_args, nullptr);
std::vector<std::optional<const BufferAllocation::Slice>> arguments;
std::vector<std::optional<BufferAllocation::Slice>> arguments;

unsigned arg_idx = 0;
// TODO(vuson): add test for custom call with token-typed operands
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ cc_library(
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ limitations under the License.
#include <memory>
#include <optional>
#include <utility>
#include <variant>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "llvm/ADT/STLExtras.h"
#include "xla/service/buffer_assignment.h"
Expand All @@ -38,16 +38,18 @@ limitations under the License.
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/memory_allocation.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {

AddressComputationThunk::AddressComputationThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
std::vector<std::optional<const BufferAllocation::Slice>> arguments,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations,
std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
Expand Down Expand Up @@ -151,28 +153,47 @@ absl::Status AddressComputationThunk::ExecuteOnStream(
std::vector<int64_t> slice_starts;
slice_starts.reserve(dst_shape.rank());

// Number of issues d2h transfers to copy offset values from device to host.
int64_t num_transfers = 0;

// Get offset for `argument_idx`-th argument, which has `dst_shape.rank()`
// components.
for (auto [offset_idx, values] : llvm::enumerate(llvm::zip(
*offset_slice, src_shape.dimensions(), dst_shape.dimensions()))) {
auto [slice, src_dim, dst_dim] = values;
se::DeviceMemoryBase offset_src =
orig_allocations.GetDeviceAddress(slice);
int64_t* offset_dst = &offsets_base[argument_idx + offset_idx];
// Copy the `offset_idx`-th component of the offset for the
// `argument_idx`-th argument from device to host.
TF_RETURN_IF_ERROR(
stream.Memcpy(offset_dst, offset_src, offset_byte_size.value()));

if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) {
return absl::InternalError(absl::StrFormat(
"Failed to retrieve all slice offset values on stream %p: %s",
&stream, blocked.message()));

if (int64_t* const_offset = std::get_if<int64_t>(&slice)) {
// Forward slice offsets that are known constant values
offsets_base[argument_idx + offset_idx] = *const_offset;
} else {
// Transfer slice offset value from device to host.
se::DeviceMemoryBase offset_src = orig_allocations.GetDeviceAddress(
std::get<BufferAllocation::Slice>(slice));
int64_t* offset_dst = &offsets_base[argument_idx + offset_idx];

// Copy the `offset_idx`-th component of the offset for the
// `argument_idx`-th argument from device to host.
TF_RETURN_IF_ERROR(
stream.Memcpy(offset_dst, offset_src, offset_byte_size.value()));
++num_transfers;
}
// Clamp start indices:
// start_indices[i] = min(max(start_indices[i], 0),
// operand.dimension_size[i] - size_indices[i])
auto start_index = std::min(std::max(*offset_dst, 0L), src_dim - dst_dim);
}

// Wait for the completion of all transfers.
if (num_transfers > 0) {
VLOG(2) << "Wait for completion of " << num_transfers << " transfer";
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
}

// Clamp start indices:
// start_indices[i] = min(max(start_indices[i], 0),
// operand.dimension_size[i] - size_indices[i])
for (auto [offset_idx, values] : llvm::enumerate(
llvm::zip(src_shape.dimensions(), dst_shape.dimensions()))) {
auto [src_dim, dst_dim] = values;
int64_t start_index =
std::min(std::max(offsets_base[argument_idx + offset_idx], 0L),
src_dim - dst_dim);
slice_starts.push_back(start_index);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <optional>
#include <variant>
#include <vector>

#include "absl/base/thread_annotations.h"
Expand All @@ -29,6 +30,7 @@ limitations under the License.
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/shape.h"
#include "xla/status.h"
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/stream_executor.h"
Expand All @@ -45,9 +47,10 @@ class AddressComputationThunk : public Thunk {
public:
AddressComputationThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
std::vector<std::optional<const BufferAllocation::Slice>> arguments,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations_,
std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
Expand All @@ -65,10 +68,10 @@ class AddressComputationThunk : public Thunk {

private:
std::unique_ptr<SequentialThunk> embedded_thunk_;
std::vector<std::optional<const BufferAllocation::Slice>>
embedded_thunk_arguments_;
std::vector<std::optional<BufferAllocation::Slice>> embedded_thunk_arguments_;
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations_;
std::vector<std::optional<std::vector<BufferAllocation::Slice>>>
std::vector<std::optional<
std::vector<std::variant<int64_t, BufferAllocation::Slice>>>>
offset_buffer_indices_;
std::vector<std::optional<Shape>> orig_shapes_;
std::vector<std::optional<Shape>> sliced_shapes_;
Expand Down
Loading

0 comments on commit c1cfd89

Please sign in to comment.