Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Compatible with TensorFlow 2.15 #386

Merged
merged 4 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -273,32 +273,32 @@ Status ParseJsonConfig(const std::string *const redis_config_abs_dir,
} \
}

#define ReadArrayJsonToParams(json_key_name, json_val_type) \
{ \
json_hangar_it = json_hangar.find(#json_key_name); \
if (json_hangar_it != json_hangar.end()) { \
if (json_hangar_it->second->type == json_array) { \
redis_connection_params->json_key_name.clear(); \
for (unsigned i = 0; i < json_hangar_it->second->u.array.length; \
++i) { \
value_depth1 = json_hangar_it->second->u.array.values[i]; \
if (value_depth1->type == json_##json_val_type) { \
redis_connection_params->redis_host_port.push_back( \
value_depth1->u.json_val_type); \
} else { \
LOG(ERROR) << #json_key_name " should be json " #json_val_type \
" array"; \
return ReturnInvalidArgumentStatus( \
" should be json " #json_val_type " array"); \
} \
} \
} else { \
LOG(ERROR) << #json_key_name " should be json " #json_val_type \
" array"; \
return ReturnInvalidArgumentStatus(" should be json " #json_val_type \
" array"); \
} \
} \
#define ReadArrayJsonToParams(json_key_name, json_val_type) \
{ \
json_hangar_it = json_hangar.find(#json_key_name); \
if (json_hangar_it != json_hangar.end()) { \
if (json_hangar_it->second->type == json_array) { \
redis_connection_params->json_key_name.clear(); \
for (unsigned i = 0; i < json_hangar_it->second->u.array.length; \
++i) { \
value_depth1 = json_hangar_it->second->u.array.values[i]; \
if (value_depth1->type == json_##json_val_type) { \
redis_connection_params->redis_host_port.push_back( \
value_depth1->u.json_val_type); \
} else { \
LOG(ERROR) << #json_key_name " should be json " #json_val_type \
" array"; \
return ReturnInvalidArgumentStatus( \
#json_key_name " should be json " #json_val_type " array"); \
} \
} \
} else { \
LOG(ERROR) << #json_key_name " should be json " #json_val_type \
" array"; \
return ReturnInvalidArgumentStatus( \
#json_key_name " should be json " #json_val_type " array"); \
} \
} \
}

#define ReadStringArrayJsonToParams(json_key_name) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,12 @@ class RedisTableOfTensors final : public LookupInterface {
hscan_reply->elements > 1) {
kvs_reply = hscan_reply->element[1];
// fill Tensor keys and values
if (kvs_reply->elements < 2 && cursor == 0) {
// Find nothing in Redis
break;
}
if constexpr (!std::is_same<V, tstring>::value) {
if (kvs_reply->element[0]->len !=
if (kvs_reply->element[1]->len !=
runtime_value_dim_ * sizeof(V)) {
return errors::InvalidArgument(
"Embedding dim in Redis server is not equal to the OP "
Expand Down Expand Up @@ -1035,8 +1039,12 @@ class RedisTableOfTensors final : public LookupInterface {
}
kvs_reply = hscan_reply->element[1];
// fill Tensor keys and values
if (kvs_reply->elements < 2 && cursor == 0) {
// Find nothing in Redis
break;
}
if constexpr (!std::is_same<V, tstring>::value) {
if (kvs_reply->element[0]->len != runtime_value_dim_ * sizeof(V)) {
if (kvs_reply->element[1]->len != runtime_value_dim_ * sizeof(V)) {
return errors::InvalidArgument(
"Embedding dim in Redis server is not equal to the OP runtime "
"dim.");
Expand Down Expand Up @@ -1146,8 +1154,12 @@ class RedisTableOfTensors final : public LookupInterface {
}
kvs_reply = hscan_reply->element[1];
// fill Tensor keys and values
if (kvs_reply->elements < 2 && cursor == 0) {
// Find nothing in Redis
break;
}
if constexpr (!std::is_same<V, tstring>::value) {
if (kvs_reply->element[0]->len != runtime_value_dim_ * sizeof(V)) {
if (kvs_reply->element[1]->len != runtime_value_dim_ * sizeof(V)) {
return errors::InvalidArgument(
"Embedding dim in Redis server is not equal to the OP runtime "
"dim.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

from tensorflow.python.distribute import distribute_lib
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
try: # tf version >= 2.14.0
from tensorflow.python.distribute import distribute_lib as distribute_ctx
assert hasattr(distribute_ctx, 'has_strategy')
except:
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import values_util
from tensorflow.python.framework import ops
from tensorflow.python.eager import tape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope
try: # tf version >= 2.14.0
from tensorflow.python.ops.array_ops_stack import stack
except:
from tensorflow.python.ops.array_ops import stack
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
from tensorflow.python.training import server_lib
Expand Down Expand Up @@ -309,7 +313,7 @@ def test_max_norm_nontrivial(self):
embedding = de.embedding_lookup(embeddings, ids, max_norm=2.0)
norms = math_ops.sqrt(
math_ops.reduce_sum(embedding_no_norm * embedding_no_norm, axis=1))
normalized = embedding_no_norm / array_ops.stack([norms, norms], axis=1)
normalized = embedding_no_norm / stack([norms, norms], axis=1)
self.assertAllCloseAccordingToType(embedding.eval(),
2 * self.evaluate(normalized))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training import training
from tensorflow.python.training.tracking import util as track_util
try: # tf version >= 2.14.0
from tensorflow.python.checkpoint.checkpoint import Checkpoint
except:
from tensorflow.python.training.tracking.util import Checkpoint
from tensorflow.python.util import compat
from tensorflow_estimator.python.estimator import estimator
from tensorflow_estimator.python.estimator import estimator_lib
Expand Down Expand Up @@ -1326,13 +1329,13 @@ def _loss_fn():
*sorted(zip(keys1, vals1), key=lambda x: x[0], reverse=False))
slot_keys_and_vals1 = [sv.export() for sv in model1.slot_vars]

ckpt1 = track_util.Checkpoint(model=model1, optimizer=model1.optmz)
ckpt1 = Checkpoint(model=model1, optimizer=model1.optmz)
ckpt_dir = self.get_temp_dir()
model_path = ckpt1.save(ckpt_dir)
del model1

model2 = TestModel()
ckpt2 = track_util.Checkpoint(model=model2, optimizer=model2.optmz)
ckpt2 = Checkpoint(model=model2, optimizer=model2.optmz)
model2.train(features) # Pre-build trace before restore.
ckpt2.restore(model_path)
loss2 = model2(features)
Expand Down
Loading
Loading