Skip to content

Commit

Permalink
Refactor loop unroller pass.
Browse files Browse the repository at this point in the history
Add a knob to force unroll if needed.

Reverts dbf3cd3

PiperOrigin-RevId: 636179776
  • Loading branch information
farzinhoushmand authored and tensorflower-gardener committed Jun 4, 2024
1 parent 9dab91d commit bbb680d
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 157 deletions.
1 change: 0 additions & 1 deletion tensorflow/core/common_runtime/eager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ tf_cuda_library(
"//tensorflow/core/framework:resource_base",
"@local_xla//xla/pjrt/distributed:key_value_store_interface",
"@local_xla//xla/pjrt:local_device_state",
"@local_xla//xla/pjrt/gpu:gpu_topology",
"@local_xla//xla/pjrt:pjrt_client",
"@local_xla//xla/pjrt:pjrt_compiler",
"@local_xla//xla/service/gpu:gpu_executable_run_options",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ limitations under the License.
#if (defined(PLATFORM_GOOGLE) && defined(TF_PLATFORM_LINUX_X86_64))
#define TF_GPU_USE_PJRT
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/local_device_state.h"
#include "xla/pjrt/pjrt_compiler.h"
Expand Down Expand Up @@ -329,18 +328,17 @@ absl::Status CreateClientOnce(
// proceed.
creation_state->SetReady();
}
auto device_topology_pair = BuildDistributedDevices(
auto status = BuildDistributedDevices(
platform_name, std::move(unique_local_device_states), node_id, num_nodes,
gpu_run_options.get(), kv_store,
&pjrt_devices, gpu_run_options.get(), kv_store,
/*enable_mock_nccl=*/false);
if (!device_topology_pair.ok()) {
if (!status.ok()) {
if (use_creation_info) {
creation_state->SetDone();
}
return device_topology_pair.status();
return status;
}

pjrt_devices = std::move(device_topology_pair->first);
VLOG(2) << "Distributed devices built with size=" << pjrt_devices.size();
int i = 0;
for (const auto& pjrt_device : pjrt_devices) {
Expand All @@ -352,18 +350,6 @@ absl::Status CreateClientOnce(
}
}

std::shared_ptr<const xla::GpuTopology> gpu_topology = nullptr;
if (!device_topology_pair->second.ok()) {
LOG(INFO)
<< "Skipping creating GPU topology since multiple nodes on the same "
"host violates GPU topology assumptions. This is expected in tests "
"that use multiple threads to simulate multiple workers. If this "
"occurs in production and op execution on GPU fails, this could be "
"related.";
} else {
gpu_topology =
xla::GpuTopology::FromProto(device_topology_pair->second.value());
}
if (use_creation_info) {
std::unique_ptr<xla::PjRtClient> pjrt_client =
std::make_unique<xla::StreamExecutorGpuClient>(
Expand All @@ -372,11 +358,10 @@ absl::Status CreateClientOnce(
/*allocator=*/std::move(info->allocator),
/*host_memory_allocator=*/std::move(info->host_memory_allocator),
/*should_stage_host_to_device_transfers=*/true,
/*gpu_run_options=*/std::move(gpu_run_options),
std::move(gpu_topology));
/*gpu_run_options=*/std::move(gpu_run_options));
VLOG(2) << "PJRT GPU client with remote devices created.";
auto status = SetPjRtClientInTFGlobalResourceManager(
DeviceType(DEVICE_GPU), std::move(pjrt_client));
status = SetPjRtClientInTFGlobalResourceManager(DeviceType(DEVICE_GPU),
std::move(pjrt_client));
creation_state->SetDone();
return status;
} else {
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ AotCompileToGpuPjRtExecutable(
xla::Compiler::TargetConfig gpu_config(gpu_target_config);
xla::StreamExecutorGpuCompiler pjrt_gpu_compiler;
// Create a trivial topology, which won't be used.
xla::StreamExecutorGpuTopologyDescription topology(xla::CudaId(),
xla::CudaName(), nullptr);
xla::StreamExecutorGpuTopologyDescription topology(
xla::CudaId(), xla::CudaName(), "fake_device", {0});
xla::CompileOptions pjrt_options =
GetPjRtCompileOptions(options, **compilation_result);
pjrt_options.target_config = gpu_config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,30 +168,37 @@ struct ConvertShapeOfOpPattern : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter& rewriter) const override {
auto operandType = mlir::dyn_cast<RankedTensorType>(op.getArg().getType());
auto operandType = dyn_cast<RankedTensorType>(op.getArg().getType());
if (!operandType)
return rewriter.notifyMatchFailure(op, "expected ranked operand");

// Produce an MHLO equivalent of this shape::ShapeOfOp.
// This is a very laborious representation because MHLO is currently lacking
// convenient tools to express this.
SmallVector<Value> sizesI32x1;
for (auto i = 0; i < operandType.getRank(); ++i) {
auto sizeI32 =
rewriter.create<GetDimensionSizeOp>(op.getLoc(), op.getArg(), i);
auto sizeI32x1 = rewriter.create<ReshapeOp>(
op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()),
sizeI32);
sizesI32x1.push_back(sizeI32x1);
Value shapeI32;
if (operandType.getRank() > 0) {
SmallVector<Value> sizesI32x1;
for (auto i = 0; i < operandType.getRank(); ++i) {
auto sizeI32 =
rewriter.create<GetDimensionSizeOp>(op.getLoc(), op.getArg(), i);
auto sizeI32x1 = rewriter.create<ReshapeOp>(
op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()),
sizeI32);
sizesI32x1.push_back(sizeI32x1);
}
shapeI32 = rewriter.create<ConcatenateOp>(op.getLoc(), sizesI32x1,
/*dimension=*/0);
} else {
shapeI32 = rewriter.create<ConstantOp>(
op.getLoc(), DenseElementsAttr::get(
RankedTensorType::get({0}, rewriter.getI32Type()),
ArrayRef<Attribute>()));
}
auto shapeI32 =
rewriter.create<mhlo::ConcatenateOp>(op.getLoc(), sizesI32x1,
/*dimension=*/0);

// Cast result from tensor<Nxi32> to tensor<Nxindex>.
// This will error out if the result is !shape.shape.
auto shapeIndex = castToIndex(rewriter, op.getLoc(), shapeI32);
if (!shapeIndex || shapeIndex.getType() != op.getResult().getType())
if (!shapeIndex || shapeIndex.getType() != op.getType())
return rewriter.notifyMatchFailure(op, "cast to index failed");
rewriter.replaceOp(op, shapeIndex);
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,15 @@ func.func @tensor_extract_dynamic(%arg0: tensor<?x3xindex>) -> index {
%0 = tensor.extract %arg0[%c1, %c2] : tensor<?x3xindex>
return %0 : index
}

// -----

// CHECK-LABEL: func @shape_of_zero_ranked_tensor
func.func @shape_of_zero_ranked_tensor(%arg0 : tensor<i32>) -> tensor<0xindex> {
// CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi32>
// CHECK-NEXT: %[[RES_DIM0_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CONST]] : tensor<0xi32> to tensor<0xindex>
// CHECK-NEXT: return %[[RES_DIM0_INDEX]] : tensor<0xindex>
%0 = shape.shape_of %arg0 : tensor<i32> -> tensor<0xindex>
func.return %0 : tensor<0xindex>
}

1 change: 0 additions & 1 deletion third_party/xla/xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ cc_library(
"//xla/pjrt:pjrt_device_description",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt/gpu:gpu_helpers",
"//xla/pjrt/gpu:gpu_topology",
"//xla/pjrt/gpu:se_gpu_pjrt_client",
"//xla/pjrt/gpu:se_gpu_pjrt_compiler", # To register GPU AOT compiler
"//xla/python:custom_partition_callback",
Expand Down
10 changes: 1 addition & 9 deletions third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ limitations under the License.
#include "xla/pjrt/c/pjrt_c_api_stream_extension.h"
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
Expand Down Expand Up @@ -188,16 +187,9 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create(
device_ids.push_back(executor->device_ordinal());
}
auto gpu_target_config = xla::Compiler::TargetConfig(executor);
// TODO(b/341334898): Create a single-host GPU topology. Will be updated for
// multi-host support in the future.
auto gpu_topology = std::make_shared<const xla::GpuTopology>(
device_ids, description.name(),
/*num_slices=*/1,
/*num_hosts_per_slice=*/1,
/*num_devices_per_host=*/device_ids.size());
auto pjrt_topology =
std::make_unique<xla::StreamExecutorGpuTopologyDescription>(
xla::CudaId(), xla::CudaName(), std::move(gpu_topology),
xla::CudaId(), xla::CudaName(), description.name(), device_ids,
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>{
{"target_config",
gpu_target_config.ToProto().SerializeAsString()}});
Expand Down
4 changes: 0 additions & 4 deletions third_party/xla/xla/pjrt/distributed/topology_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,6 @@ absl::StatusOr<GpuTopologyProto> BuildGpuTopology(
for (int i = 0; i < global_topology.nodes_size(); ++i) {
const LocalTopologyProto& local_topology = global_topology.nodes(i);

if (local_topology.devices_size() == 0) {
return absl::InternalError("Local topology has no devices.");
}

slice_id_to_node_ids[local_topology.devices(0).slice_index()].push_back(
local_topology.node_id());

Expand Down
2 changes: 0 additions & 2 deletions third_party/xla/xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ cc_library(
":gpu_helpers",
":gpu_metrics",
":gpu_topology",
":gpu_topology_proto_cc",
"//xla:literal",
"//xla:shape_util",
"//xla:status_macros",
Expand Down Expand Up @@ -300,7 +299,6 @@ xla_cc_test(
"requires-gpu-nvidia",
] + if_google(["config-cuda-only"]),
deps = [
":gpu_topology",
":se_gpu_pjrt_client",
":se_gpu_pjrt_compiler",
"//xla:test",
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/pjrt/gpu/gpu_topology.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include "xla/pjrt/gpu/gpu_topology.h"

#include <memory>
#include <string>
#include <vector>

namespace xla {
Expand All @@ -34,7 +33,7 @@ std::unique_ptr<const GpuTopology> GpuTopology::FromProto(
GpuTopologyProto GpuTopology::ToProto() const {
GpuTopologyProto proto;
proto.mutable_device_ids()->Add(device_ids().begin(), device_ids().end());
proto.set_platform_version(std::string(platform_version()));
proto.set_platform_version(platform_version());
proto.set_num_slices(num_slices());
proto.set_num_hosts_per_slice(num_hosts_per_slice());
proto.set_num_devices_per_host(num_devices_per_host());
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/pjrt/gpu/gpu_topology.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class GpuTopology {
const GpuTopologyProto& proto);
GpuTopologyProto ToProto() const;

std::string_view platform_version() const { return platform_version_; }
std::string platform_version() const { return platform_version_; }
int32_t num_slices() const { return num_slices_; }
int32_t num_hosts_per_slice() const { return num_hosts_per_slice_; }
int32_t num_devices_per_host() const { return num_devices_per_host_; }
Expand Down
42 changes: 16 additions & 26 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ limitations under the License.
#include "xla/pjrt/distributed/topology_util.h"
#include "xla/pjrt/event_pool.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/gpu/gpu_topology.pb.h"
#include "xla/pjrt/local_device_state.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
Expand Down Expand Up @@ -486,15 +484,14 @@ StreamExecutorGpuClient::StreamExecutorGpuClient(
int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tsl::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options,
std::shared_ptr<const GpuTopology> gpu_topology)
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
: xla::PjRtStreamExecutorClient(
platform_name, client, std::move(devices), process_index,
std::move(allocator), std::move(host_memory_allocator),
should_stage_host_to_device_transfers, std::move(gpu_run_options)),
topology_(xla::StreamExecutorGpuTopologyDescription::Create(
tsl::Fingerprint64(platform_name), platform_name,
std::move(gpu_topology))) {
devices_.back()->device_kind(), devices_)) {
for (auto* device : addressable_devices()) {
// Use the device id to construct a globally unique memory space id. We do
// not promise that memory space ids and device ids are the same.
Expand Down Expand Up @@ -943,15 +940,15 @@ GetStreamExecutorGpuDeviceAllocator(

} // namespace

absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
absl::Status BuildDistributedDevices(
std::string_view platform_name,
std::map<int, std::unique_ptr<LocalDeviceState>> local_device_states,
int node_id, int num_nodes,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
gpu::GpuExecutableRunOptions* gpu_executable_run_options,
std::shared_ptr<KeyValueStoreInterface> kv_store, bool enable_mock_nccl,
absl::Duration get_local_topology_timeout,
absl::Duration get_global_topology_timeout) {
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
LocalTopologyProto local_topology;
local_topology.set_node_id(node_id);
std::string boot_id_str;
Expand Down Expand Up @@ -1011,7 +1008,7 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
device_proto.name(), device_proto.vendor(),
device_proto.compute_capability(), device_proto.core_count(),
node.node_id(), device_proto.slice_index());
devices.push_back(std::move(device));
devices->push_back(std::move(device));
}
}
for (const auto& device : local_device_states) {
Expand All @@ -1029,8 +1026,7 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
});
}
#endif // GOOGLE_CUDA

return std::make_pair(std::move(devices), BuildGpuTopology(global_topology));
return absl::OkStatus();
}

std::string MakeComputeCapabilityString(const se::DeviceDescription* desc) {
Expand Down Expand Up @@ -1152,6 +1148,7 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
auto host_memory_allocator =
GetGpuHostAllocator(local_device_states.begin()->second->executor());

std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
auto gpu_run_options = std::make_unique<gpu::GpuExecutableRunOptions>();
if (options.enable_mock_nccl) {
gpu_run_options->set_enable_mock_nccl_collectives();
Expand All @@ -1161,29 +1158,22 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
kv_store = std::make_shared<InMemoryKeyValueStore>();
}
TF_RET_CHECK(options.num_nodes == 1 || kv_store != nullptr);
TF_ASSIGN_OR_RETURN(
DeviceTopologyPair device_topology_pair,
BuildDistributedDevices(pjrt_platform_name,
std::move(local_device_states), options.node_id,
options.num_nodes, gpu_run_options.get(),
kv_store, options.enable_mock_nccl));
if (!device_topology_pair.second.ok()) {
return device_topology_pair.second.status();
}
auto gpu_topology = std::shared_ptr<const GpuTopology>(
GpuTopology::FromProto(device_topology_pair.second.value()));
TF_RETURN_IF_ERROR(BuildDistributedDevices(
pjrt_platform_name, std::move(local_device_states), options.node_id,
options.num_nodes, &devices, gpu_run_options.get(), kv_store,
options.enable_mock_nccl));

return std::unique_ptr<PjRtClient>(std::make_unique<StreamExecutorGpuClient>(
pjrt_platform_name, xla_client, std::move(device_topology_pair.first),
options.node_id, std::move(allocator), std::move(host_memory_allocator),
options.should_stage_host_to_device_transfers, std::move(gpu_run_options),
std::move(gpu_topology)));
pjrt_platform_name, xla_client, std::move(devices), options.node_id,
std::move(allocator), std::move(host_memory_allocator),
options.should_stage_host_to_device_transfers,
std::move(gpu_run_options)));
}

absl::StatusOr<std::string> StreamExecutorGpuTopologyDescription::Serialize()
const {
std::string result;
if (!tsl::SerializeToStringDeterministic(gpu_topology_->ToProto(), &result)) {
if (!tsl::SerializeToStringDeterministic(gpu_topology_.ToProto(), &result)) {
return absl::InternalError("Failed to serialize gpu_topology");
}
return result;
Expand Down
Loading

0 comments on commit bbb680d

Please sign in to comment.