Skip to content

Commit

Permalink
[sgd] Modify: add interface for model (#3458)
Browse files Browse the repository at this point in the history
* Modify: add interface for model

* Modify: remove single quota and build; add metrics

* Modify: flatten into list of dict

* Update distributed_sgd.rst

* Modify: update format with scripts/format.sh

* Update sgd_worker.py
  • Loading branch information
chunyang-wen authored and ericl committed Dec 13, 2018
1 parent 0e00533 commit 5dcc333
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 23 deletions.
2 changes: 1 addition & 1 deletion doc/source/distributed_sgd.rst
Expand Up @@ -8,7 +8,7 @@ Ray SGD is built on top of the Ray task and actor abstractions to provide seamle
Interface
---------

To use Ray SGD, define a `model class <https://github.com/ray-project/ray/blob/master/python/ray/experimental/sgd/model.py>`__ with ``loss`` and ``optimizer`` attributes:
To use Ray SGD, define a `model class <https://github.com/ray-project/ray/blob/master/python/ray/experimental/sgd/model.py>`__:

.. autoclass:: ray.experimental.sgd.Model

Expand Down
22 changes: 17 additions & 5 deletions python/ray/experimental/sgd/mnist_example.py
Expand Up @@ -57,6 +57,7 @@ def __init__(self):

# Set seed and build layers
tf.set_random_seed(0)

self.x = tf.placeholder(tf.float32, [None, 784], name="x")
self.y_ = tf.placeholder(tf.float32, [None, 10], name="y_")
y_conv, self.keep_prob = deepnn(self.x)
Expand All @@ -74,6 +75,15 @@ def __init__(self):
tf.argmax(y_conv, 1), tf.argmax(self.y_, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

def get_loss(self):
return self.loss

def get_optimizer(self):
return self.optimizer

def get_variables(self):
return self.variables

def get_feed_dict(self):
batch = self.mnist.train.next_batch(50)
return {
Expand All @@ -82,13 +92,14 @@ def get_feed_dict(self):
self.keep_prob: 0.5,
}

def test_accuracy(self):
return self.accuracy.eval(
def get_metrics(self):
accuracy = self.accuracy.eval(
feed_dict={
self.x: self.mnist.test.images,
self.y_: self.mnist.test.labels,
self.keep_prob: 1.0,
})
return {"accuracy": accuracy}


def train_mnist(config, reporter):
Expand All @@ -101,14 +112,15 @@ def train_mnist(config, reporter):
strategy=args.strategy)

# Important: synchronize the initial weights of all model replicas
w0 = sgd.for_model(lambda m: m.variables.get_flat())
sgd.foreach_model(lambda m: m.variables.set_flat(w0))
w0 = sgd.for_model(lambda m: m.get_variables().get_flat())
sgd.foreach_model(lambda m: m.get_variables().set_flat(w0))

for i in range(args.num_iters):
if i % 10 == 0:
start = time.time()
loss = sgd.step(fetch_stats=True)["loss"]
acc = sgd.foreach_model(lambda model: model.test_accuracy())
metrics = sgd.foreach_model(lambda model: model.get_metrics())
acc = [m["accuracy"] for m in metrics]
print("Iter", i, "loss", loss, "accuracy", acc)
print("Time per iteration", time.time() - start)
assert len(set(acc)) == 1, ("Models out of sync", acc)
Expand Down
36 changes: 29 additions & 7 deletions python/ray/experimental/sgd/model.py
Expand Up @@ -7,16 +7,38 @@ class Model(object):
"""Your class must implement this interface to be used with Ray SGD.
This supports any form of input pipeline: it is up to you to define it
using TensorFlow. The only requirements are that the loss and optimizer
attributes must be defined.
using TensorFlow.
For an example implementation, see tfbench/test_model.py
Attributes:
loss (tf.Tensor): Loss function to minimize.
optimizer (tf.train.Optimizer): Optimizer to use to minimize the loss.
"""

def get_loss(self):
"""Return loss of the model
Returns:
loss
"""
raise NotImplementedError(
"get_loss of %s is not implemented" % self.__class__.__name__)

# TODO support complex way of updating gradient,
# e.g. using different optimizers
def get_optimizer(self):
"""Return optimizer for the model
Returns:
optimizer
"""
raise NotImplementedError(
"get_optimizer of %s is not implemented" % self.__class__.__name__)

def get_metrics(self):
"""Return metrics of the model
Returns:
metrics(dict): e.g. {"accuracy": accuracy(numpy data)}
"""
return {}

def get_feed_dict(self):
"""Extra values to pass in when computing gradients for the loss.
Expand Down
1 change: 1 addition & 0 deletions python/ray/experimental/sgd/sgd.py
Expand Up @@ -141,6 +141,7 @@ def foreach_model(self, fn):
Returns:
List of results from applying the function.
"""

results = ray.get([w.foreach_model.remote(fn) for w in self.workers])
out = []
for r in results:
Expand Down
21 changes: 11 additions & 10 deletions python/ray/experimental/sgd/sgd_worker.py
Expand Up @@ -56,9 +56,11 @@ def __init__(self,
with tf.variable_scope("device_%d" % device_idx):
model = model_creator(worker_index, device_idx)
self.models.append(model)
optimizer = model.get_optimizer()
loss = model.get_loss()
grads = [
t for t in model.optimizer.compute_gradients(
model.loss) if t[0] is not None
t for t in optimizer.compute_gradients(loss)
if t[0] is not None
]
grad_ops.append(grads)

Expand Down Expand Up @@ -123,7 +125,7 @@ def __init__(self,
]
for j in range(num_grads):
grad = self.per_device_grads[0][j]
with tf.device(self.models[0].loss.device):
with tf.device(self.models[0].get_loss().device):
plasma_grad = plasma.tf_plasma_op.tensor_to_plasma(
[grad],
self.plasma_in_grads_oids[j],
Expand Down Expand Up @@ -174,10 +176,9 @@ def __init__(self,
apply_ops = []
to_apply = unpacked_gv[0]
for ix, m in enumerate(self.models):
apply_ops.append(
m.optimizer.apply_gradients(
[(g, v)
for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix])]))
apply_ops.append(m.get_optimizer().apply_gradients([
(g, v) for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix])
]))
self.apply_op = tf.group(*apply_ops)
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
Expand Down Expand Up @@ -209,7 +210,7 @@ def compute_gradients(self):
# averaged across all devices by allreduce.
fetches = self.sess.run(
[
self.models[0].loss, self.per_device_grads[0],
self.models[0].get_loss(), self.per_device_grads[0],
self.nccl_control_out
],
feed_dict=feed_dict)
Expand All @@ -229,7 +230,7 @@ def apply_gradients(self, avg_grads):
def compute_apply(self):
fetches = run_timeline(
self.sess,
[self.models[0].loss, self.apply_op, self.nccl_control_out],
[self.models[0].get_loss(), self.apply_op, self.nccl_control_out],
feed_dict=self._grad_feed_dict(),
name="compute_apply")
return fetches[0]
Expand All @@ -247,7 +248,7 @@ def ps_compute_apply(self,
fetch(agg_grad_shard_oids)
fetches = run_timeline(
self.sess, [
self.models[0].loss, self.plasma_in_grads, self.apply_op,
self.models[0].get_loss(), self.plasma_in_grads, self.apply_op,
self.nccl_control_out
],
feed_dict=feed_dict,
Expand Down
7 changes: 7 additions & 0 deletions python/ray/experimental/sgd/tfbench/test_model.py
Expand Up @@ -14,6 +14,7 @@ class MockDataset():

class TFBenchModel(Model):
def __init__(self, batch=64, use_cpus=False):

image_shape = [batch, 224, 224, 3]
labels_shape = [batch]

Expand Down Expand Up @@ -45,5 +46,11 @@ def __init__(self, batch=64, use_cpus=False):
self.loss = tf.reduce_mean(loss, name='xentropy-loss')
self.optimizer = tf.train.GradientDescentOptimizer(1e-6)

def get_loss(self):
return self.loss

def get_optimizer(self):
return self.optimizer

def get_feed_dict(self):
return {}

0 comments on commit 5dcc333

Please sign in to comment.