From ee9b9b7c3b4174c032cb8300b9889a12c0ee1b49 Mon Sep 17 00:00:00 2001 From: Minghao Li Date: Mon, 11 Nov 2019 15:10:14 +0800 Subject: [PATCH] Embedding layer looks up embedding from PS (#1435) * Embedding layer looks up embedding from PS * rename variable * follow comments * minor edits * rename worker_mnist_test.py --- elasticdl/python/common/hash_utils.py | 8 ++- .../python/elasticdl/layers/embedding.py | 17 ++++++- ...t.py => worker_and_PS_interaction_test.py} | 42 +++++++++++++++- elasticdl/python/worker/worker.py | 49 +++++++++++++++---- 4 files changed, 102 insertions(+), 14 deletions(-) rename elasticdl/python/tests/{worker_mnist_test.py => worker_and_PS_interaction_test.py} (82%) diff --git a/elasticdl/python/common/hash_utils.py b/elasticdl/python/common/hash_utils.py index a57b1cd56..da5036af1 100644 --- a/elasticdl/python/common/hash_utils.py +++ b/elasticdl/python/common/hash_utils.py @@ -1,6 +1,10 @@ import hashlib -def string_to_id(name, num): +def string_to_id(name, bucket_num): h = hashlib.sha256(name.encode("utf-8")) - return int(h.hexdigest(), base=32) % num + return int(h.hexdigest(), base=32) % bucket_num + + +def int_to_id(number, bucket_num): + return number % bucket_num diff --git a/elasticdl/python/elasticdl/layers/embedding.py b/elasticdl/python/elasticdl/layers/embedding.py index 9229cc147..4b8276e5a 100644 --- a/elasticdl/python/elasticdl/layers/embedding.py +++ b/elasticdl/python/elasticdl/layers/embedding.py @@ -61,7 +61,7 @@ def __init__( self.combiner = combiner self.embedding_service_endpoint = embedding_service_endpoint self.tape = None - self.lookup_func = None + self._lookup_embedding_func = None self._embedding_and_ids_eagerly = [] @@ -132,6 +132,10 @@ def get_key(name_list): def lookup_embedding(self, unique_ids): ids = unique_ids.numpy() self._check_id_valid(ids) + if self._lookup_embedding_func: + embedding_vectors = self._lookup_embedding_func(self._name, ids) + return embedding_vectors + keys = [Embedding.get_key([self._name, id]) for id in ids] ( embedding_vectors, @@ -299,6 +303,17 @@ def set_tape(self, tape): def set_endpoint(self, endpoint): self.embedding_service_endpoint = endpoint + def set_lookup_embedding_func(self, func): + """Sets function for looking up embeddings in the PS. + + Args: + func: The function used to look up embeddings. The arguments of + are `(layer_name, embedding_id_list)`, where `layer_name` is + the name of embedding layer, and `embedding_id_list` is a list + of embedding ids to be looked up. + """ + self._lookup_embedding_func = func + @property def embedding_and_ids(self): """ diff --git a/elasticdl/python/tests/worker_mnist_test.py b/elasticdl/python/tests/worker_and_PS_interaction_test.py similarity index 82% rename from elasticdl/python/tests/worker_mnist_test.py rename to elasticdl/python/tests/worker_and_PS_interaction_test.py index 9855380ce..79ad0dbae 100644 --- a/elasticdl/python/tests/worker_mnist_test.py +++ b/elasticdl/python/tests/worker_and_PS_interaction_test.py @@ -7,8 +7,9 @@ from elasticdl.proto import elasticdl_pb2 from elasticdl.python.common.constants import GRPC -from elasticdl.python.common.hash_utils import string_to_id +from elasticdl.python.common.hash_utils import int_to_id, string_to_id from elasticdl.python.common.model_utils import get_model_spec +from elasticdl.python.ps.embedding_table import EmbeddingTable from elasticdl.python.ps.parameter_server import ParameterServer from elasticdl.python.tests.test_utils import PserverArgs from elasticdl.python.worker.worker import Worker @@ -226,6 +227,45 @@ def test_compare_mnist_train(self): for w, l in zip(worker_results, local_results): self.assertTupleEqual(w, l) + def test_worker_pull_embedding(self): + worker = Worker( + worker_id=0, + job_type=elasticdl_pb2.TRAINING, + minibatch_size=self._batch_size, + model_zoo=self._model_zoo_path, + model_def=self._model_def, + ps_channels=self._channel, + ) + + # Test lookup embedding vectors that do not exist + layers = ["test-2", "test-2-slot"] + ids = [3, 5, 1, 6, 10, 2, 1, 2, 4, 7, 9] + embedding_table_args = [ + (layers[0], 8, "uniform", False), + (layers[1], 8, 3.3, True), + ] + + # initialize embedding table object + for pserver in self._pserver: + for layer, table_args in zip(layers, embedding_table_args): + pserver.parameters.embedding_params[layer] = EmbeddingTable( + *table_args + ) + + result_dict = {} + for layer in layers: + embedding = worker.pull_embedding_vector(layer, ids) + result_dict[layer] = embedding + + for layer in layers: + expected_result = [] + for embedding_id in ids: + ps_id = int_to_id(embedding_id, len(self._pserver)) + table = self._pserver[ps_id].parameters.embedding_params[layer] + expected_result.append(table.get([embedding_id])) + expected_result = np.concatenate(expected_result) + self.assertTrue(np.allclose(expected_result, result_dict[layer])) + if __name__ == "__main__": unittest.main() diff --git a/elasticdl/python/worker/worker.py b/elasticdl/python/worker/worker.py index 2968cefe8..3d1d38e42 100644 --- a/elasticdl/python/worker/worker.py +++ b/elasticdl/python/worker/worker.py @@ -11,7 +11,7 @@ Mode, SaveModelConfig, ) -from elasticdl.python.common.hash_utils import string_to_id +from elasticdl.python.common.hash_utils import int_to_id, string_to_id from elasticdl.python.common.log_utils import default_logger as logger from elasticdl.python.common.model_handler import ModelHandler from elasticdl.python.common.model_utils import ( @@ -24,6 +24,7 @@ Tensor, emplace_tensor_pb_from_ndarray, serialize_tensor, + tensor_pb_to_ndarray, ) from elasticdl.python.elasticdl.layers.embedding import Embedding from elasticdl.python.worker.task_data_service import TaskDataService @@ -88,6 +89,15 @@ def __init__( max_minibatch_retry_num: The maximum number of a minibatch retry as its results (e.g. gradients) are not accepted by master. """ + self._use_multi_ps = False + if isinstance(ps_channels, list): + if len(ps_channels) > 0: + self._use_multi_ps = True + self._ps_stubs = [ + elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels + ] + self._var_to_ps = {} + self._worker_id = worker_id self._job_type = job_type self._minibatch_size = minibatch_size @@ -131,15 +141,6 @@ def __init__( self._non_embed_grads = None self._evaluation_result = {} - self._use_multi_ps = False - if isinstance(ps_channels, list): - if len(ps_channels) > 0: - self._use_multi_ps = True - self._ps_stubs = [ - elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels - ] - self._var_to_ps = {} - # TODO: Multiple tests are currently using this function to initialize # self._model, where the initialization should be done via constructor. def set_model(self, model_inst): @@ -162,6 +163,9 @@ def _init_embedding_layer(self): self._embedding_layers = find_layer(self._model, Embedding) for layer in self._embedding_layers: layer.set_endpoint(self._embedding_service_endpoint) + if self._use_multi_ps: + layer.set_lookup_embedding_func(self.pull_embedding_vector) + self._need_embedding_layer_check = ( True if self._embedding_layers else False ) @@ -223,6 +227,31 @@ def get_model_from_ps(self, version, method): model_version = max(model_version, res.model.version) self._model_version = model_version + def pull_embedding_vector(self, layer_name, embedding_ids): + """Pulls and returns embedding vectors ordered by the embedding ids.""" + ps_ids = {} + ps_ids_index = {} + for idx, embedding_id in enumerate(embedding_ids): + ps_id = int_to_id(embedding_id, len(self._ps_stubs)) + ps_ids.setdefault(ps_id, []).append(embedding_id) + ps_ids_index.setdefault(ps_id, []).append(idx) + + embeddings = [] + index = [] + for ps_id, embedding_ids in ps_ids.items(): + req = elasticdl_pb2.PullEmbeddingVectorRequest() + req.name = layer_name + req.ids.extend(embedding_ids) + pb = self._ps_stubs[ps_id].pull_embedding_vector(req) + embeddings.append(tensor_pb_to_ndarray(pb)) + index.extend(ps_ids_index[ps_id]) + embeddings = np.concatenate(embeddings) + + # adjust the order of embedding vectors + new_embeddings = np.empty_like(embeddings) + new_embeddings[index] = embeddings + return new_embeddings + def get_model(self, version, method): if self._use_multi_ps: self.get_model_from_ps(version, method)