Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

r2.13 cherry-pick: Fix TPUExecute for TPU embedding operations. Create temporary device … #60888

Merged
merged 2 commits into from Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -44,8 +44,8 @@ tsl::Status SetTpuOpsStructFns(void* library_handle) { // TENSORFLOW_STATUS_OK

TFTPU_SET_FN(ops_api_fn, TpuExecute_RuntimeInputToPaddedData);

TFTPU_SET_FN(ops_api_fn, SE_DeviceMemoryBase_FreeArray);
TFTPU_SET_FN(ops_api_fn, TpuExecute_GetTpuEmbeddingMemoryWordAddresses);
TFTPU_SET_FN(ops_api_fn, TpuExecute_GetTpuEmbeddingMemoryAllocations);
TFTPU_SET_FN(ops_api_fn, TpuExecute_FreeTpuEmbeddingMemoryAllocations);

TFTPU_SET_FN(ops_api_fn, TpuProgram_New);
TFTPU_SET_FN(ops_api_fn, TpuProgram_Free);
Expand Down
14 changes: 7 additions & 7 deletions tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h
Expand Up @@ -270,12 +270,12 @@ typedef struct TpuExecute_RuntimeInputToPaddedData_Params {
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
TpuExecute_RuntimeInputToPaddedData_Params* params);

TFTPU_CAPI_EXPORT void SE_DeviceMemoryBase_FreeArray(
SE_DeviceMemoryBase* addrs);
TFTPU_CAPI_EXPORT void TpuExecute_GetTpuEmbeddingMemoryAllocations(
int device_ordinal, SE_DeviceMemoryBase** addrs, size_t* addrs_count,
TF_Status* status);

TFTPU_CAPI_EXPORT void TpuExecute_GetTpuEmbeddingMemoryWordAddresses(
SE_StreamExecutor* executor, SE_DeviceMemoryBase** addrs,
size_t* addrs_count, TF_Status* status);
TFTPU_CAPI_EXPORT void TpuExecute_FreeTpuEmbeddingMemoryAllocations(
int device_ordinal, SE_DeviceMemoryBase* addrs);

typedef struct ConfigureDistributedTpuOp_DoWork_Params {
int32_t struct_size;
Expand Down Expand Up @@ -765,8 +765,8 @@ struct TfTpu_OpsApiFn {

TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData);

TFTPU_ADD_FN_IN_STRUCT(SE_DeviceMemoryBase_FreeArray);
TFTPU_ADD_FN_IN_STRUCT(TpuExecute_GetTpuEmbeddingMemoryWordAddresses);
TFTPU_ADD_FN_IN_STRUCT(TpuExecute_GetTpuEmbeddingMemoryAllocations);
TFTPU_ADD_FN_IN_STRUCT(TpuExecute_FreeTpuEmbeddingMemoryAllocations);

TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
Expand Down
18 changes: 11 additions & 7 deletions tensorflow/core/tpu/tpu_execute.cc
Expand Up @@ -512,17 +512,21 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
// (which needs to be free'd once the function terminates).
SE_DeviceMemoryBase* device_memory_addrs;
size_t device_memory_addrs_count;
auto device_memory_cleanup = absl::MakeCleanup([&device_memory_addrs]() {
stream_executor::tpu::OpsApiFn()->SE_DeviceMemoryBase_FreeArrayFn(
device_memory_addrs);
});
auto device_memory_cleanup =
absl::MakeCleanup([device_memory_addrs, node_context]() {
if (device_memory_addrs != nullptr) {
stream_executor::tpu::OpsApiFn()
->TpuExecute_FreeTpuEmbeddingMemoryAllocationsFn(
node_context->device_ordinal(), device_memory_addrs);
}
});

SE_StreamExecutor executor{stream->parent()};
StatusHelper status;
stream_executor::tpu::OpsApiFn()
->TpuExecute_GetTpuEmbeddingMemoryWordAddressesFn(
&executor, &device_memory_addrs, &device_memory_addrs_count,
status.c_status);
->TpuExecute_GetTpuEmbeddingMemoryAllocationsFn(
node_context->device_ordinal(), &device_memory_addrs,
&device_memory_addrs_count, status.c_status);
if (!status.ok()) {
return status.status();
}
Expand Down