Skip to content

Commit

Permalink
Embedding layer looks up embedding from PS (#1435)
Browse files Browse the repository at this point in the history
* Embedding layer looks up embedding from PS

* rename variable

* follow comments

* minor edits

* rename worker_mnist_test.py
  • Loading branch information
mhaoli committed Nov 11, 2019
1 parent 067f3e6 commit ee9b9b7
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 14 deletions.
8 changes: 6 additions & 2 deletions elasticdl/python/common/hash_utils.py
@@ -1,6 +1,10 @@
import hashlib


def string_to_id(name, num):
def string_to_id(name, bucket_num):
h = hashlib.sha256(name.encode("utf-8"))
return int(h.hexdigest(), base=32) % num
return int(h.hexdigest(), base=32) % bucket_num


def int_to_id(number, bucket_num):
return number % bucket_num
17 changes: 16 additions & 1 deletion elasticdl/python/elasticdl/layers/embedding.py
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.combiner = combiner
self.embedding_service_endpoint = embedding_service_endpoint
self.tape = None
self.lookup_func = None
self._lookup_embedding_func = None

self._embedding_and_ids_eagerly = []

Expand Down Expand Up @@ -132,6 +132,10 @@ def get_key(name_list):
def lookup_embedding(self, unique_ids):
ids = unique_ids.numpy()
self._check_id_valid(ids)
if self._lookup_embedding_func:
embedding_vectors = self._lookup_embedding_func(self._name, ids)
return embedding_vectors

keys = [Embedding.get_key([self._name, id]) for id in ids]
(
embedding_vectors,
Expand Down Expand Up @@ -299,6 +303,17 @@ def set_tape(self, tape):
def set_endpoint(self, endpoint):
self.embedding_service_endpoint = endpoint

def set_lookup_embedding_func(self, func):
"""Sets function for looking up embeddings in the PS.
Args:
func: The function used to look up embeddings. The arguments of
are `(layer_name, embedding_id_list)`, where `layer_name` is
the name of embedding layer, and `embedding_id_list` is a list
of embedding ids to be looked up.
"""
self._lookup_embedding_func = func

@property
def embedding_and_ids(self):
"""
Expand Down
Expand Up @@ -7,8 +7,9 @@

from elasticdl.proto import elasticdl_pb2
from elasticdl.python.common.constants import GRPC
from elasticdl.python.common.hash_utils import string_to_id
from elasticdl.python.common.hash_utils import int_to_id, string_to_id
from elasticdl.python.common.model_utils import get_model_spec
from elasticdl.python.ps.embedding_table import EmbeddingTable
from elasticdl.python.ps.parameter_server import ParameterServer
from elasticdl.python.tests.test_utils import PserverArgs
from elasticdl.python.worker.worker import Worker
Expand Down Expand Up @@ -226,6 +227,45 @@ def test_compare_mnist_train(self):
for w, l in zip(worker_results, local_results):
self.assertTupleEqual(w, l)

def test_worker_pull_embedding(self):
worker = Worker(
worker_id=0,
job_type=elasticdl_pb2.TRAINING,
minibatch_size=self._batch_size,
model_zoo=self._model_zoo_path,
model_def=self._model_def,
ps_channels=self._channel,
)

# Test lookup embedding vectors that do not exist
layers = ["test-2", "test-2-slot"]
ids = [3, 5, 1, 6, 10, 2, 1, 2, 4, 7, 9]
embedding_table_args = [
(layers[0], 8, "uniform", False),
(layers[1], 8, 3.3, True),
]

# initialize embedding table object
for pserver in self._pserver:
for layer, table_args in zip(layers, embedding_table_args):
pserver.parameters.embedding_params[layer] = EmbeddingTable(
*table_args
)

result_dict = {}
for layer in layers:
embedding = worker.pull_embedding_vector(layer, ids)
result_dict[layer] = embedding

for layer in layers:
expected_result = []
for embedding_id in ids:
ps_id = int_to_id(embedding_id, len(self._pserver))
table = self._pserver[ps_id].parameters.embedding_params[layer]
expected_result.append(table.get([embedding_id]))
expected_result = np.concatenate(expected_result)
self.assertTrue(np.allclose(expected_result, result_dict[layer]))


if __name__ == "__main__":
unittest.main()
49 changes: 39 additions & 10 deletions elasticdl/python/worker/worker.py
Expand Up @@ -11,7 +11,7 @@
Mode,
SaveModelConfig,
)
from elasticdl.python.common.hash_utils import string_to_id
from elasticdl.python.common.hash_utils import int_to_id, string_to_id
from elasticdl.python.common.log_utils import default_logger as logger
from elasticdl.python.common.model_handler import ModelHandler
from elasticdl.python.common.model_utils import (
Expand All @@ -24,6 +24,7 @@
Tensor,
emplace_tensor_pb_from_ndarray,
serialize_tensor,
tensor_pb_to_ndarray,
)
from elasticdl.python.elasticdl.layers.embedding import Embedding
from elasticdl.python.worker.task_data_service import TaskDataService
Expand Down Expand Up @@ -88,6 +89,15 @@ def __init__(
max_minibatch_retry_num: The maximum number of a minibatch retry
as its results (e.g. gradients) are not accepted by master.
"""
self._use_multi_ps = False
if isinstance(ps_channels, list):
if len(ps_channels) > 0:
self._use_multi_ps = True
self._ps_stubs = [
elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels
]
self._var_to_ps = {}

self._worker_id = worker_id
self._job_type = job_type
self._minibatch_size = minibatch_size
Expand Down Expand Up @@ -131,15 +141,6 @@ def __init__(
self._non_embed_grads = None
self._evaluation_result = {}

self._use_multi_ps = False
if isinstance(ps_channels, list):
if len(ps_channels) > 0:
self._use_multi_ps = True
self._ps_stubs = [
elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels
]
self._var_to_ps = {}

# TODO: Multiple tests are currently using this function to initialize
# self._model, where the initialization should be done via constructor.
def set_model(self, model_inst):
Expand All @@ -162,6 +163,9 @@ def _init_embedding_layer(self):
self._embedding_layers = find_layer(self._model, Embedding)
for layer in self._embedding_layers:
layer.set_endpoint(self._embedding_service_endpoint)
if self._use_multi_ps:
layer.set_lookup_embedding_func(self.pull_embedding_vector)

self._need_embedding_layer_check = (
True if self._embedding_layers else False
)
Expand Down Expand Up @@ -223,6 +227,31 @@ def get_model_from_ps(self, version, method):
model_version = max(model_version, res.model.version)
self._model_version = model_version

def pull_embedding_vector(self, layer_name, embedding_ids):
"""Pulls and returns embedding vectors ordered by the embedding ids."""
ps_ids = {}
ps_ids_index = {}
for idx, embedding_id in enumerate(embedding_ids):
ps_id = int_to_id(embedding_id, len(self._ps_stubs))
ps_ids.setdefault(ps_id, []).append(embedding_id)
ps_ids_index.setdefault(ps_id, []).append(idx)

embeddings = []
index = []
for ps_id, embedding_ids in ps_ids.items():
req = elasticdl_pb2.PullEmbeddingVectorRequest()
req.name = layer_name
req.ids.extend(embedding_ids)
pb = self._ps_stubs[ps_id].pull_embedding_vector(req)
embeddings.append(tensor_pb_to_ndarray(pb))
index.extend(ps_ids_index[ps_id])
embeddings = np.concatenate(embeddings)

# adjust the order of embedding vectors
new_embeddings = np.empty_like(embeddings)
new_embeddings[index] = embeddings
return new_embeddings

def get_model(self, version, method):
if self._use_multi_ps:
self.get_model_from_ps(version, method)
Expand Down

0 comments on commit ee9b9b7

Please sign in to comment.