Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Compatible with TensorFlow 2.13 and below. Also fix some HKV feature. #378

Merged
merged 6 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions build_deps/toolchains/gpu/cuda_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,11 @@ _PYTHON_BIN_PATH = "PYTHON_BIN_PATH"

_DEFAULT_CUDA_COMPUTE_CAPABILITIES = {
"11.0": [
"6.0",
"6.1",
"7.0",
"7.5",
"8.0",
],
"10.0": [
"6.0",
"6.1",
"7.0",
"7.5",
],
Expand All @@ -77,8 +73,6 @@ _DEFAULT_CUDA_COMPUTE_CAPABILITIES.update(

_DEFAULT_CUDA_COMPUTE_CAPABILITIES.update(
{"11.{}".format(v): [
"6.0",
"6.1",
"7.0",
"7.5",
"8.0",
Expand All @@ -90,8 +84,6 @@ _DEFAULT_CUDA_COMPUTE_CAPABILITIES.update(

_DEFAULT_CUDA_COMPUTE_CAPABILITIES.update(
{"12.{}".format(v): [
"6.0",
"6.1",
"7.0",
"7.5",
"8.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,12 @@ def export_to_savedmodel(model, savedmodel_dir):

# TFRA modify the Keras save function with a patch.
# !!!! Run save_model function in all rank !!!!
de.keras.models.de_save_model(model,
savedmodel_dir,
overwrite=True,
include_optimizer=True,
save_traces=True,
options=save_options)
de.keras.models.save_model(model,
savedmodel_dir,
overwrite=True,
include_optimizer=True,
save_traces=True,
options=save_options)


def export_for_serving(model, export_dir):
Expand Down Expand Up @@ -521,7 +521,7 @@ def serve(*args, **kwargs):

# TFRA modify the Keras save function with a patch.
# !!!! Run save_model function in all rank !!!!
de.keras.models.de_save_model(
de.keras.models.save_model(
model,
export_dir,
overwrite=True,
Expand Down Expand Up @@ -572,7 +572,7 @@ def train():
# horovod callback is used to broadcast the value generated by initializer of rank0.
hvd_opt_init_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback(
root_rank=0)
ckpt_callback = de.keras.callbacks.DEHvdModelCheckpoint(
ckpt_callback = de.keras.callbacks.ModelCheckpoint(
filepath=FLAGS.model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
options=save_options)
callbacks_list = [hvd_opt_init_callback, ckpt_callback]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ In addition, we also provide parameter initialization and save callback related

[`dynamic_embedding.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)

[`dynamic_embedding.keras.callbacks.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)
[`dynamic_embedding.keras.callbacks.ModelCheckpoint.`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py)

[`dynamic_embedding.keras.models.de_save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py)
[`dynamic_embedding.keras.models.save_model`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py)

[`dynamic_embedding.train.DEHvdModelCheckpoint`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py)
[`dynamic_embedding.train.ModelCheckpoint.`](https://github.com/tensorflow/recommenders-addons/blob/master/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py)

You could inherit the `HvdAllToAllEmbedding` class to implement a custom embedding
layer with other fixed shape output.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,11 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
size_t last_hint_size_;
size_t runtime_dim_;
mutable mutex mu_;
#if TF_VERSION_INTEGER >= 2130 // 2.13.0
gpu::TableWrapperBase<K, V>* table_ = nullptr TF_GUARDED_BY(mu_);
#else
gpu::TableWrapperBase<K, V>* table_ = nullptr GUARDED_BY(mu_);
#endif
};

} // namespace lookup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.

#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.h"
#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h"
#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"

#define EIGEN_USE_GPU

Expand All @@ -37,7 +36,11 @@ limitations under the License.
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/gpu_device_functions.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#if TF_VERSION_INTEGER >= 2110 // 2.11.0
#include "tensorflow/compiler/xla/stream_executor/stream.h"
#else
#include "tensorflow/stream_executor/stream.h"
#endif

namespace tensorflow {

Expand Down Expand Up @@ -187,14 +190,14 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
is_full_default);
CUDA_CHECK(cudaStreamSynchronize(stream));
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaFreeAsync(d_status, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
}

return Status::OK();
return TFOkStatus;
}

Status FindWithExists(OpKernelContext* ctx, const Tensor& d_keys,
Expand Down Expand Up @@ -222,13 +225,13 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
(V*)(default_value.tensor_data().data()), stream,
is_full_default);
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaStreamSynchronize(stream));
}

return Status::OK();
return TFOkStatus;
}

Status Insert(OpKernelContext* ctx, const Tensor& keys,
Expand All @@ -241,12 +244,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
table_->upsert((const K*)keys.tensor_data().data(),
(const V*)(values.tensor_data().data()), len, stream);
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaStreamSynchronize(stream));

return Status::OK();
return TFOkStatus;
}

Status Accum(OpKernelContext* ctx, const Tensor& keys,
Expand All @@ -260,12 +263,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
(const V*)(values_or_deltas.tensor_data().data()),
(const bool*)exists.tensor_data().data(), len, stream);
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaStreamSynchronize(stream));

return Status::OK();
return TFOkStatus;
}

Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
Expand All @@ -285,14 +288,14 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
try {
table_->remove((const K*)d_keys, len, stream);
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaFreeAsync(d_keys, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
}

return Status::OK();
return TFOkStatus;
}

Status Clear(OpKernelContext* ctx) {
Expand All @@ -302,11 +305,11 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
try {
table_->clear(stream);
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaStreamSynchronize(stream));
return Status::OK();
return TFOkStatus;
}

Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
Expand Down Expand Up @@ -345,7 +348,7 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
table_->upsert((const K*)d_keys, (const V*)d_values, len, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
if (keys_attr.type != cudaMemoryTypeDevice) {
Expand All @@ -355,7 +358,7 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
CUDA_CHECK(cudaFree(d_values));
}
}
return Status::OK();
return TFOkStatus;
}

Status ExportValues(OpKernelContext* ctx) override {
Expand Down Expand Up @@ -397,12 +400,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
d_dump_counter, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaFreeAsync(d_dump_counter, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return Status::OK();
return TFOkStatus;
}

Status ExportValuesWithScores(OpKernelContext* ctx) {
Expand Down Expand Up @@ -448,12 +451,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
len, d_dump_counter, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaFreeAsync(d_dump_counter, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return Status::OK();
return TFOkStatus;
}

Status ExportKeysAndScores(OpKernelContext* ctx, size_t split_size) {
Expand Down Expand Up @@ -486,12 +489,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
static_cast<size_t>(size), split_size,
stream);
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
}
CUDA_CHECK(cudaStreamSynchronize(stream));
return Status::OK();
return TFOkStatus;
}

Status ExportValuesToFile(OpKernelContext* ctx, const string filepath,
Expand All @@ -507,12 +510,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
table_->dump_to_file(fs, filepath, runtime_dim_, stream, buffer_size,
append_to_file);
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaStreamSynchronize(stream));

return Status::OK();
return TFOkStatus;
}

Status ImportValuesFromFile(OpKernelContext* ctx, const string& dirpath,
Expand Down Expand Up @@ -564,11 +567,11 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
buffer_size);
}
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return gpu::ReturnInternalErrorStatus(e.what());
}
}
CUDA_CHECK(cudaStreamSynchronize(stream));
return Status::OK();
return TFOkStatus;
}

DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
Expand All @@ -580,7 +583,11 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
TensorShape value_shape_;
size_t runtime_dim_;
mutable mutex mu_;
#if TF_VERSION_INTEGER >= 2130 // 2.13.0
gpu::TableWrapper<K, V>* table_ = nullptr TF_GUARDED_BY(mu_);
#else
gpu::TableWrapper<K, V>* table_ = nullptr GUARDED_BY(mu_);
#endif
};

} // namespace lookup
Expand Down Expand Up @@ -1041,6 +1048,7 @@ REGISTER_KERNEL(int64, int8);
REGISTER_KERNEL(int64, int32);
REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, Eigen::half);
REGISTER_KERNEL(int64, Eigen::bfloat16);

#undef REGISTER_KERNEL

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@ limitations under the License.
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"

namespace tensorflow {
namespace recommenders_addons {
namespace lookup {
namespace gpu {

inline Status ReturnInternalErrorStatus(const char* const str) {
#if TF_VERSION_INTEGER >= 2130 /* 2.13.0 */
return Status(absl::StatusCode::kInternal, str);
#else
return Status(tensorflow::error::INTERNAL, str);
#endif
}

template <typename K, typename V, typename S>
class KVOnlyFile : public nv::merlin::BaseKVFile<K, V, S> {
public:
Expand Down Expand Up @@ -173,7 +182,7 @@ class RandomKVFile : public nv::merlin::BaseKVFile<K, V, S> {
auto has_atomic_move_ret =
fs_->HasAtomicMove(filepath_, &has_atomic_move);
bool need_tmp_file =
(has_atomic_move == false) || (has_atomic_move_ret != Status::OK());
(has_atomic_move == false) || (has_atomic_move_ret != TFOkStatus);

if (!need_tmp_file) {
key_tmpfilepath = key_filepath;
Expand All @@ -193,7 +202,7 @@ class RandomKVFile : public nv::merlin::BaseKVFile<K, V, S> {
fs_->NewWritableFile(value_tmpfilepath, &value_writer_));
}
}
return Status::OK();
return TFOkStatus;
}

void close() {
Expand Down Expand Up @@ -445,9 +454,9 @@ class TableWrapper {
try {
table_->init(mkv_options_, allocator);
} catch (std::runtime_error& e) {
return Status(tensorflow::error::INTERNAL, e.what());
return ReturnInternalErrorStatus(e.what());
}
return Status::OK();
return TFOkStatus;
}

~TableWrapper() { delete table_; }
Expand Down Expand Up @@ -534,7 +543,7 @@ class TableWrapper {
string valuefile = filepath + "-values";
string scorefile = filepath + "-scores";
bool has_scores = false;
Status status = Status::OK();
Status status = TFOkStatus;

if (is_valid_scores(keyfile, scorefile)) {
wfile.reset(new nv::merlin::LocalKVFile<K, V, uint64_t>);
Expand Down Expand Up @@ -585,7 +594,7 @@ class TableWrapper {
string valuefile = filepath + "-values";
string scorefile = filepath + "-scores";
bool has_scores = false;
Status status = Status::OK();
Status status = TFOkStatus;

if (is_valid_scores(keyfile, scorefile)) {
rfile.reset(new nv::merlin::LocalKVFile<K, V, uint64_t>);
Expand Down
Loading
Loading