Skip to content

Commit

Permalink
Remove unused options argument from Platform::Initialize.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642379707
  • Loading branch information
klucke authored and tensorflower-gardener committed Jun 11, 2024
1 parent 48eb446 commit 431e033
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 72 deletions.
2 changes: 1 addition & 1 deletion tensorflow/dtensor/cc/dtensor_tpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class ConfigureAndInitializeGlobalTPUOpKernel : public OpKernel {
while (!tpu_platform->Initialized() &&
(absl::Now() - start < retry_timeout)) {
VLOG(1) << "Initializaing global TPU system.";
init_status = tpu_platform->Initialize({});
init_status = tpu_platform->Initialize();
}
if (!tpu_platform->Initialized()) {
return errors::Unavailable("Unable to initialize TPU system.");
Expand Down
11 changes: 1 addition & 10 deletions third_party/xla/xla/stream_executor/platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,8 @@ StreamExecutorConfig::StreamExecutorConfig() : ordinal(-1) {}
StreamExecutorConfig::StreamExecutorConfig(int ordinal_in)
: ordinal(ordinal_in) {}

Platform::~Platform() {}

bool Platform::Initialized() const { return true; }

absl::Status Platform::Initialize(
const std::map<std::string, std::string> &platform_options) {
if (!platform_options.empty()) {
return absl::UnimplementedError(
"this platform does not support custom initialization");
}
return absl::OkStatus();
}
absl::Status Platform::Initialize() { return absl::OkStatus(); }

} // namespace stream_executor
23 changes: 4 additions & 19 deletions third_party/xla/xla/stream_executor/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct StreamExecutorConfig {
// Abstract base class for a platform registered with the PlatformManager.
class Platform {
public:
virtual ~Platform();
virtual ~Platform() = default;

// A platform ID is a unique identifier for each registered platform type -
// each platform is required to expose an ID to ensure unique registration and
Expand Down Expand Up @@ -93,14 +93,9 @@ class Platform {
// Returns true iff the platform has been initialized.
virtual bool Initialized() const;

// Initializes the platform with a custom set of options. The platform must be
// initialized before obtaining StreamExecutor objects. The interpretation of
// the platform_options argument is implementation specific. This method may
// return an error if unrecognized options are provided. If using
// PlatformManager, this method will be called automatically by
// InitializePlatformWithId/InitializePlatformWithName.
virtual absl::Status Initialize(
const std::map<std::string, std::string>& platform_options);
// Initializes the platform. The platform must be initialized before obtaining
// StreamExecutor objects.
virtual absl::Status Initialize();

// Returns a populated DeviceDescription for the device at the given ordinal.
// This should not require device initialization. Note that not all platforms
Expand Down Expand Up @@ -130,16 +125,6 @@ class Platform {
// Ownership IS transferred to the caller.
virtual absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
const StreamExecutorConfig& config) = 0;

protected:
// SE_DISALLOW_COPY_AND_ASSIGN declares a constructor, which suppresses the
// presence of the default constructor. This statement re-enables it, which
// simplifies subclassing.
Platform() = default;

private:
Platform(const Platform&) = delete;
void operator=(const Platform&) = delete;
};

} // namespace stream_executor
Expand Down
17 changes: 8 additions & 9 deletions third_party/xla/xla/stream_executor/platform_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class PlatformManagerImpl {
bool initialize_platform)
ABSL_LOCKS_EXCLUDED(mu_);

absl::StatusOr<Platform*> InitializePlatformWithId(
const Platform::Id& id, const std::map<std::string, std::string>& options)
absl::StatusOr<Platform*> InitializePlatformWithId(const Platform::Id& id)
ABSL_LOCKS_EXCLUDED(mu_);

absl::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
Expand Down Expand Up @@ -126,7 +125,7 @@ absl::StatusOr<Platform*> PlatformManagerImpl::PlatformWithName(

TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
if (initialize_platform && !platform->Initialized()) {
TF_RETURN_IF_ERROR(platform->Initialize({}));
TF_RETURN_IF_ERROR(platform->Initialize());
}

return platform;
Expand All @@ -138,14 +137,14 @@ absl::StatusOr<Platform*> PlatformManagerImpl::PlatformWithId(

TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
if (initialize_platform && !platform->Initialized()) {
TF_RETURN_IF_ERROR(platform->Initialize({}));
TF_RETURN_IF_ERROR(platform->Initialize());
}

return platform;
}

absl::StatusOr<Platform*> PlatformManagerImpl::InitializePlatformWithId(
const Platform::Id& id, const std::map<std::string, std::string>& options) {
const Platform::Id& id) {
absl::MutexLock lock(&mu_);

TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
Expand All @@ -154,7 +153,7 @@ absl::StatusOr<Platform*> PlatformManagerImpl::InitializePlatformWithId(
absl::StrFormat("platform with id %p is already initialized", id));
}

TF_RETURN_IF_ERROR(platform->Initialize(options));
TF_RETURN_IF_ERROR(platform->Initialize());

return platform;
}
Expand All @@ -170,7 +169,7 @@ absl::StatusOr<std::vector<Platform*>> PlatformManagerImpl::PlatformsWithFilter(
Platform* platform = entry.second;
if (filter(platform)) {
if (initialize_platform && !platform->Initialized()) {
TF_RETURN_IF_ERROR(platform->Initialize({}));
TF_RETURN_IF_ERROR(platform->Initialize());
}
platforms.push_back(platform);
}
Expand Down Expand Up @@ -245,8 +244,8 @@ PlatformManagerImpl& Impl() {
}

/*static*/ absl::StatusOr<Platform*> PlatformManager::InitializePlatformWithId(
const Platform::Id& id, const std::map<std::string, std::string>& options) {
return Impl().InitializePlatformWithId(id, options);
const Platform::Id& id) {
return Impl().InitializePlatformWithId(id);
}

/*static*/ absl::StatusOr<std::vector<Platform*>>
Expand Down
8 changes: 1 addition & 7 deletions third_party/xla/xla/stream_executor/platform_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ limitations under the License.
#define XLA_STREAM_EXECUTOR_PLATFORM_MANAGER_H_

#include <functional>
#include <map>
#include <memory>
#include <string>
#include <vector>

#include "absl/status/status.h"
Expand Down Expand Up @@ -104,15 +102,11 @@ class PlatformManager {
// Retrieves the platform registered with the given platform id (an opaque,
// comparable value provided by the Platform's Id() method).
//
// The platform will be initialized with the given options. If the platform
// was already initialized, an error will be returned.
//
// If the requested platform is not registered, an error status is returned.
// Ownership of the platform is NOT transferred to the caller --
// the PlatformManager owns the platforms in a singleton-like fashion.
static absl::StatusOr<Platform*> InitializePlatformWithId(
const Platform::Id& id,
const std::map<std::string, std::string>& options);
const Platform::Id& id);

// Retrieves the platforms satisfying the given filter, i.e. returns true.
// Returned Platforms are always initialized.
Expand Down
4 changes: 1 addition & 3 deletions third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ extern "C" {

SE_Platform* TpuPlatform_New();
void TpuPlatform_Free(SE_Platform* platform);
void TpuPlatform_Initialize(SE_Platform* platform, size_t options_size,
const char** options_key,
const char** options_value, TF_Status* status);
void TpuPlatform_Initialize(SE_Platform* platform, TF_Status* status);
bool TpuPlatform_Initialized(SE_Platform* platform);
SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform,
SE_StreamExecutorConfig* config,
Expand Down
23 changes: 2 additions & 21 deletions third_party/xla/xla/stream_executor/tpu/tpu_platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,10 @@ TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
return tpu_registered_platform;
}

absl::Status TpuPlatform::Initialize(
const std::map<std::string, std::string>& platform_options) {
absl::Status TpuPlatform::Initialize() {
StatusHelper status;

size_t options_size = platform_options.size();
const char** options_key =
static_cast<const char**>(malloc(sizeof(const char*) * options_size));
const char** options_value =
static_cast<const char**>(malloc(sizeof(const char*) * options_size));

size_t i = 0;
for (const auto& option : platform_options) {
options_key[i] = option.first.c_str();
options_value[i] = option.second.c_str();
i++;
}

stream_executor::tpu::ExecutorApiFn()->TpuPlatform_InitializeFn(
platform_, options_size, options_key, options_value, status.c_status);

free(options_key);
free(options_value);

platform_, status.c_status);
return status.status();
}

Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/stream_executor/tpu/tpu_platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {

bool Initialized() const override;

absl::Status Initialize(
const std::map<std::string, std::string>& platform_options) override;
absl::Status Initialize() override;

absl::Status Reset(bool only_tear_down, absl::string_view reason) override {
LOG(FATAL) << "Not yet implemented";
Expand Down

0 comments on commit 431e033

Please sign in to comment.