Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions runtime/backend/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,44 @@ namespace runtime {

PyTorchBackendInterface::~PyTorchBackendInterface() {}

// TODO(T128866626): Remove global static variables.
// We want to be able to run multiple Executor instances
// and having a global registration isn't a viable solution
// in the long term.
BackendRegistry& getBackendRegistry();
BackendRegistry& getBackendRegistry() {
static BackendRegistry backend_reg;
return backend_reg;
}
namespace {

PyTorchBackendInterface* get_backend_class(const char* name) {
return getBackendRegistry().get_backend_class(name);
}
// The max number of backends that can be registered globally.
constexpr size_t kMaxRegisteredBackends = 16;

// TODO(T128866626): Remove global static variables. We want to be able to run
// multiple Executor instances and having a global registration isn't a viable
// solution in the long term.

/// Global table of registered backends.
Backend registered_backends[kMaxRegisteredBackends];

PyTorchBackendInterface* BackendRegistry::get_backend_class(const char* name) {
for (size_t idx = 0; idx < registrationTableSize_; idx++) {
Backend backend = backend_table_[idx];
if (strcmp(backend.name_, name) == 0) {
return backend.interface_ptr_;
/// The number of backends registered in the table.
size_t num_registered_backends = 0;

} // namespace

PyTorchBackendInterface* get_backend_class(const char* name) {
for (size_t i = 0; i < num_registered_backends; i++) {
Backend backend = registered_backends[i];
if (strcmp(backend.name, name) == 0) {
return backend.backend;
}
}
return nullptr;
}

Error register_backend(const Backend& backend) {
return getBackendRegistry().register_backend(backend);
}

Error BackendRegistry::register_backend(const Backend& backend) {
if (registrationTableSize_ >= kRegistrationTableMaxSize) {
if (num_registered_backends >= kMaxRegisteredBackends) {
return Error::Internal;
}

// Check if the name already exists in the table
if (this->get_backend_class(backend.name_) != nullptr) {
if (get_backend_class(backend.name) != nullptr) {
return Error::InvalidArgument;
}

backend_table_[registrationTableSize_++] = backend;
registered_backends[num_registered_backends++] = backend;
return Error::Ok;
}

Expand Down
52 changes: 10 additions & 42 deletions runtime/backend/interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,46 +110,6 @@ class PyTorchBackendInterface {
virtual void destroy(ET_UNUSED DelegateHandle* handle) const {}
};

struct Backend {
const char* name_;
PyTorchBackendInterface* interface_ptr_;
};

// The max number of backends that can be registered in
// an app. It's hard coded to 16 because it's not estimated
// to have more than 16 backends in a system. Each table
// element has two pointers, represented by Backend struct.
// The memory overhead for this table is minimum (only a few bytes).
constexpr size_t kRegistrationTableMaxSize = 16;

class BackendRegistry {
public:
BackendRegistry() : registrationTableSize_(0) {}

/**
* Registers the Backend object (i.e. string name and PyTorchBackendInterface
* pair) so that it could be called via the name during the runtime.
* @param[in] backend Backend object of the user-defined backend delegate.
* @retval Error code representing whether registration was successful.
*/
ET_NODISCARD Error register_backend(const Backend& backend);

/**
* Returns the corresponding object pointer for a given string name.
* The mapping is populated using register_backend method.
*
* @param[in] name Name of the user-defined backend delegate.
* @retval Pointer to the appropriate object that implements
* PyTorchBackendInterface. Nullptr if it can't find anything
* with the given name.
*/
PyTorchBackendInterface* get_backend_class(const char* name);

private:
Backend backend_table_[kRegistrationTableMaxSize];
size_t registrationTableSize_;
};

/**
* Returns the corresponding object pointer for a given string name.
* The mapping is populated using register_backend method.
Expand All @@ -161,6 +121,16 @@ class BackendRegistry {
*/
PyTorchBackendInterface* get_backend_class(const char* name);

/**
* A named instance of a backend.
*/
struct Backend {
/// The name of the backend. Must match the string used in the PTE file.
const char* name;
/// The instance of the backend to use when loading and executing programs.
PyTorchBackendInterface* backend;
};

/**
* Registers the Backend object (i.e. string name and PyTorchBackendInterface
* pair) so that it could be called via the name during the runtime.
Expand All @@ -178,11 +148,9 @@ namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::runtime::Backend;
using ::executorch::runtime::BackendRegistry;
using ::executorch::runtime::CompileSpec;
using ::executorch::runtime::DelegateHandle;
using ::executorch::runtime::get_backend_class;
// using ::executorch::runtime::kRegistrationTableMaxSize;
using ::executorch::runtime::PyTorchBackendInterface;
using ::executorch::runtime::register_backend;
using ::executorch::runtime::SizedBuffer;
Expand Down
Loading