Skip to content

Commit

Permalink
variables from different ps have its own model_version (#1604)
Browse files Browse the repository at this point in the history
* variables from different ps have its own model_version

* fix test failure

* fix unittest

* no need to check model version
  • Loading branch information
skydoorkai committed Jan 3, 2020
1 parent bf2ce91 commit f1ab88d
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 25 deletions.
6 changes: 5 additions & 1 deletion elasticdl/proto/elasticdl.proto
Expand Up @@ -111,6 +111,10 @@ service Master {
rpc report_version(ReportVersionRequest) returns (google.protobuf.Empty);
}

message PullVariableRequest {
int32 current_model_version = 1;
}

message PullVariableResponse {
bool model_init_status = 1;
Model model = 2;
Expand All @@ -133,7 +137,7 @@ message PushGradientResponse {

// PS service
service Pserver {
rpc pull_variable(google.protobuf.Empty) returns (PullVariableResponse);
rpc pull_variable(PullVariableRequest) returns (PullVariableResponse);
rpc pull_embedding_vector(PullEmbeddingVectorRequest) returns (Tensor);
rpc push_model(Model) returns (google.protobuf.Empty);
rpc push_embedding_info(Model) returns (google.protobuf.Empty);
Expand Down
8 changes: 7 additions & 1 deletion elasticdl/python/common/args.py
Expand Up @@ -316,7 +316,13 @@ def add_train_params(parser):
default="",
help="The path to save the final trained model",
)

parser.add_argument(
"--sync_version_tolerance",
type=int,
help="The maximum model version difference between reported gradients "
"and PS that synchronous SGD can accepts.",
default=0,
)
add_bool_param(
parser=parser,
name="--use_async",
Expand Down
2 changes: 2 additions & 0 deletions elasticdl/python/ps/parameter_server.py
Expand Up @@ -24,6 +24,7 @@ def __init__(self, args):
self.logger = get_logger("PS", level=args.log_level.upper())
self.grads_to_wait = args.grads_to_wait
self.lr_staleness_modulation = args.lr_staleness_modulation
self.sync_version_tolerance = args.sync_version_tolerance
self.use_async = args.use_async
self.port = args.port
model_module = load_module(
Expand Down Expand Up @@ -113,6 +114,7 @@ def prepare(self):
self.optimizer,
self.lr_scheduler,
lr_staleness_modulation=self.lr_staleness_modulation,
sync_version_tolerance=self.sync_version_tolerance,
use_async=self.use_async,
evaluation_steps=self.evaluation_steps,
master_channel=self.master_channel,
Expand Down
2 changes: 1 addition & 1 deletion elasticdl/python/ps/parameters.py
Expand Up @@ -131,7 +131,7 @@ def init_from_model_pb(self, model_pb):
embeddings_pb = model_pb.embedding_table_info
self.init_embedding_params(embeddings_pb)
self._restore_params_from_pb(tensors_pb)
self.version = model_pb.version
self.version = max(0, model_pb.version)
self.init_status = True
return True
return False
Expand Down
17 changes: 12 additions & 5 deletions elasticdl/python/ps/servicer.py
Expand Up @@ -22,6 +22,7 @@ def __init__(
optimizer,
lr_scheduler="",
lr_staleness_modulation=False,
sync_version_tolerance=0,
use_async=False,
evaluation_steps=0,
master_channel=None,
Expand All @@ -39,6 +40,7 @@ def __init__(
self._optimizer = optimizer
self._lr_scheduler = lr_scheduler
self._lr_staleness_modulation = lr_staleness_modulation
self._sync_version_tolerance = sync_version_tolerance
self._use_async = use_async
self._eval_steps = evaluation_steps
self._checkpoint_saver = checkpoint_saver
Expand All @@ -65,10 +67,12 @@ def pull_variable(self, request, _):
if not self._use_async:
self._lock.acquire()
res.model.version = self._parameters.version
for name, var in self._parameters.non_embedding_params.items():
emplace_tensor_pb_from_ndarray(
res.model.param, var.numpy(), name=name
)
# No need to send variables if the requester has the latest version.
if self._parameters.version > request.current_model_version:
for name, var in self._parameters.non_embedding_params.items():
emplace_tensor_pb_from_ndarray(
res.model.param, var.numpy(), name=name
)
if not self._use_async:
self._lock.release()
res.model_init_status = True
Expand Down Expand Up @@ -128,7 +132,10 @@ def push_gradient(self, request, _):
res.model_version = self._parameters.version
return res
else:
if request.model_version < self._parameters.version:
if (
request.model_version
< self._parameters.version - self._sync_version_tolerance
):
res.accepted = False
res.model_version = self._parameters.version
return res
Expand Down
10 changes: 9 additions & 1 deletion elasticdl/python/tests/pserver_servicer_test.py
Expand Up @@ -169,7 +169,8 @@ def test_pull_variable(self):
"v0": np.random.rand(3, 2).astype(np.float32),
"v1": np.random.rand(10, 32).astype(np.float32),
}
pull_req = empty_pb2.Empty()
pull_req = elasticdl_pb2.PullVariableRequest()
pull_req.current_model_version = -1
# try to pull variable
res = self._stub.pull_variable(pull_req)
# not initialized
Expand All @@ -192,6 +193,13 @@ def test_pull_variable(self):
tensor = tensor_pb_to_ndarray(param)
self.assertTrue(np.allclose(param0[name], tensor))

# pull variable again, no param as no updated version
pull_req.current_model_version = res.model.version
res = self._stub.pull_variable(pull_req)
self.assertTrue(res.model_init_status)
self.assertEqual(res.model.version, pull_req.current_model_version)
self.assertTrue(not res.model.param)

def test_pull_embedding_vector(self):
self.create_default_server_and_stub()

Expand Down
2 changes: 2 additions & 0 deletions elasticdl/python/tests/test_utils.py
Expand Up @@ -40,6 +40,7 @@ def __init__(
grads_to_wait=8,
lr_scheduler="learning_rate_scheduler",
lr_staleness_modulation=0,
sync_version_tolerance=0,
use_async=False,
model_zoo=None,
model_def=None,
Expand All @@ -61,6 +62,7 @@ def __init__(
self.grads_to_wait = grads_to_wait
self.learning_rate_scheduler = lr_scheduler
self.lr_staleness_modulation = lr_staleness_modulation
self.sync_version_tolerance = sync_version_tolerance
self.use_async = use_async
self.model_zoo = model_zoo
self.model_def = model_def
Expand Down
35 changes: 19 additions & 16 deletions elasticdl/python/worker/worker.py
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import tensorflow as tf
from google.protobuf import empty_pb2

from elasticdl.proto import elasticdl_pb2, elasticdl_pb2_grpc
from elasticdl.python.collective_ops.communicator import CollectiveCommunicator
Expand Down Expand Up @@ -99,6 +98,9 @@ def __init__(
elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels
]
self._var_to_ps = {}
self._ps_num = len(self._ps_stubs)
else:
self._ps_num = 0
self._distribution_strategy = args.distribution_strategy
if (
self._distribution_strategy
Expand Down Expand Up @@ -154,6 +156,7 @@ def _init_from_args(self, args):
self.set_model(model_inst)

self._model_version = -1
self._model_versions_from_ps = [-1 for _ in range(self._ps_num)]
self._task_data_service = TaskDataService(
self,
self._job_type == JobType.TRAINING_WITH_EVALUATION,
Expand Down Expand Up @@ -291,15 +294,15 @@ def get_task(self, task_type=None):

def get_model(self):
self._timing.start_record_time("get_model")
model_version = -1
variable_future_and_id_pairs = []
req = empty_pb2.Empty()
if self._use_multi_ps:
self.init_ps_var_partition()
for ps_id, stub in enumerate(self._ps_stubs):
if ps_id not in self._ps_vars:
continue
# async grpc call
req = elasticdl_pb2.PullVariableRequest()
req.current_model_version = self._model_versions_from_ps[ps_id]
var_future = stub.pull_variable.future(req)
variable_future_and_id_pairs.append((var_future, ps_id))

Expand All @@ -308,6 +311,8 @@ def get_model(self):
if not res.model_init_status:
# push variable to ps for initialization
self.report_variable_to_ps(ps_id)
req = elasticdl_pb2.PullVariableRequest()
req.current_model_version = self._model_versions_from_ps[ps_id]
res = self._ps_stubs[ps_id].pull_variable(req)
if not res.model_init_status:
# TODO: support PS fault-tolerance
Expand All @@ -318,17 +323,17 @@ def get_model(self):
for tensor_pb in res.model.param:
tensor = Tensor.from_tensor_pb(tensor_pb)
self._non_embed_vars[tensor.name].assign(tensor.to_ndarray())
self._model_versions_from_ps[ps_id] = res.model.version

model_version = max(model_version, res.model.version)
self._model_version = model_version
self._model_version = max(self._model_versions_from_ps)
self._timing.end_record_time("get_model")

def pull_embedding_vector(self, layer_name, embedding_ids):
"""Pulls and returns embedding vectors ordered by the embedding ids."""
ps_ids = {}
ps_ids_index = {}
for idx, embedding_id in enumerate(embedding_ids):
ps_id = int_to_id(embedding_id, len(self._ps_stubs))
ps_id = int_to_id(embedding_id, self._ps_num)
ps_ids.setdefault(ps_id, []).append(embedding_id)
ps_ids_index.setdefault(ps_id, []).append(idx)

Expand Down Expand Up @@ -367,9 +372,7 @@ def init_ps_var_partition(self):
ps_vars = {}
for v in self._non_embed_vars.values():
if v.name not in self._var_to_ps:
self._var_to_ps[v.name] = string_to_id(
v.name, len(self._ps_stubs)
)
self._var_to_ps[v.name] = string_to_id(v.name, self._ps_num)
ps_id = self._var_to_ps[v.name]
if ps_id not in ps_vars:
ps_vars[ps_id] = [v]
Expand Down Expand Up @@ -398,11 +401,12 @@ def report_embedding_info(self):
# tf.keras.initializers. Keep aligned between these two.
embedding_info.initializer = "uniform"

for ps_id in range(len(self._ps_stubs)):
for ps_id in range(self._ps_num):
self._ps_stubs[ps_id].push_embedding_info(model)

def report_variable_to_ps(self, ps_id):
model = elasticdl_pb2.Model()
model.version = self._model_versions_from_ps[ps_id]
if ps_id in self._ps_vars:
vars = self._ps_vars[ps_id]
for var in vars:
Expand All @@ -413,7 +417,7 @@ def report_variable_to_ps(self, ps_id):

def report_variable(self):
# TODO: call `push_model` in parallel
for ps_id in range(len(self._ps_stubs)):
for ps_id in range(self._ps_num):
self.report_variable_to_ps(ps_id)

def _collect_edl_embedding_name_values(self):
Expand All @@ -440,8 +444,7 @@ def _collect_edl_embedding_name_values(self):
def report_gradient_to_ps(self, grads):
self._timing.start_record_time("report_gradient")
reqs = [
elasticdl_pb2.PushGradientRequest()
for i in range(len(self._ps_stubs))
elasticdl_pb2.PushGradientRequest() for i in range(self._ps_num)
]
ps_grads = {}
non_embed_vars_n = len(self._non_embed_vars)
Expand Down Expand Up @@ -498,7 +501,7 @@ def report_gradient_to_ps(self, grads):
)

results = scatter_embedding_vector(
g_values.numpy(), g_indices.numpy(), len(self._ps_stubs)
g_values.numpy(), g_indices.numpy(), self._ps_num
)

for ps_id in results:
Expand All @@ -509,9 +512,9 @@ def report_gradient_to_ps(self, grads):
)

report_futures = []
for ps_id in range(len(self._ps_stubs)):
for ps_id in range(self._ps_num):
req = reqs[ps_id]
req.model_version = self._model_version
req.model_version = self._model_versions_from_ps[ps_id]
report_future = self._ps_stubs[ps_id].push_gradient.future(req)
report_futures.append(report_future)

Expand Down

0 comments on commit f1ab88d

Please sign in to comment.