From 46dabe138ffb82e53dba261ba4ebebb8d0b5c78a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 14 May 2018 19:54:45 -0700 Subject: [PATCH 01/17] check in sgd api --- .../ray/experimental/sgd/chrome_timeline.py | 86 +++ python/ray/experimental/sgd/sgd.py | 481 ++++++++++++++ python/ray/experimental/sgd/test_model.py | 44 ++ python/ray/experimental/sgd/test_sgd.py | 26 + .../ray/experimental/sgd/tfbench/__init__.py | 0 .../ray/experimental/sgd/tfbench/allreduce.py | 590 ++++++++++++++++++ .../sgd/tfbench/convnet_builder.py | 467 ++++++++++++++ python/ray/experimental/sgd/tfbench/model.py | 114 ++++ .../experimental/sgd/tfbench/model_config.py | 61 ++ .../sgd/tfbench/modified_allreduce.py | 171 +++++ .../experimental/sgd/tfbench/resnet_model.py | 346 ++++++++++ 11 files changed, 2386 insertions(+) create mode 100644 python/ray/experimental/sgd/chrome_timeline.py create mode 100644 python/ray/experimental/sgd/sgd.py create mode 100644 python/ray/experimental/sgd/test_model.py create mode 100644 python/ray/experimental/sgd/test_sgd.py create mode 100644 python/ray/experimental/sgd/tfbench/__init__.py create mode 100644 python/ray/experimental/sgd/tfbench/allreduce.py create mode 100644 python/ray/experimental/sgd/tfbench/convnet_builder.py create mode 100644 python/ray/experimental/sgd/tfbench/model.py create mode 100644 python/ray/experimental/sgd/tfbench/model_config.py create mode 100644 python/ray/experimental/sgd/tfbench/modified_allreduce.py create mode 100644 python/ray/experimental/sgd/tfbench/resnet_model.py diff --git a/python/ray/experimental/sgd/chrome_timeline.py b/python/ray/experimental/sgd/chrome_timeline.py new file mode 100644 index 0000000000000..41b34b9d019ea --- /dev/null +++ b/python/ray/experimental/sgd/chrome_timeline.py @@ -0,0 +1,86 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray +import json +import time + + +class Timeline(object): + def __init__(self, tid): + self.events = [] + self.offset = 0 + self.start_time = self.time() + self.tid = tid + + def patch_ray(self): + orig_log = ray.worker.log + + def custom_log(event_type, kind, *args, **kwargs): + orig_log(event_type, kind, *args, **kwargs) + if kind == ray.worker.LOG_SPAN_START: + self.start(event_type) + elif kind == ray.worker.LOG_SPAN_END: + self.end(event_type) + elif kind == ray.worker.LOG_SPAN_POINT: + self.event(event_type) + + ray.worker.log = custom_log + + def time(self): + return time.time() + self.offset + + def reset(self): + self.events = [] + self.start_time = self.time() + + def start(self, name): + self.events.append((self.tid, "B", name, self.time())) + + def end(self, name): + self.events.append((self.tid, "E", name, self.time())) + + def event(self, name): + now = self.time() + self.events.append((self.tid, "B", name, now)) + self.events.append((self.tid, "E", name, now + .0001)) + + def merge(self, other): + if other.start_time < self.start_time: + self.start_time = other.start_time + self.events.extend(other.events) + self.events.sort(key=lambda e: e[3]) + + def chrome_trace_format(self, filename): + out = [] + for tid, ph, name, t in self.events: + ts = int((t - self.start_time) * 1000000) + out.append({ + "name": name, + "tid": tid, + "pid": tid, + "ph": ph, + "ts": ts, + }) + with open(filename, "w") as f: + f.write(json.dumps(out)) + print("Wrote chrome timeline to", filename) + + +if __name__ == "__main__": + a = Timeline(1) + b = Timeline(2) + a.start("hi") + time.sleep(.1) + b.start("bye") + a.start("hi3") + time.sleep(.1) + a.end("hi3") + b.end("bye") + time.sleep(.1) + a.end("hi") + b.start("b1") + b.end("b1") + a.merge(b) + a.chrome_trace_format("test.json") diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py new file mode 100644 index 0000000000000..9358c275ef515 --- /dev/null +++ b/python/ray/experimental/sgd/sgd.py @@ -0,0 +1,481 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random +import ray +import time + +from tensorflow.python.client import timeline +import numpy as np +import tensorflow as tf +import tensorflow.contrib.nccl as nccl +import tensorflow.contrib.slim as slim + +from chrome_timeline import Timeline +from tfbench import allreduce + + +def fetch(oids): + for o in oids: + plasma_id = ray.pyarrow.plasma.ObjectID(o) + ray.worker.global_worker.plasma_client.fetch([plasma_id]) + + +def run_timeline(sess, ops, feed_dict={}, write_timeline=False, name=""): + if write_timeline: + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = tf.RunMetadata() + fetches = sess.run( + ops, options=run_options, run_metadata=run_metadata, + feed_dict=feed_dict) + trace = timeline.Timeline(step_stats=run_metadata.step_stats) + outf = "timeline-{}-{}.json".format(name, os.getpid()) + trace_file = open(outf, "w") + print("wrote tf timeline to", os.path.abspath(outf)) + trace_file.write(trace.generate_chrome_trace_format()) + else: + fetches = sess.run(ops, feed_dict=feed_dict) + return fetches + + +class SGDWorker(object): + def __init__(self, + worker_index, + model_creator, + all_reduce_alg="simple", + num_devices=1, + use_cpus=False, + max_bytes=0, + plasma_op=False, + verbose=False): + self.worker_index = worker_index + assert num_devices > 0 + + # TODO(ekl) support custom session + tf_session_args = { + "device_count": {"CPU": num_devices}, + "log_device_placement": False, + "gpu_options": tf.GPUOptions(force_gpu_compatible=True), + "inter_op_parallelism_threads": 128, + } + config_proto = tf.ConfigProto(**tf_session_args) + self.sess = tf.Session(config=config_proto) + self.models = [] + grad_ops = [] + + if use_cpus: + device_tmpl = "/cpu:%d" + else: + device_tmpl = "/gpu:%d" + for device_idx in range(num_devices): + device = device_tmpl % device_idx + with tf.device(device): + with tf.variable_scope("device_%d" % device_idx): + model = model_creator(worker_index, device_idx) + self.models.append(model) + model.grads = [ + t for t in model.optimizer.compute_gradients( + model.loss) + if t[0] is not None] + grad_ops.append(model.grads) + + if num_devices == 1: + assert not max_bytes, "Not supported with 1 GPU" + self.packed_grads_and_vars = grad_ops + else: + if max_bytes: + from tfbench import modified_allreduce + self.packed_grads_and_vars, packing_vals = ( + modified_allreduce.sum_gradients_all_reduce( + "", grad_ops, 1, all_reduce_alg, 1, + list(range(num_devices)), + agg_small_grads_max_bytes=max_bytes)) + else: + self.packed_grads_and_vars = ( + allreduce.sum_gradients_all_reduce( + "", grad_ops, 1, all_reduce_alg, 1, + list(range(num_devices)))) + self.per_device_grads = [ + list(zip(*dev_gv))[0] for dev_gv in self.packed_grads_and_vars] + assert(len(self.per_device_grads) == num_devices) + self.num_grads = num_grads = len(self.packed_grads_and_vars[0]) + if max_bytes: + print("Packed grads => {} tensors".format(num_grads)) + + # Ops for reading grads with the right control deps + nccl_noops = [] + for j in range(num_grads)[::-1]: + with tf.control_dependencies( + nccl_noops + [dev_grad[j] + for dev_grad in self.per_device_grads]): + nccl_noops = [tf.no_op()] + + # You must fetch this otherwise the NCCL allreduce will hang + self.nccl_control_out = tf.group(*nccl_noops) + + round_robin_devices = False + if plasma_op: + store_socket = ( + ray.worker.global_worker.plasma_client.store_socket_name) + manager_socket = ( + ray.worker.global_worker.plasma_client.manager_socket_name) + memcpy_plasma_module = tf.load_op_library( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "ops/memcpy_plasma_op.so")) + + # For fetching grads -> plasma + self.plasma_in_grads = [] + self.plasma_in_grads_oids = [ + tf.placeholder(shape=[], dtype=tf.string) + for _ in range(num_grads)] + ix = 0 + for j in range(num_grads): + grad = self.per_device_grads[ix][j] + if round_robin_devices: + ix += 1 # round robin assignment + ix %= num_devices + with tf.device(self.models[ix].device): + plasma_grad = memcpy_plasma_module.tensor_to_plasma( + [grad], + self.plasma_in_grads_oids[j], + plasma_store_socket_name=store_socket, + plasma_manager_socket_name=manager_socket) + self.plasma_in_grads.append(plasma_grad) + + # For applying grads <- plasma + unpacked_gv = [] + self.plasma_out_grads_oids = [ + tf.placeholder(shape=[], dtype=tf.string) + for _ in range(num_grads)] + packed_plasma_grads = [] + ix = 0 + for j in range(num_grads): + with tf.device(self.plasma_in_grads[j].device): + with tf.control_dependencies([self.plasma_in_grads[j]]): + grad_ph = memcpy_plasma_module.plasma_to_tensor( + self.plasma_out_grads_oids[j], + plasma_store_socket_name=store_socket, + plasma_manager_socket_name=manager_socket) + grad_ph = tf.reshape( + grad_ph, self.packed_grads_and_vars[0][j][0].shape) + print("Packed tensor", grad_ph) + packed_plasma_grads.append(grad_ph) + for i in range(num_devices): + per_device = [] + for j, (g, v) in enumerate(self.packed_grads_and_vars[i]): + grad_ph = packed_plasma_grads[j] + per_device.append((grad_ph, v)) + unpacked_gv.append(per_device) + + if max_bytes: + unpacked_gv = allreduce.unpack_small_tensors( + unpacked_gv, packing_vals) + + elif max_bytes: + unpacked_gv = allreduce.unpack_small_tensors( + self.packed_grads_and_vars, packing_vals) + else: + unpacked_gv = self.packed_grads_and_vars + + # Same shape as packed_grads_and_vars + assert len(unpacked_gv) == num_devices + assert len(unpacked_gv[0][0]) == 2 + + apply_ops = [] + to_apply = unpacked_gv[0] + for ix, m in enumerate(self.models): + apply_ops.append(m.optimizer.apply_gradients( + [(g, v) + for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix])])) + self.apply_op = tf.group(*apply_ops) + init_op = tf.group(tf.global_variables_initializer(), + tf.local_variables_initializer()) + self.sess.run(init_op) + + def compute_gradients(self, verbose): + start = time.time() + fetches = self.sess.run( + [self.models[0].loss, self.per_device_grads[0], + self.nccl_control_out]) + if verbose: + print("compute grad interior time", time.time() - start) + return fetches + + def apply_gradients(self, avg_grads, verbose): + start = time.time() + result = { + g: avg_grads[i] for (i, g) in enumerate(self.per_device_grads[0]) + } + self.sess.run(self.apply_op, feed_dict=result) + if verbose: + print("apply grad interior time", time.time() - start) + + def ps_compute_apply( + self, out_grad_shard_oids, agg_grad_shard_oids, + tl_name="ps_compute_apply", write_timeline=False): + feed_dict = { + ph: oid + for (ph, oid) + in zip(self.plasma_in_grads_oids, out_grad_shard_oids) + } + feed_dict.update({ + ph: oid + for (ph, oid) + in zip(self.plasma_out_grads_oids, agg_grad_shard_oids) + }) + fetch(agg_grad_shard_oids) + run_timeline( + self.sess, + [self.plasma_in_grads, self.apply_op, self.nccl_control_out], + feed_dict=feed_dict, + write_timeline=write_timeline) + + def num_grad_shards(self): + return self.num_grads + + def shard_shapes(self): + main_gv = self.packed_grads_and_vars[0] + return [g.shape for g, _ in main_gv] + + def ip(self): + return ray.services.get_node_ip_address() + + +class ParameterServer(object): + def __init__(self, num_workers, tid): + self.num_sgd_workers = num_workers + self.acc_counter = 0 + self.timeline = Timeline(tid) + self.timeline.patch_ray() + + def set_tid(self, tid): + self.timeline.tid = tid + + def get_time(self): + return time.time() + self.timeline.offset + + def set_time(self, ref_time): + self.timeline.offset = ref_time - time.time() + + def initialize(self, shard_shape): + self.accumulated = np.zeros(shard_shape, dtype=np.float32) + + def mark(self): + self.timeline.event("mark") + + def prefetch(self, oids): + self.timeline.reset() + self.timeline.start("prefetch") + fetch(oids) + self.timeline.end("prefetch") + + def add_spinwait(self, grad_shard_ids): + self.timeline.start("add_spinwait") + plasma_ids = [ray.pyarrow.plasma.ObjectID(x) for x in grad_shard_ids] + while plasma_ids: + for p in plasma_ids: + if ray.worker.global_worker.plasma_client.contains(p): + self.timeline.start("get_buffers") + [raw_grads] = ( + ray.worker.global_worker.plasma_client.get_buffers( + [p])) + grads = np.frombuffer(raw_grads, dtype=np.float32) + self.accumulated += grads + self.acc_counter += 1 + self.timeline.end("get_buffers") + plasma_ids.remove(p) + break + self.timeline.end("add_spinwait") + + def add(self, grad_shard_id): + self.timeline.start("add") + # self.timeline.start("add_wait") + # ray.wait([ray.local_scheduler.ObjectID(grad_shard_id)]) + # self.timeline.end("add_wait") + self.timeline.start("get_buffers") + oid = ray.pyarrow.plasma.ObjectID(grad_shard_id) + [raw_grads] = ray.worker.global_worker.plasma_client.get_buffers([oid]) + grads = np.frombuffer(raw_grads, dtype=np.float32) + self.timeline.end("get_buffers") + self.accumulated += grads + self.acc_counter += 1 + self.timeline.end("add") + + def get(self, object_id): + self.timeline.start("get") + client = ray.worker.global_worker.plasma_client + assert self.acc_counter == self.num_sgd_workers, self.acc_counter + oid = ray.pyarrow.plasma.ObjectID(object_id) + buff = client.create( + oid, self.accumulated.nbytes) + wrapper = np.frombuffer(buff, dtype=np.float32) + np.copyto(wrapper, self.accumulated) + client.seal(oid) + self.accumulated = np.zeros_like(self.accumulated) + self.acc_counter = 0 + self.timeline.end("get") + + def get_timeline(self): + return self.timeline + + def ip(self): + return ray.services.get_node_ip_address() + + def pin(self, cpu_id): + try: + import psutil + p = psutil.Process() + p.cpu_affinity([cpu_id]) + print("Setting CPU Affinity to: ", cpu_id) + except Exception as e: + print(e) + + +def average_gradients(grads): + out = [] + for grad_list in zip(*grads): + out.append(np.mean(grad_list, axis=0)) + return out + + +def do_sgd_step(actors, verbose): + start = time.time() + fetches = ray.get([a.compute_gradients.remote(verbose) for a in actors]) + losses = [f[0] for f in fetches] + grads = [f[1] for f in fetches] + if verbose: + print("compute all grads time", time.time() - start) + start = time.time() + if len(actors) == 1: + assert len(grads) == 1 + avg_grad = grads[0] + else: + avg_grad = average_gradients(grads) + if verbose: + print("grad reduce time", time.time() - start) + start = time.time() + ray.get([a.apply_gradients.remote(avg_grad, verbose) for a in actors]) + if verbose: + print("apply all grads time", time.time() - start) + return np.mean(losses) + + +def distributed_sgd_step(actors, ps_list, verbose, write_timeline): + # Preallocate object ids that actors will write gradient shards to + grad_shard_oids_list = [ + [np.random.bytes(20) for _ in ps_list] + for _ in actors + ] + print("generated grad oids") + + # Preallocate object ids that param servers will write new weights to + accum_shard_ids = [np.random.bytes(20) for _ in ps_list] + print("generated accum oids") + + # Kick off the fused compute grad / update weights tf run for each actor + for actor, grad_shard_oids in zip(actors, grad_shard_oids_list): + actor.ps_compute_apply.remote(grad_shard_oids, accum_shard_ids, + write_timeline=write_timeline) + print("Launched all ps_compute_applys on all actors") + + # Issue prefetch ops + for j, (ps, weight_shard_oid) in list( + enumerate(zip(ps_list, accum_shard_ids)))[::-1]: + to_fetch = [] + for grad_shard_oids in grad_shard_oids_list: + to_fetch.append(grad_shard_oids[j]) + random.shuffle(to_fetch) + ps.prefetch.remote(to_fetch) + print("Launched all prefetch ops") + + # Aggregate the gradients produced by the actors. These operations + # run concurrently with the actor methods above. + ps_gets = [] + for j, (ps, weight_shard_oid) in list( + enumerate(zip(ps_list, accum_shard_ids)))[::-1]: + ps.add_spinwait.remote([gs[j] for gs in grad_shard_oids_list]) + ps_gets.append(ps.get.remote(weight_shard_oid)) + print("Launched all aggregate ops") + + if verbose: + timelines = [ps.get_timeline.remote() for ps in ps_list] + print("launched timeline gets") + timelines = ray.get(timelines) + t0 = timelines[0] + for t in timelines[1:]: + t0.merge(t) + t0.chrome_trace_format("ps_timeline.json") + else: + # Wait for at least the ps gets to finish + ray.get(ps_gets) + + +def roundrobin_ps(ps_cls, sgd_workers, shard_shapes, spread_ps): + worker_ips = ray.get([w.ip.remote() for w in sgd_workers]) + num_ips = len(set(worker_ips)) + num_workers = len(sgd_workers) + min_placed = np.ceil(len(shard_shapes) / num_ips) + from collections import Counter, defaultdict + tid_counter = [0] + + def create_ps(): + tid_counter[0] += 1 + return RemotePS.remote(num_workers, tid_counter[0]) + + ip_mapping = defaultdict(list) + + while (any(len(v) < min_placed for v in ip_mapping.values()) + or (len(ip_mapping) < num_ips)): + print("generating new ps, ip map so far", ip_mapping) + new_ps = create_ps() + ps_ip = ray.get(new_ps.ip.remote()) + if spread_ps and ps_ip in worker_ips: + print("ignoring ps that is on same node as worker") + elif not spread_ps and ps_ip not in worker_ips: + print("ignoring ps that NOT on same node as some worker") + else: + ip_mapping[ps_ip] += [new_ps] + + final_list = [] + candidates = list(ip_mapping.values()) + for i, s in enumerate(shard_shapes): + ps = candidates[i % num_ips][i // num_ips] + final_list += [ps] + ps.initialize.remote(s) + + for ps in sum(candidates, []): + if ps not in final_list: + ps.__ray_terminate__.remote(ps._ray_actor_id.id()) + print("removing a ps...") + else: + print("saving ps...") + + print("Final PS balance: ", Counter(ray.get([ps.ip.remote() for ps in final_list]))) + for i, ps in enumerate(final_list): + ps.set_tid.remote(i) + return final_list + + +class DistributedSGD(object): + def __init__( + self, model_creator, num_workers, devices_per_worker, use_cpus): + self.model_creator = model_creator + if use_cpus: + requests = {"num_cpus": devices_per_worker} + else: + requests = {"num_gpus": devices_per_worker} + RemoteSGDWorker = ray.remote(**requests)(SGDWorker) + self.workers = [] + for worker_index in range(num_workers): + print("Creating worker", worker_index) + self.workers.append( + RemoteSGDWorker.remote( + worker_index, model_creator, + num_devices=devices_per_worker, use_cpus=use_cpus, + verbose=True)) + + def step(self): + return do_sgd_step(self.workers, True) diff --git a/python/ray/experimental/sgd/test_model.py b/python/ray/experimental/sgd/test_model.py new file mode 100644 index 0000000000000..445746bd432f5 --- /dev/null +++ b/python/ray/experimental/sgd/test_model.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tfbench import model_config + + +class MockDataset(): + name = "synthetic" + + +class TFBenchModel(object): + def __init__(self, batch=64, use_cpus=False): + image_shape = [batch, 224, 224, 3] + labels_shape = [batch] + + # Synthetic image should be within [0, 255]. + images = tf.truncated_normal( + image_shape, + dtype=tf.float32, + mean=127, + stddev=60, + name='synthetic_images') + + # Minor hack to avoid H2D copy when using synthetic data + self.inputs = tf.contrib.framework.local_variable( + images, name='gpu_cached_images') + self.labels = tf.random_uniform( + labels_shape, + minval=0, + maxval=999, + dtype=tf.int32, + name='synthetic_labels') + + self.model = model_config.get_model_config("resnet101", MockDataset()) + logits, aux = self.model.build_network( + self.inputs, data_format=use_cpus and "NHWC" or "NCHW") + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=self.labels) + self.loss = tf.reduce_mean(loss, name='xentropy-loss') + self.optimizer = tf.train.GradientDescentOptimizer(1e-6) diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py new file mode 100644 index 0000000000000..aff28c3194256 --- /dev/null +++ b/python/ray/experimental/sgd/test_sgd.py @@ -0,0 +1,26 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray + +import argparse +import numpy as np +import tensorflow as tf + +from test_model import TFBenchModel +from sgd import DistributedSGD + + +if __name__ == "__main__": + ray.init() + + model_creator = ( + lambda i, j: TFBenchModel(batch=1, use_cpus=True)) + + sgd = DistributedSGD( + model_creator, num_workers=2, devices_per_worker=2, use_cpus=True) + + for _ in range(100): + loss = sgd.step() + print("Current loss", loss) diff --git a/python/ray/experimental/sgd/tfbench/__init__.py b/python/ray/experimental/sgd/tfbench/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/ray/experimental/sgd/tfbench/allreduce.py b/python/ray/experimental/sgd/tfbench/allreduce.py new file mode 100644 index 0000000000000..06da6c577a623 --- /dev/null +++ b/python/ray/experimental/sgd/tfbench/allreduce.py @@ -0,0 +1,590 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for allreduce.""" + +from __future__ import print_function + +import collections as pycoll +import re + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.contrib import nccl +from tensorflow.contrib.all_reduce.python import all_reduce + +AllReduceSpecTuple = pycoll.namedtuple('AllReduceSpecTuple', 'alg shards limit') + + +def parse_general_int(s): + """Parse integer with power-of-2 suffix eg. 32k.""" + mo = re.match(r'(\d+)([KkMGT]?)$', s) + if mo: + i, suffix = mo.group(1, 2) + v = int(i) + if suffix: + if suffix == 'K' or suffix == 'k': + v *= 1024 + elif suffix == 'M': + v *= (1024 * 1024) + elif suffix == 'G': + v *= (1024 * 1024 * 1024) + elif suffix == 'T': + v *= (1024 * 1024 * 1024 * 1024) + else: + raise ValueError('invalid integer string %s' % s) + return v + else: + v = int(s) + return v + + +def parse_all_reduce_spec(all_reduce_spec): + """Parse all_reduce_spec. + + Args: + all_reduce_spec: a string specifying a combination of all-reduce + algorithms to apply for gradient reduction. + + Returns: + a list of AllReduceSpecTuple. + + Raises: + ValueError: all_reduce_spec is not well-formed. + + An all_reduce_spec has BNF form: + int ::= positive whole number + g_int ::= int[KkMGT]? + alg_spec ::= alg | alg#int + range_spec ::= alg_spec | alg_spec/alg_spec + spec ::= range_spec | range_spec:g_int:range_spec + + Not all syntactically correct specifications are supported. + Examples of supported all_reduce_spec strings, with semantics explained: + + 'xring' == apply ring all-reduce to all tensors + 'xring#2' == apply ring all-reduce to all tensors, using two simultaneous + transfer rings, each operating on 1/2 of each tensor. + 'nccl' == apply NCCL all-reduce to all tensors (only works within + a single worker process where all devices are GPUs) + 'nccl/xring' == apply NCCL all-reduce to all tensors within each worker + to produce at least one full-reduced (locally) value, + then apply ring all-reduce to one such value from each + worker, then apply NCCL broadcast to propagate those globally + reduced values back to every device within each worker. + 'pscpu' == Shuffle reduce using worker CPUs as the gather devices: each + distributed tensor is reduced by copying all instances to + one of the worker CPUs, computing the reduction there, then + copying back to each participating device. Tensor reductions + are assigned to specific CPUs round-robin. + 'psgpu#4' == Arrange all GPUs across all workers into groups of 4. + Each distributed tensor is shuffle reduced against one + such group of 4 GPUs, selected round-robin. That is, each + tensor is split across 4 shards for the reduction. + 'pscpu:2k:pscpu#2:64k:xring' == Apply single-shard pscpu to + tensors of size <= 2048 elements, apply 2-shard pscpu to + tensors up to size 64k elements, apply xring to larger tensors. + 'pscpu/pscpu#2' == Use shuffle gather to locally reduce each tensor on + the worker's CPU, then use 2-shard shuffle to reduce those + locally reduced tensors across workers (on the worker CPUs), then + scatter the globally reduced values locally from each worker CPU. + """ + range_parts = all_reduce_spec.split(':') + ['-1'] + if len(range_parts) % 2: + raise ValueError('all_reduce_spec not well formed: %s' % all_reduce_spec) + limit = 0 + spec = [] + alg = None + shards = 1 + for i, range_part in enumerate(range_parts): + if i % 2 == 1: + try: + limit = parse_general_int(range_part) + spec.append(AllReduceSpecTuple(alg=alg, shards=shards, limit=limit)) + except ValueError: + raise ValueError('all_reduce_spec (%s) contains non-integer range %s' % + (all_reduce_spec, range_part)) + else: + alg = range_part + alg_parts = range_part.split('#') + alg = alg_parts[0] + if len(alg_parts) > 1: + try: + shards = int(alg_parts[1]) + except ValueError: + raise ValueError('all_reduce_spec (%s) contains non-integer ' + 'shards %s' % all_reduce_spec, alg_parts[1]) + else: + shards = 1 + if alg not in [ + 'nccl', 'nccl/xring', 'nccl/rechd', 'nccl/pscpu', 'xring', 'pscpu', + 'psgpu', 'pscpu/pscpu' + ]: + raise ValueError('all_reduce_spec (%s) contains invalid alg %s' % + (all_reduce_spec, alg)) + return spec + + +def build_all_reduce_device_prefixes(job_name, num_tasks): + """Build list of device prefix names for all_reduce. + + Args: + job_name: 'worker', 'ps' or 'localhost'. + num_tasks: number of jobs across which device names should be generated. + + Returns: + A list of device name prefix strings. Each element spells out the full + host name without adding the device. + e.g. '/job:worker/task:0' + """ + if job_name != 'localhost': + return ['/job:%s/task:%d' % (job_name, d) for d in range(0, num_tasks)] + else: + assert num_tasks == 1 + return ['/job:%s' % job_name] + + +def group_device_names(devices, group_size): + """Group device names into groups of group_size. + + Args: + devices: list of strings naming devices. + group_size: int >= 1 + + Returns: + list of lists of devices, where each inner list is group_size long, + and each device appears at least once in an inner list. If + len(devices) % group_size = 0 then each device will appear + exactly once. + + Raises: + ValueError: group_size > len(devices) + """ + num_devices = len(devices) + if group_size > num_devices: + raise ValueError('only %d devices, but group_size=%d' % (num_devices, + group_size)) + num_groups = ( + num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) + groups = [[] for i in range(num_groups)] + for i in range(0, num_groups * group_size): + groups[i % num_groups].append(devices[i % num_devices]) + return groups + + +def split_grads_by_size(threshold_size, device_grads): + """Break gradients into two sets according to tensor size. + + Args: + threshold_size: int size cutoff for small vs large tensor. + device_grads: List of lists of (gradient, variable) tuples. The outer + list is over devices. The inner list is over individual gradients. + + Returns: + small_grads: Subset of device_grads where shape is <= theshold_size + elements. + large_grads: Subset of device_grads where shape is > threshold_size + elements. + """ + small_grads = [] + large_grads = [] + for dl in device_grads: + small_dl = [] + large_dl = [] + for (g, v) in dl: + tensor_size = g.get_shape().num_elements() + if tensor_size <= threshold_size: + small_dl.append([g, v]) + else: + large_dl.append([g, v]) + if small_dl: + small_grads.append(small_dl) + if large_dl: + large_grads.append(large_dl) + return small_grads, large_grads + + +def build_reduce_sum(scaled_grads): + stacked = tf.parallel_stack(values=scaled_grads) + reduced = tf.reduce_sum(stacked, 0) + return [reduced] * len(scaled_grads) + +def build_trivial_sum(scaled_grads): + return scaled_grads + +def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, + check_inf_nan): + """Calculate the average gradient for a shared variable across all towers. + + Note that this function provides a synchronization point across all towers. + + Args: + grad_and_vars: A list or tuple of (gradient, variable) tuples. Each + (gradient, variable) pair within the outer list represents the gradient + of the variable calculated for a single tower, and the number of pairs + equals the number of towers. + use_mean: if True, mean is taken, else sum of gradients is taken. + check_inf_nan: check grads for nans and infs. + + Returns: + The tuple ([(average_gradient, variable),], has_nan_or_inf) where the + gradient has been averaged across all towers. The variable is chosen from + the first tower. The has_nan_or_inf indicates the grads has nan or inf. + """ + grads = [g for g, _ in grad_and_vars] + grad = tf.add_n(grads) + + if use_mean and len(grads) > 1: + grad = tf.multiply(grad, 1.0 / len(grads)) + + v = grad_and_vars[0][1] + if check_inf_nan: + has_nan_or_inf = tf.logical_not(tf.reduce_all(tf.is_finite(grads))) + return (grad, v), has_nan_or_inf + else: + return (grad, v), None + + +def aggregate_gradients_using_copy_with_device_selection( + tower_grads, avail_devices, use_mean=True, check_inf_nan=False): + """Aggregate gradients, controlling device for the aggregation. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over towers. The inner list is over individual gradients. + use_mean: if True, mean is taken, else sum of gradients is taken. + check_inf_nan: If true, check grads for nans and infs. + + Returns: + The tuple ([(average_gradient, variable),], has_nan_or_inf) where the + gradient has been averaged across all towers. The variable is chosen from + the first tower. The has_nan_or_inf indicates the grads has nan or inf. + """ + agg_grads = [] + has_nan_or_inf_list = [] + for i, single_grads in enumerate(zip(*tower_grads)): + with tf.device(avail_devices[i % len(avail_devices)]): + grad_and_var, has_nan_or_inf = aggregate_single_gradient_using_copy( + single_grads, use_mean, check_inf_nan) + agg_grads.append(grad_and_var) + has_nan_or_inf_list.append(has_nan_or_inf) + return agg_grads + + +def sum_grad_and_var_all_reduce(grad_and_vars, + num_workers, + alg, + gpu_indices, + aux_devices=None, + num_shards=1): + """Apply all-reduce algorithm over specified gradient tensors.""" + with tf.name_scope('allreduce'): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + scaled_grads = [g for g, _ in grad_and_vars] + if alg == 'nccl': + summed_grads = nccl.all_sum(scaled_grads) + elif alg == 'simple': + summed_grads = build_reduce_sum(scaled_grads) + elif alg == 'trivial': + summed_grads = build_trivial_sum(scaled_grads) + elif alg == 'xring': + summed_grads = all_reduce.build_ring_all_reduce( + scaled_grads, num_workers, num_shards, gpu_indices, tf.add) + elif alg == 'nccl/xring': + summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, + tf.add) + elif alg == 'nccl/rechd': + summed_grads = all_reduce.build_nccl_then_recursive_hd( + scaled_grads, tf.add) + elif alg == 'nccl/pscpu': + summed_grads = all_reduce.build_nccl_then_shuffle( + scaled_grads, aux_devices, tf.add, tf.add_n) + elif alg == 'pscpu/pscpu': + summed_grads = all_reduce.build_shuffle_then_shuffle( + scaled_grads, + aux_devices, + # TODO(tucker): devise a way of better specifying the device set + # for the second level. + [aux_devices[0]], + tf.add_n) + elif alg in ['pscpu', 'psgpu']: + summed_grads = all_reduce.build_shuffle_all_reduce( + scaled_grads, aux_devices, tf.add_n) + else: + raise ValueError('unsupported all_reduce alg: ', alg) + + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result + + +def contains_any(haystack, needles): + """Tests if any needle is a substring of haystack. + + Args: + haystack: a string + needles: list of strings + + Returns: + True if any element of needles is a substring of haystack, + False otherwise. + """ + for n in needles: + if n in haystack: + return True + return False + + +def sum_gradients_all_reduce(dev_prefixes, + tower_grads, + num_workers, + alg, + num_shards, + gpu_indices, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=10): + """Apply all-reduce algorithm over specified gradient tensors. + + Args: + dev_prefixes: list of prefix strings to use to generate PS device names. + tower_grads: the gradients to reduce. + num_workers: number of worker processes across entire job. + alg: the all-reduce algorithm to apply. + num_shards: alg-specific sharding factor. + gpu_indices: indices of local GPUs in order usable for ring-reduce. + agg_small_grads_max_bytes: largest tensor eligible for aggregation, + in number of bytes. + agg_small_grads_max_group: largest permitted aggregation of small + tensors. + + Returns: + list of reduced tensors + """ + alg_contains_shuffle = contains_any(alg, ['pscpu', 'psgpu']) + is_hierarchical = '/' in alg + if 'pscpu' in alg: + aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] + elif 'psgpu' in alg: + aux_devices = [ + prefix + '/gpu:%d' % i + for i in range(len(gpu_indices)) + for prefix in dev_prefixes + ] + else: + aux_devices = ['/job:localhost/cpu:0'] + aux_device_groups = group_device_names(aux_devices, num_shards + if alg_contains_shuffle else 1) + group_index = 0 + if agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0: + tower_grads, packing = pack_small_tensors( + tower_grads, + max_bytes=agg_small_grads_max_bytes, + max_group=agg_small_grads_max_group) + else: + packing = None + reduced_gv_list = [] + for grad_and_vars in zip(*tower_grads): + reduced_gv_list.append( + sum_grad_and_var_all_reduce( + grad_and_vars, num_workers, alg, gpu_indices, aux_devices + if is_hierarchical else aux_device_groups[group_index], num_shards)) + group_index = (group_index + 1) % len(aux_device_groups) + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + if packing: + new_tower_grads = unpack_small_tensors(new_tower_grads, packing) + return new_tower_grads + + +def extract_ranges(index_list, range_size_limit=32): + """Extract consecutive ranges and singles from index_list. + + Args: + index_list: List of monotone increasing non-negative integers. + range_size_limit: Largest size range to return. If a larger + consecutive range exists it will be returned as multiple + ranges. + + Returns: + ranges, singles where ranges is a list of [first, last] pairs of + consecutive elements in index_list, and singles is all of the + other elements, in original order. + """ + if not index_list: + return [], [] + first = index_list[0] + last = first + ranges = [] + singles = [] + for i in index_list[1:]: + if i == last + 1 and (last - first) <= range_size_limit: + last = i + else: + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + first = i + last = i + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + return ranges, singles + + +GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') + + +def pack_range(key, packing, grad_vars, rng): + """Form the concatenation of a specified range of gradient tensors. + + Args: + key: Value under which to store meta-data in packing that will be used + later to restore the grad_var list structure. + packing: Dict holding data describing packed ranges of small tensors. + grad_vars: List of (grad, var) pairs for one tower. + rng: A pair of integers giving the first, last indices of a consecutive + range of tensors to be packed. + + Returns: + A tensor that is the concatenation of all the specified small tensors. + """ + to_pack = grad_vars[rng[0]:rng[1] + 1] + members = [] + variables = [] + restore_shapes = [] + with tf.name_scope('pack'): + for g, v in to_pack: + variables.append(v) + restore_shapes.append(g.shape) + with tf.device(g.device): + members.append(tf.reshape(g, [-1])) + packing[key] = GradPackTuple( + indices=range(rng[0], rng[1] + 1), + vars=variables, + shapes=restore_shapes) + with tf.device(members[0].device): + return tf.concat(members, 0) + + +def unpack_grad_tuple(gv, gpt): + """Unpack a previously packed collection of gradient tensors. + + Args: + gv: A (grad, var) pair to be unpacked. + gpt: A GradPackTuple describing the packing operation that produced gv. + + Returns: + A list of (grad, var) pairs corresponding to the values that were + originally packed into gv, maybe following subsequent operations like + reduction. + """ + elt_widths = [x.num_elements() for x in gpt.shapes] + with tf.device(gv[0][0].device): + with tf.name_scope('unpack'): + splits = tf.split(gv[0], elt_widths) + unpacked_gv = [] + for idx, s in enumerate(splits): + unpacked_gv.append((tf.reshape(s, gpt.shapes[idx]), gpt.vars[idx])) + return unpacked_gv + + +def pack_small_tensors(tower_grads, max_bytes=0, max_group=0): + """Concatenate small gradient tensors together for reduction. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. + max_bytes: Int giving max number of bytes in a tensor that + may be considered small. + max_group: Int giving max number of small tensors that may be + concatenated into one new tensor. + + Returns: + new_tower_grads, packing where new_tower_grads is identical to + tower_grads except that all feasible small_tensors have been removed + from their places and concatenated into larger tensors that are + now in the front of the list for each tower, and packing contains + the data necessary to restore the tower_grads structure. + + Look through the first tower for gradients of the same type (float), + and small size, that are all sequential. For each such group, + replace by a new tensor that is a flattened concatenation. Note + that the corresponding variable will be absent, which doesn't matter + because it isn't used during all-reduce. + + Requires: + Every gv_list in towers must have isomorphic structure including identical + tensor sizes and types. + """ + small_indices = [] + large_indices = [] + for idx, (g, _) in enumerate(tower_grads[0]): + if g.dtype == tf.float32 and (4 * g.shape.num_elements()) <= max_bytes: + small_indices.append(idx) + else: + large_indices.append(idx) + small_ranges, small_singles = extract_ranges( + small_indices, range_size_limit=max_group) + large_indices = sorted(large_indices + small_singles) + num_gv = len(tower_grads[0]) + packing = {} + if small_ranges: + new_tower_grads = [] + for dev_idx, gv_list in enumerate(tower_grads): + assert len(gv_list) == num_gv + new_gv_list = [] + for r in small_ranges: + key = '%d:%d' % (dev_idx, len(new_gv_list)) + new_gv_list.append((pack_range(key, packing, gv_list, r), + 'packing_var_placeholder')) + for i in large_indices: + new_gv_list.append(gv_list[i]) + new_tower_grads.append(new_gv_list) + return new_tower_grads, packing + else: + return tower_grads, None + + +def unpack_small_tensors(tower_grads, packing): + """Undo the structure alterations to tower_grads done by pack_small_tensors. + + Args: + tower_grads: List of List of (grad, var) tuples. + packing: A dict generated by pack_small_tensors describing the changes + it made to tower_grads. + + Returns: + new_tower_grads: identical to tower_grads except that concatentations + of small tensors have been split apart and returned to their original + positions, paired with their original variables. + """ + if not packing: + return tower_grads + new_tower_grads = [] + num_devices = len(tower_grads) + num_packed = len(packing.keys()) // num_devices + for dev_idx, gv_list in enumerate(tower_grads): + new_gv_list = gv_list[num_packed:] + for i in xrange(0, num_packed): + k = '%d:%d' % (dev_idx, i) + gpt = packing[k] + gv = unpack_grad_tuple(gv_list[i], gpt) + for gi, idx in enumerate(gpt.indices): + assert idx == gpt.indices[gi] + new_gv_list.insert(idx, gv[gi]) + new_tower_grads.append(new_gv_list) + return new_tower_grads diff --git a/python/ray/experimental/sgd/tfbench/convnet_builder.py b/python/ray/experimental/sgd/tfbench/convnet_builder.py new file mode 100644 index 0000000000000..d0cc2755e0bc1 --- /dev/null +++ b/python/ray/experimental/sgd/tfbench/convnet_builder.py @@ -0,0 +1,467 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""CNN builder.""" + +from __future__ import print_function + +from collections import defaultdict +import contextlib + +import numpy as np + +import tensorflow as tf + +from tensorflow.python.layers import convolutional as conv_layers +from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import pooling as pooling_layers +from tensorflow.python.training import moving_averages + + +class ConvNetBuilder(object): + """Builder of cnn net.""" + + def __init__(self, + input_op, + input_nchan, + phase_train, + use_tf_layers, + data_format='NCHW', + dtype=tf.float32, + variable_dtype=tf.float32): + self.top_layer = input_op + self.top_size = input_nchan + self.phase_train = phase_train + self.use_tf_layers = use_tf_layers + self.data_format = data_format + self.dtype = dtype + self.variable_dtype = variable_dtype + self.counts = defaultdict(lambda: 0) + self.use_batch_norm = False + self.batch_norm_config = {} # 'decay': 0.997, 'scale': True} + self.channel_pos = ('channels_last' + if data_format == 'NHWC' else 'channels_first') + self.aux_top_layer = None + self.aux_top_size = 0 + + def get_custom_getter(self): + """Returns a custom getter that this class's methods must be called under. + + All methods of this class must be called under a variable scope that was + passed this custom getter. Example: + + ```python + network = ConvNetBuilder(...) + with tf.variable_scope('cg', custom_getter=network.get_custom_getter()): + network.conv(...) + # Call more methods of network here + ``` + + Currently, this custom getter only does anything if self.use_tf_layers is + True. In that case, it causes variables to be stored as dtype + self.variable_type, then casted to the requested dtype, instead of directly + storing the variable as the requested dtype. + """ + def inner_custom_getter(getter, *args, **kwargs): + """Custom getter that forces variables to have type self.variable_type.""" + if not self.use_tf_layers: + return getter(*args, **kwargs) + requested_dtype = kwargs['dtype'] + if not (requested_dtype == tf.float32 and + self.variable_dtype == tf.float16): + # Only change the variable dtype if doing so does not decrease variable + # precision. + kwargs['dtype'] = self.variable_dtype + var = getter(*args, **kwargs) + # This if statement is needed to guard the cast, because batch norm + # assigns directly to the return value of this custom getter. The cast + # makes the return value not a variable so it cannot be assigned. Batch + # norm variables are always in fp32 so this if statement is never + # triggered for them. + if var.dtype.base_dtype != requested_dtype: + var = tf.cast(var, requested_dtype) + return var + return inner_custom_getter + + @contextlib.contextmanager + def switch_to_aux_top_layer(self): + """Context that construct cnn in the auxiliary arm.""" + if self.aux_top_layer is None: + raise RuntimeError('Empty auxiliary top layer in the network.') + saved_top_layer = self.top_layer + saved_top_size = self.top_size + self.top_layer = self.aux_top_layer + self.top_size = self.aux_top_size + yield + self.aux_top_layer = self.top_layer + self.aux_top_size = self.top_size + self.top_layer = saved_top_layer + self.top_size = saved_top_size + + def get_variable(self, name, shape, dtype, cast_dtype, *args, **kwargs): + # TODO(reedwm): Currently variables and gradients are transferred to other + # devices and machines as type `dtype`, not `cast_dtype`. In particular, + # this means in fp16 mode, variables are transferred as fp32 values, not + # fp16 values, which uses extra bandwidth. + var = tf.get_variable(name, shape, dtype, *args, **kwargs) + return tf.cast(var, cast_dtype) + + def _conv2d_impl(self, input_layer, num_channels_in, filters, kernel_size, + strides, padding, kernel_initializer): + if self.use_tf_layers: + return conv_layers.conv2d(input_layer, filters, kernel_size, strides, + padding, self.channel_pos, + kernel_initializer=kernel_initializer, + use_bias=False) + else: + weights_shape = [kernel_size[0], kernel_size[1], num_channels_in, filters] + # We use the name 'conv2d/kernel' so the variable has the same name as its + # tf.layers equivalent. This way, if a checkpoint is written when + # self.use_tf_layers == True, it can be loaded when + # self.use_tf_layers == False, and vice versa. + weights = self.get_variable('conv2d/kernel', weights_shape, + self.variable_dtype, self.dtype, + initializer=kernel_initializer) + if self.data_format == 'NHWC': + strides = [1] + strides + [1] + else: + strides = [1, 1] + strides + return tf.nn.conv2d(input_layer, weights, strides, padding, + data_format=self.data_format) + + def conv(self, + num_out_channels, + k_height, + k_width, + d_height=1, + d_width=1, + mode='SAME', + input_layer=None, + num_channels_in=None, + use_batch_norm=None, + stddev=None, + activation='relu', + bias=0.0): + """Construct a conv2d layer on top of cnn.""" + if input_layer is None: + input_layer = self.top_layer + if num_channels_in is None: + num_channels_in = self.top_size + kernel_initializer = None + if stddev is not None: + kernel_initializer = tf.truncated_normal_initializer(stddev=stddev) + name = 'conv' + str(self.counts['conv']) + self.counts['conv'] += 1 + with tf.variable_scope(name): + strides = [1, d_height, d_width, 1] + if self.data_format == 'NCHW': + strides = [strides[0], strides[3], strides[1], strides[2]] + if mode != 'SAME_RESNET': + conv = self._conv2d_impl(input_layer, num_channels_in, num_out_channels, + kernel_size=[k_height, k_width], + strides=[d_height, d_width], padding=mode, + kernel_initializer=kernel_initializer) + else: # Special padding mode for ResNet models + if d_height == 1 and d_width == 1: + conv = self._conv2d_impl(input_layer, num_channels_in, + num_out_channels, + kernel_size=[k_height, k_width], + strides=[d_height, d_width], padding='SAME', + kernel_initializer=kernel_initializer) + else: + rate = 1 # Unused (for 'a trous' convolutions) + kernel_height_effective = k_height + (k_height - 1) * (rate - 1) + pad_h_beg = (kernel_height_effective - 1) // 2 + pad_h_end = kernel_height_effective - 1 - pad_h_beg + kernel_width_effective = k_width + (k_width - 1) * (rate - 1) + pad_w_beg = (kernel_width_effective - 1) // 2 + pad_w_end = kernel_width_effective - 1 - pad_w_beg + padding = [[0, 0], [pad_h_beg, pad_h_end], + [pad_w_beg, pad_w_end], [0, 0]] + if self.data_format == 'NCHW': + padding = [padding[0], padding[3], padding[1], padding[2]] + input_layer = tf.pad(input_layer, padding) + conv = self._conv2d_impl(input_layer, num_channels_in, + num_out_channels, + kernel_size=[k_height, k_width], + strides=[d_height, d_width], padding='VALID', + kernel_initializer=kernel_initializer) + if use_batch_norm is None: + use_batch_norm = self.use_batch_norm + if not use_batch_norm: + if bias is not None: + biases = self.get_variable('biases', [num_out_channels], + self.variable_dtype, self.dtype, + initializer=tf.constant_initializer(bias)) + biased = tf.reshape( + tf.nn.bias_add(conv, biases, data_format=self.data_format), + conv.get_shape()) + else: + biased = conv + else: + self.top_layer = conv + self.top_size = num_out_channels + biased = self.batch_norm(**self.batch_norm_config) + if activation == 'relu': + conv1 = tf.nn.relu(biased) + elif activation == 'linear' or activation is None: + conv1 = biased + elif activation == 'tanh': + conv1 = tf.nn.tanh(biased) + else: + raise KeyError('Invalid activation type \'%s\'' % activation) + self.top_layer = conv1 + self.top_size = num_out_channels + return conv1 + + def _pool(self, + pool_name, + pool_function, + k_height, + k_width, + d_height, + d_width, + mode, + input_layer, + num_channels_in): + """Construct a pooling layer.""" + if input_layer is None: + input_layer = self.top_layer + else: + self.top_size = num_channels_in + name = pool_name + str(self.counts[pool_name]) + self.counts[pool_name] += 1 + if self.use_tf_layers: + pool = pool_function( + input_layer, [k_height, k_width], [d_height, d_width], + padding=mode, + data_format=self.channel_pos, + name=name) + else: + if self.data_format == 'NHWC': + ksize = [1, k_height, k_width, 1] + strides = [1, d_height, d_width, 1] + else: + ksize = [1, 1, k_height, k_width] + strides = [1, 1, d_height, d_width] + pool = tf.nn.max_pool(input_layer, ksize, strides, padding=mode, + data_format=self.data_format, name=name) + self.top_layer = pool + return pool + + def mpool(self, + k_height, + k_width, + d_height=2, + d_width=2, + mode='VALID', + input_layer=None, + num_channels_in=None): + """Construct a max pooling layer.""" + return self._pool('mpool', pooling_layers.max_pooling2d, k_height, k_width, + d_height, d_width, mode, input_layer, num_channels_in) + + def apool(self, + k_height, + k_width, + d_height=2, + d_width=2, + mode='VALID', + input_layer=None, + num_channels_in=None): + """Construct an average pooling layer.""" + return self._pool('apool', pooling_layers.average_pooling2d, k_height, + k_width, d_height, d_width, mode, input_layer, + num_channels_in) + + def reshape(self, shape, input_layer=None): + if input_layer is None: + input_layer = self.top_layer + self.top_layer = tf.reshape(input_layer, shape) + self.top_size = shape[-1] # HACK This may not always work + return self.top_layer + + def affine(self, + num_out_channels, + input_layer=None, + num_channels_in=None, + bias=0.0, + stddev=None, + activation='relu'): + if input_layer is None: + input_layer = self.top_layer + if num_channels_in is None: + num_channels_in = self.top_size + name = 'affine' + str(self.counts['affine']) + self.counts['affine'] += 1 + with tf.variable_scope(name): + init_factor = 2. if activation == 'relu' else 1. + stddev = stddev or np.sqrt(init_factor / num_channels_in) + kernel = self.get_variable( + 'weights', [num_channels_in, num_out_channels], + self.variable_dtype, self.dtype, + initializer=tf.truncated_normal_initializer(stddev=stddev)) + biases = self.get_variable('biases', [num_out_channels], + self.variable_dtype, self.dtype, + initializer=tf.constant_initializer(bias)) + logits = tf.nn.xw_plus_b(input_layer, kernel, biases) + if activation == 'relu': + affine1 = tf.nn.relu(logits, name=name) + elif activation == 'linear' or activation is None: + affine1 = logits + else: + raise KeyError('Invalid activation type \'%s\'' % activation) + self.top_layer = affine1 + self.top_size = num_out_channels + return affine1 + + def inception_module(self, name, cols, input_layer=None, in_size=None): + if input_layer is None: + input_layer = self.top_layer + if in_size is None: + in_size = self.top_size + name += str(self.counts[name]) + self.counts[name] += 1 + with tf.variable_scope(name): + col_layers = [] + col_layer_sizes = [] + for c, col in enumerate(cols): + col_layers.append([]) + col_layer_sizes.append([]) + for l, layer in enumerate(col): + ltype, args = layer[0], layer[1:] + kwargs = { + 'input_layer': input_layer, + 'num_channels_in': in_size + } if l == 0 else {} + if ltype == 'conv': + self.conv(*args, **kwargs) + elif ltype == 'mpool': + self.mpool(*args, **kwargs) + elif ltype == 'apool': + self.apool(*args, **kwargs) + elif ltype == 'share': # Share matching layer from previous column + self.top_layer = col_layers[c - 1][l] + self.top_size = col_layer_sizes[c - 1][l] + else: + raise KeyError( + 'Invalid layer type for inception module: \'%s\'' % ltype) + col_layers[c].append(self.top_layer) + col_layer_sizes[c].append(self.top_size) + catdim = 3 if self.data_format == 'NHWC' else 1 + self.top_layer = tf.concat([layers[-1] for layers in col_layers], catdim) + self.top_size = sum([sizes[-1] for sizes in col_layer_sizes]) + return self.top_layer + + def spatial_mean(self, keep_dims=False): + name = 'spatial_mean' + str(self.counts['spatial_mean']) + self.counts['spatial_mean'] += 1 + axes = [1, 2] if self.data_format == 'NHWC' else [2, 3] + self.top_layer = tf.reduce_mean( + self.top_layer, axes, keep_dims=keep_dims, name=name) + return self.top_layer + + def dropout(self, keep_prob=0.5, input_layer=None): + if input_layer is None: + input_layer = self.top_layer + else: + self.top_size = None + name = 'dropout' + str(self.counts['dropout']) + with tf.variable_scope(name): + if not self.phase_train: + keep_prob = 1.0 + if self.use_tf_layers: + dropout = core_layers.dropout(input_layer, 1. - keep_prob) + else: + dropout = tf.nn.dropout(input_layer, keep_prob) + self.top_layer = dropout + return dropout + + def _batch_norm_without_layers(self, input_layer, decay, use_scale, epsilon): + """Batch normalization on `input_layer` without tf.layers.""" + # We make this function as similar as possible to the + # tf.contrib.layers.batch_norm, to minimize the differences between using + # layers and not using layers. + shape = input_layer.shape + num_channels = shape[3] if self.data_format == 'NHWC' else shape[1] + beta = self.get_variable('beta', [num_channels], tf.float32, tf.float32, + initializer=tf.zeros_initializer()) + if use_scale: + gamma = self.get_variable('gamma', [num_channels], tf.float32, + tf.float32, initializer=tf.ones_initializer()) + else: + gamma = tf.constant(1.0, tf.float32, [num_channels]) + # For moving variables, we use tf.get_variable instead of self.get_variable, + # since self.get_variable returns the result of tf.cast which we cannot + # assign to. + moving_mean = tf.get_variable('moving_mean', [num_channels], + tf.float32, + initializer=tf.zeros_initializer(), + trainable=False) + moving_variance = tf.get_variable('moving_variance', [num_channels], + tf.float32, + initializer=tf.ones_initializer(), + trainable=False) + if self.phase_train: + bn, batch_mean, batch_variance = tf.nn.fused_batch_norm( + input_layer, gamma, beta, epsilon=epsilon, + data_format=self.data_format, is_training=True) + mean_update = moving_averages.assign_moving_average( + moving_mean, batch_mean, decay=decay, zero_debias=False) + variance_update = moving_averages.assign_moving_average( + moving_variance, batch_variance, decay=decay, zero_debias=False) + tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, mean_update) + tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, variance_update) + else: + bn, _, _ = tf.nn.fused_batch_norm( + input_layer, gamma, beta, mean=moving_mean, + variance=moving_variance, epsilon=epsilon, + data_format=self.data_format, is_training=False) + return bn + + def batch_norm(self, input_layer=None, decay=0.999, scale=False, + epsilon=0.001): + """Adds a Batch Normalization layer.""" + if input_layer is None: + input_layer = self.top_layer + else: + self.top_size = None + name = 'batchnorm' + str(self.counts['batchnorm']) + self.counts['batchnorm'] += 1 + + with tf.variable_scope(name) as scope: + if self.use_tf_layers: + bn = tf.contrib.layers.batch_norm( + input_layer, + decay=decay, + scale=scale, + epsilon=epsilon, + is_training=self.phase_train, + fused=True, + data_format=self.data_format, + scope=scope) + else: + bn = self._batch_norm_without_layers(input_layer, decay, scale, epsilon) + self.top_layer = bn + self.top_size = bn.shape[3] if self.data_format == 'NHWC' else bn.shape[1] + self.top_size = int(self.top_size) + return bn + + def lrn(self, depth_radius, bias, alpha, beta): + """Adds a local response normalization layer.""" + name = 'lrn' + str(self.counts['lrn']) + self.counts['lrn'] += 1 + self.top_layer = tf.nn.lrn( + self.top_layer, depth_radius, bias, alpha, beta, name=name) + return self.top_layer diff --git a/python/ray/experimental/sgd/tfbench/model.py b/python/ray/experimental/sgd/tfbench/model.py new file mode 100644 index 0000000000000..1c3fd658f676b --- /dev/null +++ b/python/ray/experimental/sgd/tfbench/model.py @@ -0,0 +1,114 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base model configuration for CNN benchmarks.""" +import tensorflow as tf + +from . import convnet_builder + + +class Model(object): + """Base model configuration for CNN benchmarks.""" + + def __init__(self, + model, + image_size, + batch_size, + learning_rate, + layer_counts=None, + fp16_loss_scale=128): + self.model = model + self.image_size = image_size + self.batch_size = batch_size + self.default_batch_size = batch_size + self.learning_rate = learning_rate + self.layer_counts = layer_counts + # TODO(reedwm) Set custom loss scales for each model instead of using the + # default of 128. + self.fp16_loss_scale = fp16_loss_scale + + def get_model(self): + return self.model + + def get_image_size(self): + return self.image_size + + def get_batch_size(self): + return self.batch_size + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + + def get_default_batch_size(self): + return self.default_batch_size + + def get_layer_counts(self): + return self.layer_counts + + def get_fp16_loss_scale(self): + return self.fp16_loss_scale + + def get_learning_rate(self, global_step, batch_size): + del global_step + del batch_size + return self.learning_rate + + def add_inference(self, unused_cnn): + raise ValueError('Must be implemented in derived classes') + + def skip_final_affine_layer(self): + """Returns if the caller of this class should skip the final affine layer. + + Normally, this class adds a final affine layer to the model after calling + self.add_inference(), to generate the logits. If a subclass override this + method to return True, the caller should not add the final affine layer. + + This is useful for tests. + """ + return False + + def build_network(self, images, phase_train=True, nclass=1001, image_depth=3, + data_type=tf.float32, data_format='NCHW', + use_tf_layers=True, fp16_vars=False): + """Returns logits and aux_logits from images.""" + if data_format == 'NCHW': + images = tf.transpose(images, [0, 3, 1, 2]) + var_type = tf.float32 + if data_type == tf.float16 and fp16_vars: + var_type = tf.float16 + network = convnet_builder.ConvNetBuilder( + images, image_depth, phase_train, use_tf_layers, + data_format, data_type, var_type) + with tf.variable_scope('cg', custom_getter=network.get_custom_getter()): + self.add_inference(network) + # Add the final fully-connected class layer + logits = (network.affine(nclass, activation='linear') + if not self.skip_final_affine_layer() + else network.top_layer) + aux_logits = None + if network.aux_top_layer is not None: + with network.switch_to_aux_top_layer(): + aux_logits = network.affine( + nclass, activation='linear', stddev=0.001) + if data_type == tf.float16: + # TODO(reedwm): Determine if we should do this cast here. + logits = tf.cast(logits, tf.float32) + if aux_logits is not None: + aux_logits = tf.cast(aux_logits, tf.float32) + return logits, aux_logits + + # Subclasses can override this to define their own loss function. By default, + # benchmark_cnn.py defines its own loss function. If overridden, it must have + # the same signature as benchmark_cnn.loss_function. + loss_function = None diff --git a/python/ray/experimental/sgd/tfbench/model_config.py b/python/ray/experimental/sgd/tfbench/model_config.py new file mode 100644 index 0000000000000..e993cf63fa33e --- /dev/null +++ b/python/ray/experimental/sgd/tfbench/model_config.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Model configurations for CNN benchmarks. +""" + +from . import resnet_model + + +_model_name_to_imagenet_model = { + 'resnet50': resnet_model.create_resnet50_model, + 'resnet50_v2': resnet_model.create_resnet50_v2_model, + 'resnet101': resnet_model.create_resnet101_model, + 'resnet101_v2': resnet_model.create_resnet101_v2_model, + 'resnet152': resnet_model.create_resnet152_model, + 'resnet152_v2': resnet_model.create_resnet152_v2_model, +} + + +_model_name_to_cifar_model = { +} + + +def _get_model_map(dataset_name): + if 'cifar10' == dataset_name: + return _model_name_to_cifar_model + elif dataset_name in ('imagenet', 'synthetic'): + return _model_name_to_imagenet_model + else: + raise ValueError('Invalid dataset name: %s' % dataset_name) + + +def get_model_config(model_name, dataset): + """Map model name to model network configuration.""" + model_map = _get_model_map(dataset.name) + if model_name not in model_map: + raise ValueError('Invalid model name \'%s\' for dataset \'%s\'' % + (model_name, dataset.name)) + else: + return model_map[model_name]() + + +def register_model(model_name, dataset_name, model_func): + """Register a new model that can be obtained with `get_model_config`.""" + model_map = _get_model_map(dataset_name) + if model_name in model_map: + raise ValueError('Model "%s" is already registered for dataset "%s"' % + (model_name, dataset_name)) + model_map[model_name] = model_func diff --git a/python/ray/experimental/sgd/tfbench/modified_allreduce.py b/python/ray/experimental/sgd/tfbench/modified_allreduce.py new file mode 100644 index 0000000000000..5334daa57dc9a --- /dev/null +++ b/python/ray/experimental/sgd/tfbench/modified_allreduce.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for allreduce.""" + +from __future__ import print_function + +import collections as pycoll +import re +import numpy as np + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.contrib import nccl +from tensorflow.contrib.all_reduce.python import all_reduce +from allreduce import * + + +def sum_gradients_all_reduce(dev_prefixes, + tower_grads, + num_workers, + alg, + num_shards, + gpu_indices, + agg_small_grads_max_bytes=0): + """Apply all-reduce algorithm over specified gradient tensors. + + Args: + dev_prefixes: list of prefix strings to use to generate PS device names. + tower_grads: the gradients to reduce. + num_workers: number of worker processes across entire job. + alg: the all-reduce algorithm to apply. + num_shards: alg-specific sharding factor. + gpu_indices: indices of local GPUs in order usable for ring-reduce. + agg_small_grads_max_bytes: largest tensor eligible for aggregation, + in number of bytes. + + Returns: + list of reduced tensors, packing values + """ + alg_contains_shuffle = contains_any(alg, ['pscpu', 'psgpu']) + is_hierarchical = '/' in alg + if 'pscpu' in alg: + aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] + elif 'psgpu' in alg: + aux_devices = [ + prefix + '/gpu:%d' % i + for i in range(len(gpu_indices)) + for prefix in dev_prefixes + ] + else: + aux_devices = ['/job:localhost/cpu:0'] + aux_device_groups = group_device_names(aux_devices, num_shards + if alg_contains_shuffle else 1) + group_index = 0 + if agg_small_grads_max_bytes > 0: + tower_grads, packing = pack_small_tensors( + tower_grads, + max_bytes=agg_small_grads_max_bytes) + + else: + packing = None + new_tower_grads = [] + if alg == 'better': + raw_devices = ['/gpu:%i' % (i) for i in gpu_indices] + agg_grads = aggregate_gradients_using_copy_with_device_selection( + tower_grads, raw_devices) + for arr in tower_grads: + new_tower_grads.append( + [(g, v) for (_, v), (g, _) in zip(arr, agg_grads)]) + else: + reduced_gv_list = [] + for grad_and_vars in zip(*tower_grads): + reduced_gv_list.append( + sum_grad_and_var_all_reduce( + grad_and_vars, num_workers, alg, gpu_indices, aux_devices + if is_hierarchical else aux_device_groups[group_index], num_shards)) + group_index = (group_index + 1) % len(aux_device_groups) + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return new_tower_grads, packing + +def print_stats(sizes): + def sizeof_fmt(num, suffix='B'): + for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + stats = { + "avg": np.mean(sizes), + "median": np.median(sizes), + "total size": np.sum(sizes) + } + print("Stats " + ", ".join( + ["%s: %s" % (k, sizeof_fmt(v)) for k, v in stats.items()])) + other_stats = { + "len": len(sizes) + } + print(", ".join(["%s: %f" % (k, v) for k, v in other_stats.items()])) + + +def pack_small_tensors(tower_grads, max_bytes=0): + """Concatenate gradients together more intelligently. + + Does binpacking + Args: + tower_grads: List of lists of (gradient, variable) tuples. + max_bytes: Int giving max number of bytes in a tensor that + may be considered small. + """ + assert max_bytes >= 0 + orig_grads = [g for g, _ in tower_grads[0]] + # Check to make sure sizes are accurate; not entirely important + assert all(g.dtype == tf.float32 for g in orig_grads) + sizes = [4 * g.shape.num_elements() for g in orig_grads] + print("Before packing") + print_stats(sizes) + small_ranges = [] + large_indices = [] + new_sizes = [] + + def end_interval(indices, small_ranges, large_indices): + if len(indices) > 1: + small_ranges.insert(0, [indices[0], indices[-1]]) + else: + large_indices.insert(0, indices[0]) + + cur_range = [] + cur_size = 0 + for i, s in reversed(list(enumerate(sizes))): + if cur_size > max_bytes: + end_interval(cur_range, small_ranges, large_indices) + new_sizes.insert(0, cur_size) + cur_range = [] + cur_size = 0 + cur_range.insert(0, i) + cur_size += s + end_interval(cur_range, small_ranges, large_indices) + new_sizes.insert(0, cur_size) + + print("After packing") + print_stats(new_sizes) + num_gv = len(orig_grads) + packing = {} + if len(small_ranges): + new_tower_grads = [] + for dev_idx, gv_list in enumerate(tower_grads): + assert len(gv_list) == num_gv + new_gv_list = [] + for r in small_ranges: + key = '%d:%d' % (dev_idx, len(new_gv_list)) + new_gv_list.append((pack_range(key, packing, gv_list, r), + 'packing_var_placeholder')) + for i in large_indices: + new_gv_list.append(gv_list[i]) + new_tower_grads.append(new_gv_list) + return new_tower_grads, packing + else: + return tower_grads, None diff --git a/python/ray/experimental/sgd/tfbench/resnet_model.py b/python/ray/experimental/sgd/tfbench/resnet_model.py new file mode 100644 index 0000000000000..f2f348a02bd9b --- /dev/null +++ b/python/ray/experimental/sgd/tfbench/resnet_model.py @@ -0,0 +1,346 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Resnet model configuration. + +References: + Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition + arXiv:1512.03385 (2015) + + Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Identity Mappings in Deep Residual Networks + arXiv:1603.05027 (2016) + + Liang-Chieh Chen, George Papandreou, Iasonas Kokkinos, Kevin Murphy, + Alan L. Yuille + DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, + Atrous Convolution, and Fully Connected CRFs + arXiv:1606.00915 (2016) +""" + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf +from . import model as model_lib + + +def bottleneck_block_v1(cnn, depth, depth_bottleneck, stride): + """Bottleneck block with identity short-cut for ResNet v1. + + Args: + cnn: the network to append bottleneck blocks. + depth: the number of output filters for this bottleneck block. + depth_bottleneck: the number of bottleneck filters for this block. + stride: Stride used in the first layer of the bottleneck block. + """ + input_layer = cnn.top_layer + in_size = cnn.top_size + name_key = 'resnet_v1' + name = name_key + str(cnn.counts[name_key]) + cnn.counts[name_key] += 1 + + with tf.variable_scope(name): + if depth == in_size: + if stride == 1: + shortcut = input_layer + else: + shortcut = cnn.apool( + 1, 1, stride, stride, input_layer=input_layer, + num_channels_in=in_size) + else: + shortcut = cnn.conv( + depth, 1, 1, stride, stride, activation=None, + use_batch_norm=True, input_layer=input_layer, + num_channels_in=in_size, bias=None) + cnn.conv(depth_bottleneck, 1, 1, stride, stride, + input_layer=input_layer, num_channels_in=in_size, + use_batch_norm=True, bias=None) + cnn.conv(depth_bottleneck, 3, 3, 1, 1, mode='SAME_RESNET', + use_batch_norm=True, bias=None) + res = cnn.conv(depth, 1, 1, 1, 1, activation=None, + use_batch_norm=True, bias=None) + output = tf.nn.relu(shortcut + res) + cnn.top_layer = output + cnn.top_size = depth + + +def bottleneck_block_v2(cnn, depth, depth_bottleneck, stride): + """Bottleneck block with identity short-cut for ResNet v2. + + The main difference from v1 is that a batch norm and relu are done at the + start of the block, instead of the end. This initial batch norm and relu is + collectively called a pre-activation. + + Args: + cnn: the network to append bottleneck blocks. + depth: the number of output filters for this bottleneck block. + depth_bottleneck: the number of bottleneck filters for this block. + stride: Stride used in the first layer of the bottleneck block. + """ + input_layer = cnn.top_layer + in_size = cnn.top_size + name_key = 'resnet_v2' + name = name_key + str(cnn.counts[name_key]) + cnn.counts[name_key] += 1 + + preact = cnn.batch_norm() + preact = tf.nn.relu(preact) + with tf.variable_scope(name): + if depth == in_size: + if stride == 1: + shortcut = input_layer + else: + shortcut = cnn.apool( + 1, 1, stride, stride, input_layer=input_layer, + num_channels_in=in_size) + else: + shortcut = cnn.conv( + depth, 1, 1, stride, stride, activation=None, use_batch_norm=False, + input_layer=preact, num_channels_in=in_size, bias=None) + cnn.conv(depth_bottleneck, 1, 1, stride, stride, + input_layer=preact, num_channels_in=in_size, + use_batch_norm=True, bias=None) + cnn.conv(depth_bottleneck, 3, 3, 1, 1, mode='SAME_RESNET', + use_batch_norm=True, bias=None) + res = cnn.conv(depth, 1, 1, 1, 1, activation=None, + use_batch_norm=False, bias=None) + output = shortcut + res + cnn.top_layer = output + cnn.top_size = depth + + +def bottleneck_block(cnn, depth, depth_bottleneck, stride, pre_activation): + """Bottleneck block with identity short-cut. + + Args: + cnn: the network to append bottleneck blocks. + depth: the number of output filters for this bottleneck block. + depth_bottleneck: the number of bottleneck filters for this block. + stride: Stride used in the first layer of the bottleneck block. + pre_activation: use pre_activation structure used in v2 or not. + """ + if pre_activation: + bottleneck_block_v2(cnn, depth, depth_bottleneck, stride) + else: + bottleneck_block_v1(cnn, depth, depth_bottleneck, stride) + + +def residual_block(cnn, depth, stride, pre_activation): + """Residual block with identity short-cut. + + Args: + cnn: the network to append residual blocks. + depth: the number of output filters for this residual block. + stride: Stride used in the first layer of the residual block. + pre_activation: use pre_activation structure or not. + """ + input_layer = cnn.top_layer + in_size = cnn.top_size + if in_size != depth: + # Plan A of shortcut. + shortcut = cnn.apool(1, 1, stride, stride, + input_layer=input_layer, + num_channels_in=in_size) + padding = (depth - in_size) // 2 + if cnn.channel_pos == 'channels_last': + shortcut = tf.pad( + shortcut, [[0, 0], [0, 0], [0, 0], [padding, padding]]) + else: + shortcut = tf.pad( + shortcut, [[0, 0], [padding, padding], [0, 0], [0, 0]]) + else: + shortcut = input_layer + if pre_activation: + res = cnn.batch_norm(input_layer) + res = tf.nn.relu(res) + else: + res = input_layer + cnn.conv(depth, 3, 3, stride, stride, + input_layer=res, num_channels_in=in_size, + use_batch_norm=True, bias=None) + if pre_activation: + res = cnn.conv(depth, 3, 3, 1, 1, activation=None, + use_batch_norm=False, bias=None) + output = shortcut + res + else: + res = cnn.conv(depth, 3, 3, 1, 1, activation=None, + use_batch_norm=True, bias=None) + output = tf.nn.relu(shortcut + res) + cnn.top_layer = output + cnn.top_size = depth + + +class ResnetModel(model_lib.Model): + """Resnet cnn network configuration.""" + + def __init__(self, model, layer_counts): + default_batch_sizes = { + 'resnet50': 64, + 'resnet101': 32, + 'resnet152': 32, + 'resnet50_v2': 64, + 'resnet101_v2': 32, + 'resnet152_v2': 32, + } + batch_size = default_batch_sizes.get(model, 32) + super(ResnetModel, self).__init__(model, 224, batch_size, 0.005, + layer_counts) + self.pre_activation = 'v2' in model + + def add_inference(self, cnn): + if self.layer_counts is None: + raise ValueError('Layer counts not specified for %s' % self.get_model()) + cnn.use_batch_norm = True + cnn.batch_norm_config = {'decay': 0.997, 'epsilon': 1e-5, 'scale': True} + cnn.conv(64, 7, 7, 2, 2, mode='SAME_RESNET', use_batch_norm=True) + cnn.mpool(3, 3, 2, 2, mode='SAME') + for _ in xrange(self.layer_counts[0]): + bottleneck_block(cnn, 256, 64, 1, self.pre_activation) + for i in xrange(self.layer_counts[1]): + stride = 2 if i == 0 else 1 + bottleneck_block(cnn, 512, 128, stride, self.pre_activation) + for i in xrange(self.layer_counts[2]): + stride = 2 if i == 0 else 1 + bottleneck_block(cnn, 1024, 256, stride, self.pre_activation) + for i in xrange(self.layer_counts[3]): + stride = 2 if i == 0 else 1 + bottleneck_block(cnn, 2048, 512, stride, self.pre_activation) + if self.pre_activation: + cnn.batch_norm() + cnn.top_layer = tf.nn.relu(cnn.top_layer) + cnn.spatial_mean() + + def get_learning_rate(self, global_step, batch_size): + num_batches_per_epoch = ( + float(datasets.IMAGENET_NUM_TRAIN_IMAGES) / batch_size) + boundaries = [int(num_batches_per_epoch * x) for x in [30, 60]] + values = [0.1, 0.01, 0.001] + return tf.train.piecewise_constant(global_step, boundaries, values) + + +def create_resnet50_model(): + return ResnetModel('resnet50', (3, 4, 6, 3)) + + +def create_resnet50_v2_model(): + return ResnetModel('resnet50_v2', (3, 4, 6, 3)) + + +def create_resnet101_model(): + return ResnetModel('resnet101', (3, 4, 23, 3)) + + +def create_resnet101_v2_model(): + return ResnetModel('resnet101_v2', (3, 4, 23, 3)) + + +def create_resnet152_model(): + return ResnetModel('resnet152', (3, 8, 36, 3)) + + +def create_resnet152_v2_model(): + return ResnetModel('resnet152_v2', (3, 8, 36, 3)) + + +class ResnetCifar10Model(model_lib.Model): + """Resnet cnn network configuration for Cifar 10 dataset. + + V1 model architecture follows the one defined in the paper: + https://arxiv.org/pdf/1512.03385.pdf. + + V2 model architecture follows the one defined in the paper: + https://arxiv.org/pdf/1603.05027.pdf. + """ + + def __init__(self, model, layer_counts): + self.pre_activation = 'v2' in model + super(ResnetCifar10Model, self).__init__( + model, 32, 128, 0.1, layer_counts) + + def add_inference(self, cnn): + if self.layer_counts is None: + raise ValueError('Layer counts not specified for %s' % self.get_model()) + + cnn.use_batch_norm = True + cnn.batch_norm_config = {'decay': 0.9, 'epsilon': 1e-5, 'scale': True} + if self.pre_activation: + cnn.conv(16, 3, 3, 1, 1, use_batch_norm=True) + else: + cnn.conv(16, 3, 3, 1, 1, activation=None, use_batch_norm=True) + for i in xrange(self.layer_counts[0]): + # reshape to batch_size x 16 x 32 x 32 + residual_block(cnn, 16, 1, self.pre_activation) + for i in xrange(self.layer_counts[1]): + # Subsampling is performed at the first convolution with a stride of 2 + stride = 2 if i == 0 else 1 + # reshape to batch_size x 32 x 16 x 16 + residual_block(cnn, 32, stride, self.pre_activation) + for i in xrange(self.layer_counts[2]): + stride = 2 if i == 0 else 1 + # reshape to batch_size x 64 x 8 x 8 + residual_block(cnn, 64, stride, self.pre_activation) + if self.pre_activation: + cnn.batch_norm() + cnn.top_layer = tf.nn.relu(cnn.top_layer) + cnn.spatial_mean() + + def get_learning_rate(self, global_step, batch_size): + num_batches_per_epoch = int(50000 / batch_size) + boundaries = num_batches_per_epoch * np.array([82, 123, 300], + dtype=np.int64) + boundaries = [x for x in boundaries] + values = [0.1, 0.01, 0.001, 0.0002] + return tf.train.piecewise_constant(global_step, boundaries, values) + + +def create_resnet20_cifar_model(): + return ResnetCifar10Model('resnet20', (3, 3, 3)) + + +def create_resnet20_v2_cifar_model(): + return ResnetCifar10Model('resnet20_v2', (3, 3, 3)) + + +def create_resnet32_cifar_model(): + return ResnetCifar10Model('resnet32_v2', (5, 5, 5)) + + +def create_resnet32_v2_cifar_model(): + return ResnetCifar10Model('resnet32_v2', (5, 5, 5)) + + +def create_resnet44_cifar_model(): + return ResnetCifar10Model('resnet44', (7, 7, 7)) + + +def create_resnet44_v2_cifar_model(): + return ResnetCifar10Model('resnet44_v2', (7, 7, 7)) + + +def create_resnet56_cifar_model(): + return ResnetCifar10Model('resnet56', (9, 9, 9)) + + +def create_resnet56_v2_cifar_model(): + return ResnetCifar10Model('resnet56_v2', (9, 9, 9)) + + +def create_resnet110_cifar_model(): + return ResnetCifar10Model('resnet110', (18, 18, 18)) + + +def create_resnet110_v2_cifar_model(): + return ResnetCifar10Model('resnet110_v2', (18, 18, 18)) From ec2c21b05bedc8f02b65cc1a715acef6b24075e3 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 14 May 2018 20:04:01 -0700 Subject: [PATCH 02/17] idx --- python/ray/experimental/sgd/test_sgd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index aff28c3194256..60b1d6c49a035 100644 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -16,7 +16,7 @@ ray.init() model_creator = ( - lambda i, j: TFBenchModel(batch=1, use_cpus=True)) + lambda worker_idx, device_idx: TFBenchModel(batch=1, use_cpus=True)) sgd = DistributedSGD( model_creator, num_workers=2, devices_per_worker=2, use_cpus=True) From aab1d5a09db9a9d45fed4fe459f7a244f89877bf Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 16 May 2018 11:07:06 -0700 Subject: [PATCH 03/17] foreach_worker foreach_model --- python/ray/experimental/sgd/sgd.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index 9358c275ef515..5eaf6b6f6573b 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -195,6 +195,12 @@ def __init__(self, tf.local_variables_initializer()) self.sess.run(init_op) + def foreach_model(self, fn): + return [fn(m) for m in self.models] + + def foreach_worker(self, fn): + return fn(self) + def compute_gradients(self, verbose): start = time.time() fetches = self.sess.run( @@ -477,5 +483,16 @@ def __init__( num_devices=devices_per_worker, use_cpus=use_cpus, verbose=True)) + def foreach_worker(self, fn): + results = ray.get([w.foreach_worker.remote(fn) for w in self.workers]) + return results + + def foreach_model(self, fn): + results = ray.get([w.foreach_model.remote(fn) for w in self.workers]) + out = [] + for r in results: + out.extend(r) + return r + def step(self): return do_sgd_step(self.workers, True) From bc1ce2a811ada7bbddd9ea28c1cf09664c6b1c10 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 13 Jun 2018 15:40:00 -0700 Subject: [PATCH 04/17] add feed_dict --- python/ray/experimental/sgd/sgd.py | 8 +++++++- python/ray/experimental/sgd/test_model.py | 3 +++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index 5eaf6b6f6573b..4d643b78023d4 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -203,9 +203,15 @@ def foreach_worker(self, fn): def compute_gradients(self, verbose): start = time.time() + feed_dict = {} + # Aggregate feed dicts for each model on this worker. + for model in self.models: + feed_dict.update(model.get_feed_dict()) + # We only need to fetch the first per_device_grad, since they are + # averaged across all devices by allreduce. fetches = self.sess.run( [self.models[0].loss, self.per_device_grads[0], - self.nccl_control_out]) + self.nccl_control_out], feed_dict=feed_dict) if verbose: print("compute grad interior time", time.time() - start) return fetches diff --git a/python/ray/experimental/sgd/test_model.py b/python/ray/experimental/sgd/test_model.py index 445746bd432f5..d5b4ff56fc65a 100644 --- a/python/ray/experimental/sgd/test_model.py +++ b/python/ray/experimental/sgd/test_model.py @@ -42,3 +42,6 @@ def __init__(self, batch=64, use_cpus=False): logits=logits, labels=self.labels) self.loss = tf.reduce_mean(loss, name='xentropy-loss') self.optimizer = tf.train.GradientDescentOptimizer(1e-6) + + def get_feed_dict(self): + return {} From 6a0a0682f80156fdafec56aae338d2c71cbb4d6d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Sep 2018 18:21:40 -0700 Subject: [PATCH 05/17] update --- python/ray/experimental/sgd/__init__.py | 0 .../sgd/{ => example}/test_model.py | 0 python/ray/experimental/sgd/sgd.py | 33 +- python/ray/experimental/sgd/test_sgd.py | 4 +- .../ray/experimental/sgd/tfbench/allreduce.py | 590 ------------------ .../sgd/tfbench/modified_allreduce.py | 450 ++++++++++++- .../sgd/{chrome_timeline.py => util.py} | 23 + 7 files changed, 478 insertions(+), 622 deletions(-) create mode 100644 python/ray/experimental/sgd/__init__.py rename python/ray/experimental/sgd/{ => example}/test_model.py (100%) delete mode 100644 python/ray/experimental/sgd/tfbench/allreduce.py rename python/ray/experimental/sgd/{chrome_timeline.py => util.py} (72%) diff --git a/python/ray/experimental/sgd/__init__.py b/python/ray/experimental/sgd/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/ray/experimental/sgd/test_model.py b/python/ray/experimental/sgd/example/test_model.py similarity index 100% rename from python/ray/experimental/sgd/test_model.py rename to python/ray/experimental/sgd/example/test_model.py diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index 4d643b78023d4..5496453918f7f 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -13,31 +13,8 @@ import tensorflow.contrib.nccl as nccl import tensorflow.contrib.slim as slim -from chrome_timeline import Timeline -from tfbench import allreduce - - -def fetch(oids): - for o in oids: - plasma_id = ray.pyarrow.plasma.ObjectID(o) - ray.worker.global_worker.plasma_client.fetch([plasma_id]) - - -def run_timeline(sess, ops, feed_dict={}, write_timeline=False, name=""): - if write_timeline: - run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - run_metadata = tf.RunMetadata() - fetches = sess.run( - ops, options=run_options, run_metadata=run_metadata, - feed_dict=feed_dict) - trace = timeline.Timeline(step_stats=run_metadata.step_stats) - outf = "timeline-{}-{}.json".format(name, os.getpid()) - trace_file = open(outf, "w") - print("wrote tf timeline to", os.path.abspath(outf)) - trace_file.write(trace.generate_chrome_trace_format()) - else: - fetches = sess.run(ops, feed_dict=feed_dict) - return fetches +from util import Timeline, fetch, run_timeline +from tfbench import modified_allreduce class SGDWorker(object): @@ -86,7 +63,6 @@ def __init__(self, self.packed_grads_and_vars = grad_ops else: if max_bytes: - from tfbench import modified_allreduce self.packed_grads_and_vars, packing_vals = ( modified_allreduce.sum_gradients_all_reduce( "", grad_ops, 1, all_reduce_alg, 1, @@ -94,9 +70,10 @@ def __init__(self, agg_small_grads_max_bytes=max_bytes)) else: self.packed_grads_and_vars = ( - allreduce.sum_gradients_all_reduce( + modified_allreduce.sum_gradients_all_reduce( "", grad_ops, 1, all_reduce_alg, 1, - list(range(num_devices)))) + list(range(num_devices)), + agg_small_grads_max_bytes=0)) self.per_device_grads = [ list(zip(*dev_gv))[0] for dev_gv in self.packed_grads_and_vars] assert(len(self.per_device_grads) == num_devices) diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index 60b1d6c49a035..37ea18a3c3e4c 100644 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -8,8 +8,8 @@ import numpy as np import tensorflow as tf -from test_model import TFBenchModel -from sgd import DistributedSGD +from ray.experimental.sgd.example.test_model import TFBenchModel +from ray.experimental.sgd.sgd import DistributedSGD if __name__ == "__main__": diff --git a/python/ray/experimental/sgd/tfbench/allreduce.py b/python/ray/experimental/sgd/tfbench/allreduce.py deleted file mode 100644 index 06da6c577a623..0000000000000 --- a/python/ray/experimental/sgd/tfbench/allreduce.py +++ /dev/null @@ -1,590 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for allreduce.""" - -from __future__ import print_function - -import collections as pycoll -import re - -from six.moves import xrange # pylint: disable=redefined-builtin -import tensorflow as tf - -from tensorflow.contrib import nccl -from tensorflow.contrib.all_reduce.python import all_reduce - -AllReduceSpecTuple = pycoll.namedtuple('AllReduceSpecTuple', 'alg shards limit') - - -def parse_general_int(s): - """Parse integer with power-of-2 suffix eg. 32k.""" - mo = re.match(r'(\d+)([KkMGT]?)$', s) - if mo: - i, suffix = mo.group(1, 2) - v = int(i) - if suffix: - if suffix == 'K' or suffix == 'k': - v *= 1024 - elif suffix == 'M': - v *= (1024 * 1024) - elif suffix == 'G': - v *= (1024 * 1024 * 1024) - elif suffix == 'T': - v *= (1024 * 1024 * 1024 * 1024) - else: - raise ValueError('invalid integer string %s' % s) - return v - else: - v = int(s) - return v - - -def parse_all_reduce_spec(all_reduce_spec): - """Parse all_reduce_spec. - - Args: - all_reduce_spec: a string specifying a combination of all-reduce - algorithms to apply for gradient reduction. - - Returns: - a list of AllReduceSpecTuple. - - Raises: - ValueError: all_reduce_spec is not well-formed. - - An all_reduce_spec has BNF form: - int ::= positive whole number - g_int ::= int[KkMGT]? - alg_spec ::= alg | alg#int - range_spec ::= alg_spec | alg_spec/alg_spec - spec ::= range_spec | range_spec:g_int:range_spec - - Not all syntactically correct specifications are supported. - Examples of supported all_reduce_spec strings, with semantics explained: - - 'xring' == apply ring all-reduce to all tensors - 'xring#2' == apply ring all-reduce to all tensors, using two simultaneous - transfer rings, each operating on 1/2 of each tensor. - 'nccl' == apply NCCL all-reduce to all tensors (only works within - a single worker process where all devices are GPUs) - 'nccl/xring' == apply NCCL all-reduce to all tensors within each worker - to produce at least one full-reduced (locally) value, - then apply ring all-reduce to one such value from each - worker, then apply NCCL broadcast to propagate those globally - reduced values back to every device within each worker. - 'pscpu' == Shuffle reduce using worker CPUs as the gather devices: each - distributed tensor is reduced by copying all instances to - one of the worker CPUs, computing the reduction there, then - copying back to each participating device. Tensor reductions - are assigned to specific CPUs round-robin. - 'psgpu#4' == Arrange all GPUs across all workers into groups of 4. - Each distributed tensor is shuffle reduced against one - such group of 4 GPUs, selected round-robin. That is, each - tensor is split across 4 shards for the reduction. - 'pscpu:2k:pscpu#2:64k:xring' == Apply single-shard pscpu to - tensors of size <= 2048 elements, apply 2-shard pscpu to - tensors up to size 64k elements, apply xring to larger tensors. - 'pscpu/pscpu#2' == Use shuffle gather to locally reduce each tensor on - the worker's CPU, then use 2-shard shuffle to reduce those - locally reduced tensors across workers (on the worker CPUs), then - scatter the globally reduced values locally from each worker CPU. - """ - range_parts = all_reduce_spec.split(':') + ['-1'] - if len(range_parts) % 2: - raise ValueError('all_reduce_spec not well formed: %s' % all_reduce_spec) - limit = 0 - spec = [] - alg = None - shards = 1 - for i, range_part in enumerate(range_parts): - if i % 2 == 1: - try: - limit = parse_general_int(range_part) - spec.append(AllReduceSpecTuple(alg=alg, shards=shards, limit=limit)) - except ValueError: - raise ValueError('all_reduce_spec (%s) contains non-integer range %s' % - (all_reduce_spec, range_part)) - else: - alg = range_part - alg_parts = range_part.split('#') - alg = alg_parts[0] - if len(alg_parts) > 1: - try: - shards = int(alg_parts[1]) - except ValueError: - raise ValueError('all_reduce_spec (%s) contains non-integer ' - 'shards %s' % all_reduce_spec, alg_parts[1]) - else: - shards = 1 - if alg not in [ - 'nccl', 'nccl/xring', 'nccl/rechd', 'nccl/pscpu', 'xring', 'pscpu', - 'psgpu', 'pscpu/pscpu' - ]: - raise ValueError('all_reduce_spec (%s) contains invalid alg %s' % - (all_reduce_spec, alg)) - return spec - - -def build_all_reduce_device_prefixes(job_name, num_tasks): - """Build list of device prefix names for all_reduce. - - Args: - job_name: 'worker', 'ps' or 'localhost'. - num_tasks: number of jobs across which device names should be generated. - - Returns: - A list of device name prefix strings. Each element spells out the full - host name without adding the device. - e.g. '/job:worker/task:0' - """ - if job_name != 'localhost': - return ['/job:%s/task:%d' % (job_name, d) for d in range(0, num_tasks)] - else: - assert num_tasks == 1 - return ['/job:%s' % job_name] - - -def group_device_names(devices, group_size): - """Group device names into groups of group_size. - - Args: - devices: list of strings naming devices. - group_size: int >= 1 - - Returns: - list of lists of devices, where each inner list is group_size long, - and each device appears at least once in an inner list. If - len(devices) % group_size = 0 then each device will appear - exactly once. - - Raises: - ValueError: group_size > len(devices) - """ - num_devices = len(devices) - if group_size > num_devices: - raise ValueError('only %d devices, but group_size=%d' % (num_devices, - group_size)) - num_groups = ( - num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) - groups = [[] for i in range(num_groups)] - for i in range(0, num_groups * group_size): - groups[i % num_groups].append(devices[i % num_devices]) - return groups - - -def split_grads_by_size(threshold_size, device_grads): - """Break gradients into two sets according to tensor size. - - Args: - threshold_size: int size cutoff for small vs large tensor. - device_grads: List of lists of (gradient, variable) tuples. The outer - list is over devices. The inner list is over individual gradients. - - Returns: - small_grads: Subset of device_grads where shape is <= theshold_size - elements. - large_grads: Subset of device_grads where shape is > threshold_size - elements. - """ - small_grads = [] - large_grads = [] - for dl in device_grads: - small_dl = [] - large_dl = [] - for (g, v) in dl: - tensor_size = g.get_shape().num_elements() - if tensor_size <= threshold_size: - small_dl.append([g, v]) - else: - large_dl.append([g, v]) - if small_dl: - small_grads.append(small_dl) - if large_dl: - large_grads.append(large_dl) - return small_grads, large_grads - - -def build_reduce_sum(scaled_grads): - stacked = tf.parallel_stack(values=scaled_grads) - reduced = tf.reduce_sum(stacked, 0) - return [reduced] * len(scaled_grads) - -def build_trivial_sum(scaled_grads): - return scaled_grads - -def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, - check_inf_nan): - """Calculate the average gradient for a shared variable across all towers. - - Note that this function provides a synchronization point across all towers. - - Args: - grad_and_vars: A list or tuple of (gradient, variable) tuples. Each - (gradient, variable) pair within the outer list represents the gradient - of the variable calculated for a single tower, and the number of pairs - equals the number of towers. - use_mean: if True, mean is taken, else sum of gradients is taken. - check_inf_nan: check grads for nans and infs. - - Returns: - The tuple ([(average_gradient, variable),], has_nan_or_inf) where the - gradient has been averaged across all towers. The variable is chosen from - the first tower. The has_nan_or_inf indicates the grads has nan or inf. - """ - grads = [g for g, _ in grad_and_vars] - grad = tf.add_n(grads) - - if use_mean and len(grads) > 1: - grad = tf.multiply(grad, 1.0 / len(grads)) - - v = grad_and_vars[0][1] - if check_inf_nan: - has_nan_or_inf = tf.logical_not(tf.reduce_all(tf.is_finite(grads))) - return (grad, v), has_nan_or_inf - else: - return (grad, v), None - - -def aggregate_gradients_using_copy_with_device_selection( - tower_grads, avail_devices, use_mean=True, check_inf_nan=False): - """Aggregate gradients, controlling device for the aggregation. - - Args: - tower_grads: List of lists of (gradient, variable) tuples. The outer list - is over towers. The inner list is over individual gradients. - use_mean: if True, mean is taken, else sum of gradients is taken. - check_inf_nan: If true, check grads for nans and infs. - - Returns: - The tuple ([(average_gradient, variable),], has_nan_or_inf) where the - gradient has been averaged across all towers. The variable is chosen from - the first tower. The has_nan_or_inf indicates the grads has nan or inf. - """ - agg_grads = [] - has_nan_or_inf_list = [] - for i, single_grads in enumerate(zip(*tower_grads)): - with tf.device(avail_devices[i % len(avail_devices)]): - grad_and_var, has_nan_or_inf = aggregate_single_gradient_using_copy( - single_grads, use_mean, check_inf_nan) - agg_grads.append(grad_and_var) - has_nan_or_inf_list.append(has_nan_or_inf) - return agg_grads - - -def sum_grad_and_var_all_reduce(grad_and_vars, - num_workers, - alg, - gpu_indices, - aux_devices=None, - num_shards=1): - """Apply all-reduce algorithm over specified gradient tensors.""" - with tf.name_scope('allreduce'): - # Note that each grad_and_vars looks like the following: - # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) - scaled_grads = [g for g, _ in grad_and_vars] - if alg == 'nccl': - summed_grads = nccl.all_sum(scaled_grads) - elif alg == 'simple': - summed_grads = build_reduce_sum(scaled_grads) - elif alg == 'trivial': - summed_grads = build_trivial_sum(scaled_grads) - elif alg == 'xring': - summed_grads = all_reduce.build_ring_all_reduce( - scaled_grads, num_workers, num_shards, gpu_indices, tf.add) - elif alg == 'nccl/xring': - summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, - tf.add) - elif alg == 'nccl/rechd': - summed_grads = all_reduce.build_nccl_then_recursive_hd( - scaled_grads, tf.add) - elif alg == 'nccl/pscpu': - summed_grads = all_reduce.build_nccl_then_shuffle( - scaled_grads, aux_devices, tf.add, tf.add_n) - elif alg == 'pscpu/pscpu': - summed_grads = all_reduce.build_shuffle_then_shuffle( - scaled_grads, - aux_devices, - # TODO(tucker): devise a way of better specifying the device set - # for the second level. - [aux_devices[0]], - tf.add_n) - elif alg in ['pscpu', 'psgpu']: - summed_grads = all_reduce.build_shuffle_all_reduce( - scaled_grads, aux_devices, tf.add_n) - else: - raise ValueError('unsupported all_reduce alg: ', alg) - - result = [] - for (_, v), g in zip(grad_and_vars, summed_grads): - result.append([g, v]) - return result - - -def contains_any(haystack, needles): - """Tests if any needle is a substring of haystack. - - Args: - haystack: a string - needles: list of strings - - Returns: - True if any element of needles is a substring of haystack, - False otherwise. - """ - for n in needles: - if n in haystack: - return True - return False - - -def sum_gradients_all_reduce(dev_prefixes, - tower_grads, - num_workers, - alg, - num_shards, - gpu_indices, - agg_small_grads_max_bytes=0, - agg_small_grads_max_group=10): - """Apply all-reduce algorithm over specified gradient tensors. - - Args: - dev_prefixes: list of prefix strings to use to generate PS device names. - tower_grads: the gradients to reduce. - num_workers: number of worker processes across entire job. - alg: the all-reduce algorithm to apply. - num_shards: alg-specific sharding factor. - gpu_indices: indices of local GPUs in order usable for ring-reduce. - agg_small_grads_max_bytes: largest tensor eligible for aggregation, - in number of bytes. - agg_small_grads_max_group: largest permitted aggregation of small - tensors. - - Returns: - list of reduced tensors - """ - alg_contains_shuffle = contains_any(alg, ['pscpu', 'psgpu']) - is_hierarchical = '/' in alg - if 'pscpu' in alg: - aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] - elif 'psgpu' in alg: - aux_devices = [ - prefix + '/gpu:%d' % i - for i in range(len(gpu_indices)) - for prefix in dev_prefixes - ] - else: - aux_devices = ['/job:localhost/cpu:0'] - aux_device_groups = group_device_names(aux_devices, num_shards - if alg_contains_shuffle else 1) - group_index = 0 - if agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0: - tower_grads, packing = pack_small_tensors( - tower_grads, - max_bytes=agg_small_grads_max_bytes, - max_group=agg_small_grads_max_group) - else: - packing = None - reduced_gv_list = [] - for grad_and_vars in zip(*tower_grads): - reduced_gv_list.append( - sum_grad_and_var_all_reduce( - grad_and_vars, num_workers, alg, gpu_indices, aux_devices - if is_hierarchical else aux_device_groups[group_index], num_shards)) - group_index = (group_index + 1) % len(aux_device_groups) - new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] - if packing: - new_tower_grads = unpack_small_tensors(new_tower_grads, packing) - return new_tower_grads - - -def extract_ranges(index_list, range_size_limit=32): - """Extract consecutive ranges and singles from index_list. - - Args: - index_list: List of monotone increasing non-negative integers. - range_size_limit: Largest size range to return. If a larger - consecutive range exists it will be returned as multiple - ranges. - - Returns: - ranges, singles where ranges is a list of [first, last] pairs of - consecutive elements in index_list, and singles is all of the - other elements, in original order. - """ - if not index_list: - return [], [] - first = index_list[0] - last = first - ranges = [] - singles = [] - for i in index_list[1:]: - if i == last + 1 and (last - first) <= range_size_limit: - last = i - else: - if last > first: - ranges.append([first, last]) - else: - singles.append(first) - first = i - last = i - if last > first: - ranges.append([first, last]) - else: - singles.append(first) - return ranges, singles - - -GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') - - -def pack_range(key, packing, grad_vars, rng): - """Form the concatenation of a specified range of gradient tensors. - - Args: - key: Value under which to store meta-data in packing that will be used - later to restore the grad_var list structure. - packing: Dict holding data describing packed ranges of small tensors. - grad_vars: List of (grad, var) pairs for one tower. - rng: A pair of integers giving the first, last indices of a consecutive - range of tensors to be packed. - - Returns: - A tensor that is the concatenation of all the specified small tensors. - """ - to_pack = grad_vars[rng[0]:rng[1] + 1] - members = [] - variables = [] - restore_shapes = [] - with tf.name_scope('pack'): - for g, v in to_pack: - variables.append(v) - restore_shapes.append(g.shape) - with tf.device(g.device): - members.append(tf.reshape(g, [-1])) - packing[key] = GradPackTuple( - indices=range(rng[0], rng[1] + 1), - vars=variables, - shapes=restore_shapes) - with tf.device(members[0].device): - return tf.concat(members, 0) - - -def unpack_grad_tuple(gv, gpt): - """Unpack a previously packed collection of gradient tensors. - - Args: - gv: A (grad, var) pair to be unpacked. - gpt: A GradPackTuple describing the packing operation that produced gv. - - Returns: - A list of (grad, var) pairs corresponding to the values that were - originally packed into gv, maybe following subsequent operations like - reduction. - """ - elt_widths = [x.num_elements() for x in gpt.shapes] - with tf.device(gv[0][0].device): - with tf.name_scope('unpack'): - splits = tf.split(gv[0], elt_widths) - unpacked_gv = [] - for idx, s in enumerate(splits): - unpacked_gv.append((tf.reshape(s, gpt.shapes[idx]), gpt.vars[idx])) - return unpacked_gv - - -def pack_small_tensors(tower_grads, max_bytes=0, max_group=0): - """Concatenate small gradient tensors together for reduction. - - Args: - tower_grads: List of lists of (gradient, variable) tuples. - max_bytes: Int giving max number of bytes in a tensor that - may be considered small. - max_group: Int giving max number of small tensors that may be - concatenated into one new tensor. - - Returns: - new_tower_grads, packing where new_tower_grads is identical to - tower_grads except that all feasible small_tensors have been removed - from their places and concatenated into larger tensors that are - now in the front of the list for each tower, and packing contains - the data necessary to restore the tower_grads structure. - - Look through the first tower for gradients of the same type (float), - and small size, that are all sequential. For each such group, - replace by a new tensor that is a flattened concatenation. Note - that the corresponding variable will be absent, which doesn't matter - because it isn't used during all-reduce. - - Requires: - Every gv_list in towers must have isomorphic structure including identical - tensor sizes and types. - """ - small_indices = [] - large_indices = [] - for idx, (g, _) in enumerate(tower_grads[0]): - if g.dtype == tf.float32 and (4 * g.shape.num_elements()) <= max_bytes: - small_indices.append(idx) - else: - large_indices.append(idx) - small_ranges, small_singles = extract_ranges( - small_indices, range_size_limit=max_group) - large_indices = sorted(large_indices + small_singles) - num_gv = len(tower_grads[0]) - packing = {} - if small_ranges: - new_tower_grads = [] - for dev_idx, gv_list in enumerate(tower_grads): - assert len(gv_list) == num_gv - new_gv_list = [] - for r in small_ranges: - key = '%d:%d' % (dev_idx, len(new_gv_list)) - new_gv_list.append((pack_range(key, packing, gv_list, r), - 'packing_var_placeholder')) - for i in large_indices: - new_gv_list.append(gv_list[i]) - new_tower_grads.append(new_gv_list) - return new_tower_grads, packing - else: - return tower_grads, None - - -def unpack_small_tensors(tower_grads, packing): - """Undo the structure alterations to tower_grads done by pack_small_tensors. - - Args: - tower_grads: List of List of (grad, var) tuples. - packing: A dict generated by pack_small_tensors describing the changes - it made to tower_grads. - - Returns: - new_tower_grads: identical to tower_grads except that concatentations - of small tensors have been split apart and returned to their original - positions, paired with their original variables. - """ - if not packing: - return tower_grads - new_tower_grads = [] - num_devices = len(tower_grads) - num_packed = len(packing.keys()) // num_devices - for dev_idx, gv_list in enumerate(tower_grads): - new_gv_list = gv_list[num_packed:] - for i in xrange(0, num_packed): - k = '%d:%d' % (dev_idx, i) - gpt = packing[k] - gv = unpack_grad_tuple(gv_list[i], gpt) - for gi, idx in enumerate(gpt.indices): - assert idx == gpt.indices[gi] - new_gv_list.insert(idx, gv[gi]) - new_tower_grads.append(new_gv_list) - return new_tower_grads diff --git a/python/ray/experimental/sgd/tfbench/modified_allreduce.py b/python/ray/experimental/sgd/tfbench/modified_allreduce.py index 5334daa57dc9a..f86cdddfe219e 100644 --- a/python/ray/experimental/sgd/tfbench/modified_allreduce.py +++ b/python/ray/experimental/sgd/tfbench/modified_allreduce.py @@ -18,14 +18,335 @@ import collections as pycoll import re -import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.contrib import nccl from tensorflow.contrib.all_reduce.python import all_reduce -from allreduce import * + +AllReduceSpecTuple = pycoll.namedtuple('AllReduceSpecTuple', 'alg shards limit') + + +def parse_general_int(s): + """Parse integer with power-of-2 suffix eg. 32k.""" + mo = re.match(r'(\d+)([KkMGT]?)$', s) + if mo: + i, suffix = mo.group(1, 2) + v = int(i) + if suffix: + if suffix == 'K' or suffix == 'k': + v *= 1024 + elif suffix == 'M': + v *= (1024 * 1024) + elif suffix == 'G': + v *= (1024 * 1024 * 1024) + elif suffix == 'T': + v *= (1024 * 1024 * 1024 * 1024) + else: + raise ValueError('invalid integer string %s' % s) + return v + else: + v = int(s) + return v + + +def parse_all_reduce_spec(all_reduce_spec): + """Parse all_reduce_spec. + + Args: + all_reduce_spec: a string specifying a combination of all-reduce + algorithms to apply for gradient reduction. + + Returns: + a list of AllReduceSpecTuple. + + Raises: + ValueError: all_reduce_spec is not well-formed. + + An all_reduce_spec has BNF form: + int ::= positive whole number + g_int ::= int[KkMGT]? + alg_spec ::= alg | alg#int + range_spec ::= alg_spec | alg_spec/alg_spec + spec ::= range_spec | range_spec:g_int:range_spec + + Not all syntactically correct specifications are supported. + Examples of supported all_reduce_spec strings, with semantics explained: + + 'xring' == apply ring all-reduce to all tensors + 'xring#2' == apply ring all-reduce to all tensors, using two simultaneous + transfer rings, each operating on 1/2 of each tensor. + 'nccl' == apply NCCL all-reduce to all tensors (only works within + a single worker process where all devices are GPUs) + 'nccl/xring' == apply NCCL all-reduce to all tensors within each worker + to produce at least one full-reduced (locally) value, + then apply ring all-reduce to one such value from each + worker, then apply NCCL broadcast to propagate those globally + reduced values back to every device within each worker. + 'pscpu' == Shuffle reduce using worker CPUs as the gather devices: each + distributed tensor is reduced by copying all instances to + one of the worker CPUs, computing the reduction there, then + copying back to each participating device. Tensor reductions + are assigned to specific CPUs round-robin. + 'psgpu#4' == Arrange all GPUs across all workers into groups of 4. + Each distributed tensor is shuffle reduced against one + such group of 4 GPUs, selected round-robin. That is, each + tensor is split across 4 shards for the reduction. + 'pscpu:2k:pscpu#2:64k:xring' == Apply single-shard pscpu to + tensors of size <= 2048 elements, apply 2-shard pscpu to + tensors up to size 64k elements, apply xring to larger tensors. + 'pscpu/pscpu#2' == Use shuffle gather to locally reduce each tensor on + the worker's CPU, then use 2-shard shuffle to reduce those + locally reduced tensors across workers (on the worker CPUs), then + scatter the globally reduced values locally from each worker CPU. + """ + range_parts = all_reduce_spec.split(':') + ['-1'] + if len(range_parts) % 2: + raise ValueError('all_reduce_spec not well formed: %s' % all_reduce_spec) + limit = 0 + spec = [] + alg = None + shards = 1 + for i, range_part in enumerate(range_parts): + if i % 2 == 1: + try: + limit = parse_general_int(range_part) + spec.append(AllReduceSpecTuple(alg=alg, shards=shards, limit=limit)) + except ValueError: + raise ValueError('all_reduce_spec (%s) contains non-integer range %s' % + (all_reduce_spec, range_part)) + else: + alg = range_part + alg_parts = range_part.split('#') + alg = alg_parts[0] + if len(alg_parts) > 1: + try: + shards = int(alg_parts[1]) + except ValueError: + raise ValueError('all_reduce_spec (%s) contains non-integer ' + 'shards %s' % all_reduce_spec, alg_parts[1]) + else: + shards = 1 + if alg not in [ + 'nccl', 'nccl/xring', 'nccl/rechd', 'nccl/pscpu', 'xring', 'pscpu', + 'psgpu', 'pscpu/pscpu' + ]: + raise ValueError('all_reduce_spec (%s) contains invalid alg %s' % + (all_reduce_spec, alg)) + return spec + + +def build_all_reduce_device_prefixes(job_name, num_tasks): + """Build list of device prefix names for all_reduce. + + Args: + job_name: 'worker', 'ps' or 'localhost'. + num_tasks: number of jobs across which device names should be generated. + + Returns: + A list of device name prefix strings. Each element spells out the full + host name without adding the device. + e.g. '/job:worker/task:0' + """ + if job_name != 'localhost': + return ['/job:%s/task:%d' % (job_name, d) for d in range(0, num_tasks)] + else: + assert num_tasks == 1 + return ['/job:%s' % job_name] + + +def group_device_names(devices, group_size): + """Group device names into groups of group_size. + + Args: + devices: list of strings naming devices. + group_size: int >= 1 + + Returns: + list of lists of devices, where each inner list is group_size long, + and each device appears at least once in an inner list. If + len(devices) % group_size = 0 then each device will appear + exactly once. + + Raises: + ValueError: group_size > len(devices) + """ + num_devices = len(devices) + if group_size > num_devices: + raise ValueError('only %d devices, but group_size=%d' % (num_devices, + group_size)) + num_groups = ( + num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) + groups = [[] for i in range(num_groups)] + for i in range(0, num_groups * group_size): + groups[i % num_groups].append(devices[i % num_devices]) + return groups + + +def split_grads_by_size(threshold_size, device_grads): + """Break gradients into two sets according to tensor size. + + Args: + threshold_size: int size cutoff for small vs large tensor. + device_grads: List of lists of (gradient, variable) tuples. The outer + list is over devices. The inner list is over individual gradients. + + Returns: + small_grads: Subset of device_grads where shape is <= theshold_size + elements. + large_grads: Subset of device_grads where shape is > threshold_size + elements. + """ + small_grads = [] + large_grads = [] + for dl in device_grads: + small_dl = [] + large_dl = [] + for (g, v) in dl: + tensor_size = g.get_shape().num_elements() + if tensor_size <= threshold_size: + small_dl.append([g, v]) + else: + large_dl.append([g, v]) + if small_dl: + small_grads.append(small_dl) + if large_dl: + large_grads.append(large_dl) + return small_grads, large_grads + + +def build_reduce_sum(scaled_grads): + stacked = tf.parallel_stack(values=scaled_grads) + reduced = tf.reduce_sum(stacked, 0) + return [reduced] * len(scaled_grads) + +def build_trivial_sum(scaled_grads): + return scaled_grads + +def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, + check_inf_nan): + """Calculate the average gradient for a shared variable across all towers. + + Note that this function provides a synchronization point across all towers. + + Args: + grad_and_vars: A list or tuple of (gradient, variable) tuples. Each + (gradient, variable) pair within the outer list represents the gradient + of the variable calculated for a single tower, and the number of pairs + equals the number of towers. + use_mean: if True, mean is taken, else sum of gradients is taken. + check_inf_nan: check grads for nans and infs. + + Returns: + The tuple ([(average_gradient, variable),], has_nan_or_inf) where the + gradient has been averaged across all towers. The variable is chosen from + the first tower. The has_nan_or_inf indicates the grads has nan or inf. + """ + grads = [g for g, _ in grad_and_vars] + grad = tf.add_n(grads) + + if use_mean and len(grads) > 1: + grad = tf.multiply(grad, 1.0 / len(grads)) + + v = grad_and_vars[0][1] + if check_inf_nan: + has_nan_or_inf = tf.logical_not(tf.reduce_all(tf.is_finite(grads))) + return (grad, v), has_nan_or_inf + else: + return (grad, v), None + + +def aggregate_gradients_using_copy_with_device_selection( + tower_grads, avail_devices, use_mean=True, check_inf_nan=False): + """Aggregate gradients, controlling device for the aggregation. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over towers. The inner list is over individual gradients. + use_mean: if True, mean is taken, else sum of gradients is taken. + check_inf_nan: If true, check grads for nans and infs. + + Returns: + The tuple ([(average_gradient, variable),], has_nan_or_inf) where the + gradient has been averaged across all towers. The variable is chosen from + the first tower. The has_nan_or_inf indicates the grads has nan or inf. + """ + agg_grads = [] + has_nan_or_inf_list = [] + for i, single_grads in enumerate(zip(*tower_grads)): + with tf.device(avail_devices[i % len(avail_devices)]): + grad_and_var, has_nan_or_inf = aggregate_single_gradient_using_copy( + single_grads, use_mean, check_inf_nan) + agg_grads.append(grad_and_var) + has_nan_or_inf_list.append(has_nan_or_inf) + return agg_grads + + +def sum_grad_and_var_all_reduce(grad_and_vars, + num_workers, + alg, + gpu_indices, + aux_devices=None, + num_shards=1): + """Apply all-reduce algorithm over specified gradient tensors.""" + with tf.name_scope('allreduce'): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + scaled_grads = [g for g, _ in grad_and_vars] + if alg == 'nccl': + summed_grads = nccl.all_sum(scaled_grads) + elif alg == 'simple': + summed_grads = build_reduce_sum(scaled_grads) + elif alg == 'trivial': + summed_grads = build_trivial_sum(scaled_grads) + elif alg == 'xring': + summed_grads = all_reduce.build_ring_all_reduce( + scaled_grads, num_workers, num_shards, gpu_indices, tf.add) + elif alg == 'nccl/xring': + summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, + tf.add) + elif alg == 'nccl/rechd': + summed_grads = all_reduce.build_nccl_then_recursive_hd( + scaled_grads, tf.add) + elif alg == 'nccl/pscpu': + summed_grads = all_reduce.build_nccl_then_shuffle( + scaled_grads, aux_devices, tf.add, tf.add_n) + elif alg == 'pscpu/pscpu': + summed_grads = all_reduce.build_shuffle_then_shuffle( + scaled_grads, + aux_devices, + # TODO(tucker): devise a way of better specifying the device set + # for the second level. + [aux_devices[0]], + tf.add_n) + elif alg in ['pscpu', 'psgpu']: + summed_grads = all_reduce.build_shuffle_all_reduce( + scaled_grads, aux_devices, tf.add_n) + else: + raise ValueError('unsupported all_reduce alg: ', alg) + + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result + + +def contains_any(haystack, needles): + """Tests if any needle is a substring of haystack. + + Args: + haystack: a string + needles: list of strings + + Returns: + True if any element of needles is a substring of haystack, + False otherwise. + """ + for n in needles: + if n in haystack: + return True + return False def sum_gradients_all_reduce(dev_prefixes, @@ -111,6 +432,100 @@ def sizeof_fmt(num, suffix='B'): print(", ".join(["%s: %f" % (k, v) for k, v in other_stats.items()])) +def extract_ranges(index_list, range_size_limit=32): + """Extract consecutive ranges and singles from index_list. + + Args: + index_list: List of monotone increasing non-negative integers. + range_size_limit: Largest size range to return. If a larger + consecutive range exists it will be returned as multiple + ranges. + + Returns: + ranges, singles where ranges is a list of [first, last] pairs of + consecutive elements in index_list, and singles is all of the + other elements, in original order. + """ + if not index_list: + return [], [] + first = index_list[0] + last = first + ranges = [] + singles = [] + for i in index_list[1:]: + if i == last + 1 and (last - first) <= range_size_limit: + last = i + else: + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + first = i + last = i + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + return ranges, singles + + +GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') + + +def pack_range(key, packing, grad_vars, rng): + """Form the concatenation of a specified range of gradient tensors. + + Args: + key: Value under which to store meta-data in packing that will be used + later to restore the grad_var list structure. + packing: Dict holding data describing packed ranges of small tensors. + grad_vars: List of (grad, var) pairs for one tower. + rng: A pair of integers giving the first, last indices of a consecutive + range of tensors to be packed. + + Returns: + A tensor that is the concatenation of all the specified small tensors. + """ + to_pack = grad_vars[rng[0]:rng[1] + 1] + members = [] + variables = [] + restore_shapes = [] + with tf.name_scope('pack'): + for g, v in to_pack: + variables.append(v) + restore_shapes.append(g.shape) + with tf.device(g.device): + members.append(tf.reshape(g, [-1])) + packing[key] = GradPackTuple( + indices=range(rng[0], rng[1] + 1), + vars=variables, + shapes=restore_shapes) + with tf.device(members[0].device): + return tf.concat(members, 0) + + +def unpack_grad_tuple(gv, gpt): + """Unpack a previously packed collection of gradient tensors. + + Args: + gv: A (grad, var) pair to be unpacked. + gpt: A GradPackTuple describing the packing operation that produced gv. + + Returns: + A list of (grad, var) pairs corresponding to the values that were + originally packed into gv, maybe following subsequent operations like + reduction. + """ + elt_widths = [x.num_elements() for x in gpt.shapes] + with tf.device(gv[0][0].device): + with tf.name_scope('unpack'): + splits = tf.split(gv[0], elt_widths) + unpacked_gv = [] + for idx, s in enumerate(splits): + unpacked_gv.append((tf.reshape(s, gpt.shapes[idx]), gpt.vars[idx])) + return unpacked_gv + + def pack_small_tensors(tower_grads, max_bytes=0): """Concatenate gradients together more intelligently. @@ -169,3 +584,34 @@ def end_interval(indices, small_ranges, large_indices): return new_tower_grads, packing else: return tower_grads, None + + +def unpack_small_tensors(tower_grads, packing): + """Undo the structure alterations to tower_grads done by pack_small_tensors. + + Args: + tower_grads: List of List of (grad, var) tuples. + packing: A dict generated by pack_small_tensors describing the changes + it made to tower_grads. + + Returns: + new_tower_grads: identical to tower_grads except that concatentations + of small tensors have been split apart and returned to their original + positions, paired with their original variables. + """ + if not packing: + return tower_grads + new_tower_grads = [] + num_devices = len(tower_grads) + num_packed = len(packing.keys()) // num_devices + for dev_idx, gv_list in enumerate(tower_grads): + new_gv_list = gv_list[num_packed:] + for i in xrange(0, num_packed): + k = '%d:%d' % (dev_idx, i) + gpt = packing[k] + gv = unpack_grad_tuple(gv_list[i], gpt) + for gi, idx in enumerate(gpt.indices): + assert idx == gpt.indices[gi] + new_gv_list.insert(idx, gv[gi]) + new_tower_grads.append(new_gv_list) + return new_tower_grads diff --git a/python/ray/experimental/sgd/chrome_timeline.py b/python/ray/experimental/sgd/util.py similarity index 72% rename from python/ray/experimental/sgd/chrome_timeline.py rename to python/ray/experimental/sgd/util.py index 41b34b9d019ea..c41ce39dbc4a0 100644 --- a/python/ray/experimental/sgd/chrome_timeline.py +++ b/python/ray/experimental/sgd/util.py @@ -7,6 +7,29 @@ import time +def fetch(oids): + for o in oids: + plasma_id = ray.pyarrow.plasma.ObjectID(o) + ray.worker.global_worker.plasma_client.fetch([plasma_id]) + + +def run_timeline(sess, ops, feed_dict={}, write_timeline=False, name=""): + if write_timeline: + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = tf.RunMetadata() + fetches = sess.run( + ops, options=run_options, run_metadata=run_metadata, + feed_dict=feed_dict) + trace = timeline.Timeline(step_stats=run_metadata.step_stats) + outf = "timeline-{}-{}.json".format(name, os.getpid()) + trace_file = open(outf, "w") + print("wrote tf timeline to", os.path.abspath(outf)) + trace_file.write(trace.generate_chrome_trace_format()) + else: + fetches = sess.run(ops, feed_dict=feed_dict) + return fetches + + class Timeline(object): def __init__(self, tid): self.events = [] From 424288617d17371a5776ef02905f3fe9faf329c8 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Sep 2018 18:21:51 -0700 Subject: [PATCH 06/17] yapf --- python/ray/experimental/sgd/sgd.py | 113 ++- python/ray/experimental/sgd/test_sgd.py | 3 +- .../sgd/tfbench/convnet_builder.py | 862 ++++++++++-------- python/ray/experimental/sgd/tfbench/model.py | 157 ++-- .../experimental/sgd/tfbench/model_config.py | 44 +- .../sgd/tfbench/modified_allreduce.py | 662 +++++++------- .../experimental/sgd/tfbench/resnet_model.py | 485 ++++++---- python/ray/experimental/sgd/util.py | 4 +- 8 files changed, 1253 insertions(+), 1077 deletions(-) diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index 5496453918f7f..1a49f26e8ad60 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -32,7 +32,9 @@ def __init__(self, # TODO(ekl) support custom session tf_session_args = { - "device_count": {"CPU": num_devices}, + "device_count": { + "CPU": num_devices + }, "log_device_placement": False, "gpu_options": tf.GPUOptions(force_gpu_compatible=True), "inter_op_parallelism_threads": 128, @@ -53,30 +55,40 @@ def __init__(self, model = model_creator(worker_index, device_idx) self.models.append(model) model.grads = [ - t for t in model.optimizer.compute_gradients( - model.loss) - if t[0] is not None] + t + for t in model.optimizer.compute_gradients(model.loss) + if t[0] is not None + ] grad_ops.append(model.grads) if num_devices == 1: - assert not max_bytes, "Not supported with 1 GPU" - self.packed_grads_and_vars = grad_ops + assert not max_bytes, "Not supported with 1 GPU" + self.packed_grads_and_vars = grad_ops else: if max_bytes: self.packed_grads_and_vars, packing_vals = ( modified_allreduce.sum_gradients_all_reduce( - "", grad_ops, 1, all_reduce_alg, 1, + "", + grad_ops, + 1, + all_reduce_alg, + 1, list(range(num_devices)), agg_small_grads_max_bytes=max_bytes)) else: self.packed_grads_and_vars = ( modified_allreduce.sum_gradients_all_reduce( - "", grad_ops, 1, all_reduce_alg, 1, + "", + grad_ops, + 1, + all_reduce_alg, + 1, list(range(num_devices)), agg_small_grads_max_bytes=0)) self.per_device_grads = [ - list(zip(*dev_gv))[0] for dev_gv in self.packed_grads_and_vars] - assert(len(self.per_device_grads) == num_devices) + list(zip(*dev_gv))[0] for dev_gv in self.packed_grads_and_vars + ] + assert (len(self.per_device_grads) == num_devices) self.num_grads = num_grads = len(self.packed_grads_and_vars[0]) if max_bytes: print("Packed grads => {} tensors".format(num_grads)) @@ -85,8 +97,8 @@ def __init__(self, nccl_noops = [] for j in range(num_grads)[::-1]: with tf.control_dependencies( - nccl_noops + [dev_grad[j] - for dev_grad in self.per_device_grads]): + nccl_noops + + [dev_grad[j] for dev_grad in self.per_device_grads]): nccl_noops = [tf.no_op()] # You must fetch this otherwise the NCCL allreduce will hang @@ -107,7 +119,8 @@ def __init__(self, self.plasma_in_grads = [] self.plasma_in_grads_oids = [ tf.placeholder(shape=[], dtype=tf.string) - for _ in range(num_grads)] + for _ in range(num_grads) + ] ix = 0 for j in range(num_grads): grad = self.per_device_grads[ix][j] @@ -126,7 +139,8 @@ def __init__(self, unpacked_gv = [] self.plasma_out_grads_oids = [ tf.placeholder(shape=[], dtype=tf.string) - for _ in range(num_grads)] + for _ in range(num_grads) + ] packed_plasma_grads = [] ix = 0 for j in range(num_grads): @@ -136,8 +150,8 @@ def __init__(self, self.plasma_out_grads_oids[j], plasma_store_socket_name=store_socket, plasma_manager_socket_name=manager_socket) - grad_ph = tf.reshape( - grad_ph, self.packed_grads_and_vars[0][j][0].shape) + grad_ph = tf.reshape(grad_ph, + self.packed_grads_and_vars[0][j][0].shape) print("Packed tensor", grad_ph) packed_plasma_grads.append(grad_ph) for i in range(num_devices): @@ -164,9 +178,10 @@ def __init__(self, apply_ops = [] to_apply = unpacked_gv[0] for ix, m in enumerate(self.models): - apply_ops.append(m.optimizer.apply_gradients( - [(g, v) - for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix])])) + apply_ops.append( + m.optimizer.apply_gradients( + [(g, v) + for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix])])) self.apply_op = tf.group(*apply_ops) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) @@ -187,8 +202,11 @@ def compute_gradients(self, verbose): # We only need to fetch the first per_device_grad, since they are # averaged across all devices by allreduce. fetches = self.sess.run( - [self.models[0].loss, self.per_device_grads[0], - self.nccl_control_out], feed_dict=feed_dict) + [ + self.models[0].loss, self.per_device_grads[0], + self.nccl_control_out + ], + feed_dict=feed_dict) if verbose: print("compute grad interior time", time.time() - start) return fetches @@ -196,24 +214,27 @@ def compute_gradients(self, verbose): def apply_gradients(self, avg_grads, verbose): start = time.time() result = { - g: avg_grads[i] for (i, g) in enumerate(self.per_device_grads[0]) + g: avg_grads[i] + for (i, g) in enumerate(self.per_device_grads[0]) } self.sess.run(self.apply_op, feed_dict=result) if verbose: print("apply grad interior time", time.time() - start) - def ps_compute_apply( - self, out_grad_shard_oids, agg_grad_shard_oids, - tl_name="ps_compute_apply", write_timeline=False): + def ps_compute_apply(self, + out_grad_shard_oids, + agg_grad_shard_oids, + tl_name="ps_compute_apply", + write_timeline=False): feed_dict = { ph: oid - for (ph, oid) - in zip(self.plasma_in_grads_oids, out_grad_shard_oids) + for (ph, + oid) in zip(self.plasma_in_grads_oids, out_grad_shard_oids) } feed_dict.update({ ph: oid - for (ph, oid) - in zip(self.plasma_out_grads_oids, agg_grad_shard_oids) + for (ph, + oid) in zip(self.plasma_out_grads_oids, agg_grad_shard_oids) }) fetch(agg_grad_shard_oids) run_timeline( @@ -268,9 +289,9 @@ def add_spinwait(self, grad_shard_ids): for p in plasma_ids: if ray.worker.global_worker.plasma_client.contains(p): self.timeline.start("get_buffers") - [raw_grads] = ( - ray.worker.global_worker.plasma_client.get_buffers( - [p])) + [raw_grads + ] = (ray.worker.global_worker.plasma_client.get_buffers( + [p])) grads = np.frombuffer(raw_grads, dtype=np.float32) self.accumulated += grads self.acc_counter += 1 @@ -298,8 +319,7 @@ def get(self, object_id): client = ray.worker.global_worker.plasma_client assert self.acc_counter == self.num_sgd_workers, self.acc_counter oid = ray.pyarrow.plasma.ObjectID(object_id) - buff = client.create( - oid, self.accumulated.nbytes) + buff = client.create(oid, self.accumulated.nbytes) wrapper = np.frombuffer(buff, dtype=np.float32) np.copyto(wrapper, self.accumulated) client.seal(oid) @@ -354,10 +374,8 @@ def do_sgd_step(actors, verbose): def distributed_sgd_step(actors, ps_list, verbose, write_timeline): # Preallocate object ids that actors will write gradient shards to - grad_shard_oids_list = [ - [np.random.bytes(20) for _ in ps_list] - for _ in actors - ] + grad_shard_oids_list = [[np.random.bytes(20) for _ in ps_list] + for _ in actors] print("generated grad oids") # Preallocate object ids that param servers will write new weights to @@ -366,8 +384,8 @@ def distributed_sgd_step(actors, ps_list, verbose, write_timeline): # Kick off the fused compute grad / update weights tf run for each actor for actor, grad_shard_oids in zip(actors, grad_shard_oids_list): - actor.ps_compute_apply.remote(grad_shard_oids, accum_shard_ids, - write_timeline=write_timeline) + actor.ps_compute_apply.remote( + grad_shard_oids, accum_shard_ids, write_timeline=write_timeline) print("Launched all ps_compute_applys on all actors") # Issue prefetch ops @@ -417,7 +435,7 @@ def create_ps(): ip_mapping = defaultdict(list) while (any(len(v) < min_placed for v in ip_mapping.values()) - or (len(ip_mapping) < num_ips)): + or (len(ip_mapping) < num_ips)): print("generating new ps, ip map so far", ip_mapping) new_ps = create_ps() ps_ip = ray.get(new_ps.ip.remote()) @@ -442,15 +460,16 @@ def create_ps(): else: print("saving ps...") - print("Final PS balance: ", Counter(ray.get([ps.ip.remote() for ps in final_list]))) + print("Final PS balance: ", + Counter(ray.get([ps.ip.remote() for ps in final_list]))) for i, ps in enumerate(final_list): ps.set_tid.remote(i) return final_list class DistributedSGD(object): - def __init__( - self, model_creator, num_workers, devices_per_worker, use_cpus): + def __init__(self, model_creator, num_workers, devices_per_worker, + use_cpus): self.model_creator = model_creator if use_cpus: requests = {"num_cpus": devices_per_worker} @@ -462,8 +481,10 @@ def __init__( print("Creating worker", worker_index) self.workers.append( RemoteSGDWorker.remote( - worker_index, model_creator, - num_devices=devices_per_worker, use_cpus=use_cpus, + worker_index, + model_creator, + num_devices=devices_per_worker, + use_cpus=use_cpus, verbose=True)) def foreach_worker(self, fn): diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index 37ea18a3c3e4c..fc7c0233e867f 100644 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -11,10 +11,9 @@ from ray.experimental.sgd.example.test_model import TFBenchModel from ray.experimental.sgd.sgd import DistributedSGD - if __name__ == "__main__": ray.init() - + model_creator = ( lambda worker_idx, device_idx: TFBenchModel(batch=1, use_cpus=True)) diff --git a/python/ray/experimental/sgd/tfbench/convnet_builder.py b/python/ray/experimental/sgd/tfbench/convnet_builder.py index d0cc2755e0bc1..ad4225540ef85 100644 --- a/python/ray/experimental/sgd/tfbench/convnet_builder.py +++ b/python/ray/experimental/sgd/tfbench/convnet_builder.py @@ -30,33 +30,33 @@ class ConvNetBuilder(object): - """Builder of cnn net.""" + """Builder of cnn net.""" - def __init__(self, - input_op, - input_nchan, - phase_train, - use_tf_layers, - data_format='NCHW', - dtype=tf.float32, - variable_dtype=tf.float32): - self.top_layer = input_op - self.top_size = input_nchan - self.phase_train = phase_train - self.use_tf_layers = use_tf_layers - self.data_format = data_format - self.dtype = dtype - self.variable_dtype = variable_dtype - self.counts = defaultdict(lambda: 0) - self.use_batch_norm = False - self.batch_norm_config = {} # 'decay': 0.997, 'scale': True} - self.channel_pos = ('channels_last' - if data_format == 'NHWC' else 'channels_first') - self.aux_top_layer = None - self.aux_top_size = 0 + def __init__(self, + input_op, + input_nchan, + phase_train, + use_tf_layers, + data_format='NCHW', + dtype=tf.float32, + variable_dtype=tf.float32): + self.top_layer = input_op + self.top_size = input_nchan + self.phase_train = phase_train + self.use_tf_layers = use_tf_layers + self.data_format = data_format + self.dtype = dtype + self.variable_dtype = variable_dtype + self.counts = defaultdict(lambda: 0) + self.use_batch_norm = False + self.batch_norm_config = {} # 'decay': 0.997, 'scale': True} + self.channel_pos = ('channels_last' + if data_format == 'NHWC' else 'channels_first') + self.aux_top_layer = None + self.aux_top_size = 0 - def get_custom_getter(self): - """Returns a custom getter that this class's methods must be called under. + def get_custom_getter(self): + """Returns a custom getter that this class's methods must be called under. All methods of this class must be called under a variable scope that was passed this custom getter. Example: @@ -73,395 +73,457 @@ def get_custom_getter(self): self.variable_type, then casted to the requested dtype, instead of directly storing the variable as the requested dtype. """ - def inner_custom_getter(getter, *args, **kwargs): - """Custom getter that forces variables to have type self.variable_type.""" - if not self.use_tf_layers: - return getter(*args, **kwargs) - requested_dtype = kwargs['dtype'] - if not (requested_dtype == tf.float32 and - self.variable_dtype == tf.float16): - # Only change the variable dtype if doing so does not decrease variable - # precision. - kwargs['dtype'] = self.variable_dtype - var = getter(*args, **kwargs) - # This if statement is needed to guard the cast, because batch norm - # assigns directly to the return value of this custom getter. The cast - # makes the return value not a variable so it cannot be assigned. Batch - # norm variables are always in fp32 so this if statement is never - # triggered for them. - if var.dtype.base_dtype != requested_dtype: - var = tf.cast(var, requested_dtype) - return var - return inner_custom_getter - @contextlib.contextmanager - def switch_to_aux_top_layer(self): - """Context that construct cnn in the auxiliary arm.""" - if self.aux_top_layer is None: - raise RuntimeError('Empty auxiliary top layer in the network.') - saved_top_layer = self.top_layer - saved_top_size = self.top_size - self.top_layer = self.aux_top_layer - self.top_size = self.aux_top_size - yield - self.aux_top_layer = self.top_layer - self.aux_top_size = self.top_size - self.top_layer = saved_top_layer - self.top_size = saved_top_size + def inner_custom_getter(getter, *args, **kwargs): + """Custom getter that forces variables to have type self.variable_type.""" + if not self.use_tf_layers: + return getter(*args, **kwargs) + requested_dtype = kwargs['dtype'] + if not (requested_dtype == tf.float32 + and self.variable_dtype == tf.float16): + # Only change the variable dtype if doing so does not decrease variable + # precision. + kwargs['dtype'] = self.variable_dtype + var = getter(*args, **kwargs) + # This if statement is needed to guard the cast, because batch norm + # assigns directly to the return value of this custom getter. The cast + # makes the return value not a variable so it cannot be assigned. Batch + # norm variables are always in fp32 so this if statement is never + # triggered for them. + if var.dtype.base_dtype != requested_dtype: + var = tf.cast(var, requested_dtype) + return var - def get_variable(self, name, shape, dtype, cast_dtype, *args, **kwargs): - # TODO(reedwm): Currently variables and gradients are transferred to other - # devices and machines as type `dtype`, not `cast_dtype`. In particular, - # this means in fp16 mode, variables are transferred as fp32 values, not - # fp16 values, which uses extra bandwidth. - var = tf.get_variable(name, shape, dtype, *args, **kwargs) - return tf.cast(var, cast_dtype) + return inner_custom_getter - def _conv2d_impl(self, input_layer, num_channels_in, filters, kernel_size, - strides, padding, kernel_initializer): - if self.use_tf_layers: - return conv_layers.conv2d(input_layer, filters, kernel_size, strides, - padding, self.channel_pos, - kernel_initializer=kernel_initializer, - use_bias=False) - else: - weights_shape = [kernel_size[0], kernel_size[1], num_channels_in, filters] - # We use the name 'conv2d/kernel' so the variable has the same name as its - # tf.layers equivalent. This way, if a checkpoint is written when - # self.use_tf_layers == True, it can be loaded when - # self.use_tf_layers == False, and vice versa. - weights = self.get_variable('conv2d/kernel', weights_shape, - self.variable_dtype, self.dtype, - initializer=kernel_initializer) - if self.data_format == 'NHWC': - strides = [1] + strides + [1] - else: - strides = [1, 1] + strides - return tf.nn.conv2d(input_layer, weights, strides, padding, - data_format=self.data_format) + @contextlib.contextmanager + def switch_to_aux_top_layer(self): + """Context that construct cnn in the auxiliary arm.""" + if self.aux_top_layer is None: + raise RuntimeError('Empty auxiliary top layer in the network.') + saved_top_layer = self.top_layer + saved_top_size = self.top_size + self.top_layer = self.aux_top_layer + self.top_size = self.aux_top_size + yield + self.aux_top_layer = self.top_layer + self.aux_top_size = self.top_size + self.top_layer = saved_top_layer + self.top_size = saved_top_size - def conv(self, - num_out_channels, - k_height, - k_width, - d_height=1, - d_width=1, - mode='SAME', - input_layer=None, - num_channels_in=None, - use_batch_norm=None, - stddev=None, - activation='relu', - bias=0.0): - """Construct a conv2d layer on top of cnn.""" - if input_layer is None: - input_layer = self.top_layer - if num_channels_in is None: - num_channels_in = self.top_size - kernel_initializer = None - if stddev is not None: - kernel_initializer = tf.truncated_normal_initializer(stddev=stddev) - name = 'conv' + str(self.counts['conv']) - self.counts['conv'] += 1 - with tf.variable_scope(name): - strides = [1, d_height, d_width, 1] - if self.data_format == 'NCHW': - strides = [strides[0], strides[3], strides[1], strides[2]] - if mode != 'SAME_RESNET': - conv = self._conv2d_impl(input_layer, num_channels_in, num_out_channels, - kernel_size=[k_height, k_width], - strides=[d_height, d_width], padding=mode, - kernel_initializer=kernel_initializer) - else: # Special padding mode for ResNet models - if d_height == 1 and d_width == 1: - conv = self._conv2d_impl(input_layer, num_channels_in, - num_out_channels, - kernel_size=[k_height, k_width], - strides=[d_height, d_width], padding='SAME', - kernel_initializer=kernel_initializer) - else: - rate = 1 # Unused (for 'a trous' convolutions) - kernel_height_effective = k_height + (k_height - 1) * (rate - 1) - pad_h_beg = (kernel_height_effective - 1) // 2 - pad_h_end = kernel_height_effective - 1 - pad_h_beg - kernel_width_effective = k_width + (k_width - 1) * (rate - 1) - pad_w_beg = (kernel_width_effective - 1) // 2 - pad_w_end = kernel_width_effective - 1 - pad_w_beg - padding = [[0, 0], [pad_h_beg, pad_h_end], - [pad_w_beg, pad_w_end], [0, 0]] - if self.data_format == 'NCHW': - padding = [padding[0], padding[3], padding[1], padding[2]] - input_layer = tf.pad(input_layer, padding) - conv = self._conv2d_impl(input_layer, num_channels_in, - num_out_channels, - kernel_size=[k_height, k_width], - strides=[d_height, d_width], padding='VALID', - kernel_initializer=kernel_initializer) - if use_batch_norm is None: - use_batch_norm = self.use_batch_norm - if not use_batch_norm: - if bias is not None: - biases = self.get_variable('biases', [num_out_channels], - self.variable_dtype, self.dtype, - initializer=tf.constant_initializer(bias)) - biased = tf.reshape( - tf.nn.bias_add(conv, biases, data_format=self.data_format), - conv.get_shape()) - else: - biased = conv - else: - self.top_layer = conv - self.top_size = num_out_channels - biased = self.batch_norm(**self.batch_norm_config) - if activation == 'relu': - conv1 = tf.nn.relu(biased) - elif activation == 'linear' or activation is None: - conv1 = biased - elif activation == 'tanh': - conv1 = tf.nn.tanh(biased) - else: - raise KeyError('Invalid activation type \'%s\'' % activation) - self.top_layer = conv1 - self.top_size = num_out_channels - return conv1 - - def _pool(self, - pool_name, - pool_function, - k_height, - k_width, - d_height, - d_width, - mode, - input_layer, - num_channels_in): - """Construct a pooling layer.""" - if input_layer is None: - input_layer = self.top_layer - else: - self.top_size = num_channels_in - name = pool_name + str(self.counts[pool_name]) - self.counts[pool_name] += 1 - if self.use_tf_layers: - pool = pool_function( - input_layer, [k_height, k_width], [d_height, d_width], - padding=mode, - data_format=self.channel_pos, - name=name) - else: - if self.data_format == 'NHWC': - ksize = [1, k_height, k_width, 1] - strides = [1, d_height, d_width, 1] - else: - ksize = [1, 1, k_height, k_width] - strides = [1, 1, d_height, d_width] - pool = tf.nn.max_pool(input_layer, ksize, strides, padding=mode, - data_format=self.data_format, name=name) - self.top_layer = pool - return pool - - def mpool(self, - k_height, - k_width, - d_height=2, - d_width=2, - mode='VALID', - input_layer=None, - num_channels_in=None): - """Construct a max pooling layer.""" - return self._pool('mpool', pooling_layers.max_pooling2d, k_height, k_width, - d_height, d_width, mode, input_layer, num_channels_in) - - def apool(self, - k_height, - k_width, - d_height=2, - d_width=2, - mode='VALID', - input_layer=None, - num_channels_in=None): - """Construct an average pooling layer.""" - return self._pool('apool', pooling_layers.average_pooling2d, k_height, - k_width, d_height, d_width, mode, input_layer, - num_channels_in) + def get_variable(self, name, shape, dtype, cast_dtype, *args, **kwargs): + # TODO(reedwm): Currently variables and gradients are transferred to other + # devices and machines as type `dtype`, not `cast_dtype`. In particular, + # this means in fp16 mode, variables are transferred as fp32 values, not + # fp16 values, which uses extra bandwidth. + var = tf.get_variable(name, shape, dtype, *args, **kwargs) + return tf.cast(var, cast_dtype) - def reshape(self, shape, input_layer=None): - if input_layer is None: - input_layer = self.top_layer - self.top_layer = tf.reshape(input_layer, shape) - self.top_size = shape[-1] # HACK This may not always work - return self.top_layer + def _conv2d_impl(self, input_layer, num_channels_in, filters, kernel_size, + strides, padding, kernel_initializer): + if self.use_tf_layers: + return conv_layers.conv2d( + input_layer, + filters, + kernel_size, + strides, + padding, + self.channel_pos, + kernel_initializer=kernel_initializer, + use_bias=False) + else: + weights_shape = [ + kernel_size[0], kernel_size[1], num_channels_in, filters + ] + # We use the name 'conv2d/kernel' so the variable has the same name as its + # tf.layers equivalent. This way, if a checkpoint is written when + # self.use_tf_layers == True, it can be loaded when + # self.use_tf_layers == False, and vice versa. + weights = self.get_variable( + 'conv2d/kernel', + weights_shape, + self.variable_dtype, + self.dtype, + initializer=kernel_initializer) + if self.data_format == 'NHWC': + strides = [1] + strides + [1] + else: + strides = [1, 1] + strides + return tf.nn.conv2d( + input_layer, + weights, + strides, + padding, + data_format=self.data_format) - def affine(self, + def conv(self, num_out_channels, + k_height, + k_width, + d_height=1, + d_width=1, + mode='SAME', input_layer=None, num_channels_in=None, - bias=0.0, + use_batch_norm=None, stddev=None, - activation='relu'): - if input_layer is None: - input_layer = self.top_layer - if num_channels_in is None: - num_channels_in = self.top_size - name = 'affine' + str(self.counts['affine']) - self.counts['affine'] += 1 - with tf.variable_scope(name): - init_factor = 2. if activation == 'relu' else 1. - stddev = stddev or np.sqrt(init_factor / num_channels_in) - kernel = self.get_variable( - 'weights', [num_channels_in, num_out_channels], - self.variable_dtype, self.dtype, - initializer=tf.truncated_normal_initializer(stddev=stddev)) - biases = self.get_variable('biases', [num_out_channels], - self.variable_dtype, self.dtype, - initializer=tf.constant_initializer(bias)) - logits = tf.nn.xw_plus_b(input_layer, kernel, biases) - if activation == 'relu': - affine1 = tf.nn.relu(logits, name=name) - elif activation == 'linear' or activation is None: - affine1 = logits - else: - raise KeyError('Invalid activation type \'%s\'' % activation) - self.top_layer = affine1 - self.top_size = num_out_channels - return affine1 + activation='relu', + bias=0.0): + """Construct a conv2d layer on top of cnn.""" + if input_layer is None: + input_layer = self.top_layer + if num_channels_in is None: + num_channels_in = self.top_size + kernel_initializer = None + if stddev is not None: + kernel_initializer = tf.truncated_normal_initializer(stddev=stddev) + name = 'conv' + str(self.counts['conv']) + self.counts['conv'] += 1 + with tf.variable_scope(name): + strides = [1, d_height, d_width, 1] + if self.data_format == 'NCHW': + strides = [strides[0], strides[3], strides[1], strides[2]] + if mode != 'SAME_RESNET': + conv = self._conv2d_impl( + input_layer, + num_channels_in, + num_out_channels, + kernel_size=[k_height, k_width], + strides=[d_height, d_width], + padding=mode, + kernel_initializer=kernel_initializer) + else: # Special padding mode for ResNet models + if d_height == 1 and d_width == 1: + conv = self._conv2d_impl( + input_layer, + num_channels_in, + num_out_channels, + kernel_size=[k_height, k_width], + strides=[d_height, d_width], + padding='SAME', + kernel_initializer=kernel_initializer) + else: + rate = 1 # Unused (for 'a trous' convolutions) + kernel_height_effective = k_height + (k_height - 1) * ( + rate - 1) + pad_h_beg = (kernel_height_effective - 1) // 2 + pad_h_end = kernel_height_effective - 1 - pad_h_beg + kernel_width_effective = k_width + (k_width - 1) * ( + rate - 1) + pad_w_beg = (kernel_width_effective - 1) // 2 + pad_w_end = kernel_width_effective - 1 - pad_w_beg + padding = [[0, 0], [pad_h_beg, pad_h_end], + [pad_w_beg, pad_w_end], [0, 0]] + if self.data_format == 'NCHW': + padding = [ + padding[0], padding[3], padding[1], padding[2] + ] + input_layer = tf.pad(input_layer, padding) + conv = self._conv2d_impl( + input_layer, + num_channels_in, + num_out_channels, + kernel_size=[k_height, k_width], + strides=[d_height, d_width], + padding='VALID', + kernel_initializer=kernel_initializer) + if use_batch_norm is None: + use_batch_norm = self.use_batch_norm + if not use_batch_norm: + if bias is not None: + biases = self.get_variable( + 'biases', [num_out_channels], + self.variable_dtype, + self.dtype, + initializer=tf.constant_initializer(bias)) + biased = tf.reshape( + tf.nn.bias_add( + conv, biases, data_format=self.data_format), + conv.get_shape()) + else: + biased = conv + else: + self.top_layer = conv + self.top_size = num_out_channels + biased = self.batch_norm(**self.batch_norm_config) + if activation == 'relu': + conv1 = tf.nn.relu(biased) + elif activation == 'linear' or activation is None: + conv1 = biased + elif activation == 'tanh': + conv1 = tf.nn.tanh(biased) + else: + raise KeyError('Invalid activation type \'%s\'' % activation) + self.top_layer = conv1 + self.top_size = num_out_channels + return conv1 + + def _pool(self, pool_name, pool_function, k_height, k_width, d_height, + d_width, mode, input_layer, num_channels_in): + """Construct a pooling layer.""" + if input_layer is None: + input_layer = self.top_layer + else: + self.top_size = num_channels_in + name = pool_name + str(self.counts[pool_name]) + self.counts[pool_name] += 1 + if self.use_tf_layers: + pool = pool_function( + input_layer, [k_height, k_width], [d_height, d_width], + padding=mode, + data_format=self.channel_pos, + name=name) + else: + if self.data_format == 'NHWC': + ksize = [1, k_height, k_width, 1] + strides = [1, d_height, d_width, 1] + else: + ksize = [1, 1, k_height, k_width] + strides = [1, 1, d_height, d_width] + pool = tf.nn.max_pool( + input_layer, + ksize, + strides, + padding=mode, + data_format=self.data_format, + name=name) + self.top_layer = pool + return pool - def inception_module(self, name, cols, input_layer=None, in_size=None): - if input_layer is None: - input_layer = self.top_layer - if in_size is None: - in_size = self.top_size - name += str(self.counts[name]) - self.counts[name] += 1 - with tf.variable_scope(name): - col_layers = [] - col_layer_sizes = [] - for c, col in enumerate(cols): - col_layers.append([]) - col_layer_sizes.append([]) - for l, layer in enumerate(col): - ltype, args = layer[0], layer[1:] - kwargs = { - 'input_layer': input_layer, - 'num_channels_in': in_size - } if l == 0 else {} - if ltype == 'conv': - self.conv(*args, **kwargs) - elif ltype == 'mpool': - self.mpool(*args, **kwargs) - elif ltype == 'apool': - self.apool(*args, **kwargs) - elif ltype == 'share': # Share matching layer from previous column - self.top_layer = col_layers[c - 1][l] - self.top_size = col_layer_sizes[c - 1][l] - else: - raise KeyError( - 'Invalid layer type for inception module: \'%s\'' % ltype) - col_layers[c].append(self.top_layer) - col_layer_sizes[c].append(self.top_size) - catdim = 3 if self.data_format == 'NHWC' else 1 - self.top_layer = tf.concat([layers[-1] for layers in col_layers], catdim) - self.top_size = sum([sizes[-1] for sizes in col_layer_sizes]) - return self.top_layer + def mpool(self, + k_height, + k_width, + d_height=2, + d_width=2, + mode='VALID', + input_layer=None, + num_channels_in=None): + """Construct a max pooling layer.""" + return self._pool('mpool', pooling_layers.max_pooling2d, k_height, + k_width, d_height, d_width, mode, input_layer, + num_channels_in) - def spatial_mean(self, keep_dims=False): - name = 'spatial_mean' + str(self.counts['spatial_mean']) - self.counts['spatial_mean'] += 1 - axes = [1, 2] if self.data_format == 'NHWC' else [2, 3] - self.top_layer = tf.reduce_mean( - self.top_layer, axes, keep_dims=keep_dims, name=name) - return self.top_layer + def apool(self, + k_height, + k_width, + d_height=2, + d_width=2, + mode='VALID', + input_layer=None, + num_channels_in=None): + """Construct an average pooling layer.""" + return self._pool('apool', pooling_layers.average_pooling2d, k_height, + k_width, d_height, d_width, mode, input_layer, + num_channels_in) - def dropout(self, keep_prob=0.5, input_layer=None): - if input_layer is None: - input_layer = self.top_layer - else: - self.top_size = None - name = 'dropout' + str(self.counts['dropout']) - with tf.variable_scope(name): - if not self.phase_train: - keep_prob = 1.0 - if self.use_tf_layers: - dropout = core_layers.dropout(input_layer, 1. - keep_prob) - else: - dropout = tf.nn.dropout(input_layer, keep_prob) - self.top_layer = dropout - return dropout + def reshape(self, shape, input_layer=None): + if input_layer is None: + input_layer = self.top_layer + self.top_layer = tf.reshape(input_layer, shape) + self.top_size = shape[-1] # HACK This may not always work + return self.top_layer - def _batch_norm_without_layers(self, input_layer, decay, use_scale, epsilon): - """Batch normalization on `input_layer` without tf.layers.""" - # We make this function as similar as possible to the - # tf.contrib.layers.batch_norm, to minimize the differences between using - # layers and not using layers. - shape = input_layer.shape - num_channels = shape[3] if self.data_format == 'NHWC' else shape[1] - beta = self.get_variable('beta', [num_channels], tf.float32, tf.float32, - initializer=tf.zeros_initializer()) - if use_scale: - gamma = self.get_variable('gamma', [num_channels], tf.float32, - tf.float32, initializer=tf.ones_initializer()) - else: - gamma = tf.constant(1.0, tf.float32, [num_channels]) - # For moving variables, we use tf.get_variable instead of self.get_variable, - # since self.get_variable returns the result of tf.cast which we cannot - # assign to. - moving_mean = tf.get_variable('moving_mean', [num_channels], - tf.float32, - initializer=tf.zeros_initializer(), - trainable=False) - moving_variance = tf.get_variable('moving_variance', [num_channels], - tf.float32, - initializer=tf.ones_initializer(), - trainable=False) - if self.phase_train: - bn, batch_mean, batch_variance = tf.nn.fused_batch_norm( - input_layer, gamma, beta, epsilon=epsilon, - data_format=self.data_format, is_training=True) - mean_update = moving_averages.assign_moving_average( - moving_mean, batch_mean, decay=decay, zero_debias=False) - variance_update = moving_averages.assign_moving_average( - moving_variance, batch_variance, decay=decay, zero_debias=False) - tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, mean_update) - tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, variance_update) - else: - bn, _, _ = tf.nn.fused_batch_norm( - input_layer, gamma, beta, mean=moving_mean, - variance=moving_variance, epsilon=epsilon, - data_format=self.data_format, is_training=False) - return bn + def affine(self, + num_out_channels, + input_layer=None, + num_channels_in=None, + bias=0.0, + stddev=None, + activation='relu'): + if input_layer is None: + input_layer = self.top_layer + if num_channels_in is None: + num_channels_in = self.top_size + name = 'affine' + str(self.counts['affine']) + self.counts['affine'] += 1 + with tf.variable_scope(name): + init_factor = 2. if activation == 'relu' else 1. + stddev = stddev or np.sqrt(init_factor / num_channels_in) + kernel = self.get_variable( + 'weights', [num_channels_in, num_out_channels], + self.variable_dtype, + self.dtype, + initializer=tf.truncated_normal_initializer(stddev=stddev)) + biases = self.get_variable( + 'biases', [num_out_channels], + self.variable_dtype, + self.dtype, + initializer=tf.constant_initializer(bias)) + logits = tf.nn.xw_plus_b(input_layer, kernel, biases) + if activation == 'relu': + affine1 = tf.nn.relu(logits, name=name) + elif activation == 'linear' or activation is None: + affine1 = logits + else: + raise KeyError('Invalid activation type \'%s\'' % activation) + self.top_layer = affine1 + self.top_size = num_out_channels + return affine1 - def batch_norm(self, input_layer=None, decay=0.999, scale=False, - epsilon=0.001): - """Adds a Batch Normalization layer.""" - if input_layer is None: - input_layer = self.top_layer - else: - self.top_size = None - name = 'batchnorm' + str(self.counts['batchnorm']) - self.counts['batchnorm'] += 1 + def inception_module(self, name, cols, input_layer=None, in_size=None): + if input_layer is None: + input_layer = self.top_layer + if in_size is None: + in_size = self.top_size + name += str(self.counts[name]) + self.counts[name] += 1 + with tf.variable_scope(name): + col_layers = [] + col_layer_sizes = [] + for c, col in enumerate(cols): + col_layers.append([]) + col_layer_sizes.append([]) + for l, layer in enumerate(col): + ltype, args = layer[0], layer[1:] + kwargs = { + 'input_layer': input_layer, + 'num_channels_in': in_size + } if l == 0 else {} + if ltype == 'conv': + self.conv(*args, **kwargs) + elif ltype == 'mpool': + self.mpool(*args, **kwargs) + elif ltype == 'apool': + self.apool(*args, **kwargs) + elif ltype == 'share': # Share matching layer from previous column + self.top_layer = col_layers[c - 1][l] + self.top_size = col_layer_sizes[c - 1][l] + else: + raise KeyError( + 'Invalid layer type for inception module: \'%s\'' % + ltype) + col_layers[c].append(self.top_layer) + col_layer_sizes[c].append(self.top_size) + catdim = 3 if self.data_format == 'NHWC' else 1 + self.top_layer = tf.concat([layers[-1] for layers in col_layers], + catdim) + self.top_size = sum([sizes[-1] for sizes in col_layer_sizes]) + return self.top_layer + + def spatial_mean(self, keep_dims=False): + name = 'spatial_mean' + str(self.counts['spatial_mean']) + self.counts['spatial_mean'] += 1 + axes = [1, 2] if self.data_format == 'NHWC' else [2, 3] + self.top_layer = tf.reduce_mean( + self.top_layer, axes, keep_dims=keep_dims, name=name) + return self.top_layer + + def dropout(self, keep_prob=0.5, input_layer=None): + if input_layer is None: + input_layer = self.top_layer + else: + self.top_size = None + name = 'dropout' + str(self.counts['dropout']) + with tf.variable_scope(name): + if not self.phase_train: + keep_prob = 1.0 + if self.use_tf_layers: + dropout = core_layers.dropout(input_layer, 1. - keep_prob) + else: + dropout = tf.nn.dropout(input_layer, keep_prob) + self.top_layer = dropout + return dropout + + def _batch_norm_without_layers(self, input_layer, decay, use_scale, + epsilon): + """Batch normalization on `input_layer` without tf.layers.""" + # We make this function as similar as possible to the + # tf.contrib.layers.batch_norm, to minimize the differences between using + # layers and not using layers. + shape = input_layer.shape + num_channels = shape[3] if self.data_format == 'NHWC' else shape[1] + beta = self.get_variable( + 'beta', [num_channels], + tf.float32, + tf.float32, + initializer=tf.zeros_initializer()) + if use_scale: + gamma = self.get_variable( + 'gamma', [num_channels], + tf.float32, + tf.float32, + initializer=tf.ones_initializer()) + else: + gamma = tf.constant(1.0, tf.float32, [num_channels]) + # For moving variables, we use tf.get_variable instead of self.get_variable, + # since self.get_variable returns the result of tf.cast which we cannot + # assign to. + moving_mean = tf.get_variable( + 'moving_mean', [num_channels], + tf.float32, + initializer=tf.zeros_initializer(), + trainable=False) + moving_variance = tf.get_variable( + 'moving_variance', [num_channels], + tf.float32, + initializer=tf.ones_initializer(), + trainable=False) + if self.phase_train: + bn, batch_mean, batch_variance = tf.nn.fused_batch_norm( + input_layer, + gamma, + beta, + epsilon=epsilon, + data_format=self.data_format, + is_training=True) + mean_update = moving_averages.assign_moving_average( + moving_mean, batch_mean, decay=decay, zero_debias=False) + variance_update = moving_averages.assign_moving_average( + moving_variance, + batch_variance, + decay=decay, + zero_debias=False) + tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, mean_update) + tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, variance_update) + else: + bn, _, _ = tf.nn.fused_batch_norm( + input_layer, + gamma, + beta, + mean=moving_mean, + variance=moving_variance, + epsilon=epsilon, + data_format=self.data_format, + is_training=False) + return bn + + def batch_norm(self, + input_layer=None, + decay=0.999, + scale=False, + epsilon=0.001): + """Adds a Batch Normalization layer.""" + if input_layer is None: + input_layer = self.top_layer + else: + self.top_size = None + name = 'batchnorm' + str(self.counts['batchnorm']) + self.counts['batchnorm'] += 1 - with tf.variable_scope(name) as scope: - if self.use_tf_layers: - bn = tf.contrib.layers.batch_norm( - input_layer, - decay=decay, - scale=scale, - epsilon=epsilon, - is_training=self.phase_train, - fused=True, - data_format=self.data_format, - scope=scope) - else: - bn = self._batch_norm_without_layers(input_layer, decay, scale, epsilon) - self.top_layer = bn - self.top_size = bn.shape[3] if self.data_format == 'NHWC' else bn.shape[1] - self.top_size = int(self.top_size) - return bn + with tf.variable_scope(name) as scope: + if self.use_tf_layers: + bn = tf.contrib.layers.batch_norm( + input_layer, + decay=decay, + scale=scale, + epsilon=epsilon, + is_training=self.phase_train, + fused=True, + data_format=self.data_format, + scope=scope) + else: + bn = self._batch_norm_without_layers(input_layer, decay, scale, + epsilon) + self.top_layer = bn + self.top_size = bn.shape[ + 3] if self.data_format == 'NHWC' else bn.shape[1] + self.top_size = int(self.top_size) + return bn - def lrn(self, depth_radius, bias, alpha, beta): - """Adds a local response normalization layer.""" - name = 'lrn' + str(self.counts['lrn']) - self.counts['lrn'] += 1 - self.top_layer = tf.nn.lrn( - self.top_layer, depth_radius, bias, alpha, beta, name=name) - return self.top_layer + def lrn(self, depth_radius, bias, alpha, beta): + """Adds a local response normalization layer.""" + name = 'lrn' + str(self.counts['lrn']) + self.counts['lrn'] += 1 + self.top_layer = tf.nn.lrn( + self.top_layer, depth_radius, bias, alpha, beta, name=name) + return self.top_layer diff --git a/python/ray/experimental/sgd/tfbench/model.py b/python/ray/experimental/sgd/tfbench/model.py index 1c3fd658f676b..02ebebb52f070 100644 --- a/python/ray/experimental/sgd/tfbench/model.py +++ b/python/ray/experimental/sgd/tfbench/model.py @@ -19,56 +19,56 @@ class Model(object): - """Base model configuration for CNN benchmarks.""" + """Base model configuration for CNN benchmarks.""" - def __init__(self, - model, - image_size, - batch_size, - learning_rate, - layer_counts=None, - fp16_loss_scale=128): - self.model = model - self.image_size = image_size - self.batch_size = batch_size - self.default_batch_size = batch_size - self.learning_rate = learning_rate - self.layer_counts = layer_counts - # TODO(reedwm) Set custom loss scales for each model instead of using the - # default of 128. - self.fp16_loss_scale = fp16_loss_scale + def __init__(self, + model, + image_size, + batch_size, + learning_rate, + layer_counts=None, + fp16_loss_scale=128): + self.model = model + self.image_size = image_size + self.batch_size = batch_size + self.default_batch_size = batch_size + self.learning_rate = learning_rate + self.layer_counts = layer_counts + # TODO(reedwm) Set custom loss scales for each model instead of using the + # default of 128. + self.fp16_loss_scale = fp16_loss_scale - def get_model(self): - return self.model + def get_model(self): + return self.model - def get_image_size(self): - return self.image_size + def get_image_size(self): + return self.image_size - def get_batch_size(self): - return self.batch_size + def get_batch_size(self): + return self.batch_size - def set_batch_size(self, batch_size): - self.batch_size = batch_size + def set_batch_size(self, batch_size): + self.batch_size = batch_size - def get_default_batch_size(self): - return self.default_batch_size + def get_default_batch_size(self): + return self.default_batch_size - def get_layer_counts(self): - return self.layer_counts + def get_layer_counts(self): + return self.layer_counts - def get_fp16_loss_scale(self): - return self.fp16_loss_scale + def get_fp16_loss_scale(self): + return self.fp16_loss_scale - def get_learning_rate(self, global_step, batch_size): - del global_step - del batch_size - return self.learning_rate + def get_learning_rate(self, global_step, batch_size): + del global_step + del batch_size + return self.learning_rate - def add_inference(self, unused_cnn): - raise ValueError('Must be implemented in derived classes') + def add_inference(self, unused_cnn): + raise ValueError('Must be implemented in derived classes') - def skip_final_affine_layer(self): - """Returns if the caller of this class should skip the final affine layer. + def skip_final_affine_layer(self): + """Returns if the caller of this class should skip the final affine layer. Normally, this class adds a final affine layer to the model after calling self.add_inference(), to generate the logits. If a subclass override this @@ -76,39 +76,46 @@ def skip_final_affine_layer(self): This is useful for tests. """ - return False - - def build_network(self, images, phase_train=True, nclass=1001, image_depth=3, - data_type=tf.float32, data_format='NCHW', - use_tf_layers=True, fp16_vars=False): - """Returns logits and aux_logits from images.""" - if data_format == 'NCHW': - images = tf.transpose(images, [0, 3, 1, 2]) - var_type = tf.float32 - if data_type == tf.float16 and fp16_vars: - var_type = tf.float16 - network = convnet_builder.ConvNetBuilder( - images, image_depth, phase_train, use_tf_layers, - data_format, data_type, var_type) - with tf.variable_scope('cg', custom_getter=network.get_custom_getter()): - self.add_inference(network) - # Add the final fully-connected class layer - logits = (network.affine(nclass, activation='linear') - if not self.skip_final_affine_layer() - else network.top_layer) - aux_logits = None - if network.aux_top_layer is not None: - with network.switch_to_aux_top_layer(): - aux_logits = network.affine( - nclass, activation='linear', stddev=0.001) - if data_type == tf.float16: - # TODO(reedwm): Determine if we should do this cast here. - logits = tf.cast(logits, tf.float32) - if aux_logits is not None: - aux_logits = tf.cast(aux_logits, tf.float32) - return logits, aux_logits - - # Subclasses can override this to define their own loss function. By default, - # benchmark_cnn.py defines its own loss function. If overridden, it must have - # the same signature as benchmark_cnn.loss_function. - loss_function = None + return False + + def build_network(self, + images, + phase_train=True, + nclass=1001, + image_depth=3, + data_type=tf.float32, + data_format='NCHW', + use_tf_layers=True, + fp16_vars=False): + """Returns logits and aux_logits from images.""" + if data_format == 'NCHW': + images = tf.transpose(images, [0, 3, 1, 2]) + var_type = tf.float32 + if data_type == tf.float16 and fp16_vars: + var_type = tf.float16 + network = convnet_builder.ConvNetBuilder( + images, image_depth, phase_train, use_tf_layers, data_format, + data_type, var_type) + with tf.variable_scope( + 'cg', custom_getter=network.get_custom_getter()): + self.add_inference(network) + # Add the final fully-connected class layer + logits = (network.affine(nclass, activation='linear') + if not self.skip_final_affine_layer() else + network.top_layer) + aux_logits = None + if network.aux_top_layer is not None: + with network.switch_to_aux_top_layer(): + aux_logits = network.affine( + nclass, activation='linear', stddev=0.001) + if data_type == tf.float16: + # TODO(reedwm): Determine if we should do this cast here. + logits = tf.cast(logits, tf.float32) + if aux_logits is not None: + aux_logits = tf.cast(aux_logits, tf.float32) + return logits, aux_logits + + # Subclasses can override this to define their own loss function. By default, + # benchmark_cnn.py defines its own loss function. If overridden, it must have + # the same signature as benchmark_cnn.loss_function. + loss_function = None diff --git a/python/ray/experimental/sgd/tfbench/model_config.py b/python/ray/experimental/sgd/tfbench/model_config.py index e993cf63fa33e..387bc0345ed8a 100644 --- a/python/ray/experimental/sgd/tfbench/model_config.py +++ b/python/ray/experimental/sgd/tfbench/model_config.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Model configurations for CNN benchmarks. """ from . import resnet_model - _model_name_to_imagenet_model = { 'resnet50': resnet_model.create_resnet50_model, 'resnet50_v2': resnet_model.create_resnet50_v2_model, @@ -28,34 +26,32 @@ 'resnet152_v2': resnet_model.create_resnet152_v2_model, } - -_model_name_to_cifar_model = { -} +_model_name_to_cifar_model = {} def _get_model_map(dataset_name): - if 'cifar10' == dataset_name: - return _model_name_to_cifar_model - elif dataset_name in ('imagenet', 'synthetic'): - return _model_name_to_imagenet_model - else: - raise ValueError('Invalid dataset name: %s' % dataset_name) + if 'cifar10' == dataset_name: + return _model_name_to_cifar_model + elif dataset_name in ('imagenet', 'synthetic'): + return _model_name_to_imagenet_model + else: + raise ValueError('Invalid dataset name: %s' % dataset_name) def get_model_config(model_name, dataset): - """Map model name to model network configuration.""" - model_map = _get_model_map(dataset.name) - if model_name not in model_map: - raise ValueError('Invalid model name \'%s\' for dataset \'%s\'' % - (model_name, dataset.name)) - else: - return model_map[model_name]() + """Map model name to model network configuration.""" + model_map = _get_model_map(dataset.name) + if model_name not in model_map: + raise ValueError('Invalid model name \'%s\' for dataset \'%s\'' % + (model_name, dataset.name)) + else: + return model_map[model_name]() def register_model(model_name, dataset_name, model_func): - """Register a new model that can be obtained with `get_model_config`.""" - model_map = _get_model_map(dataset_name) - if model_name in model_map: - raise ValueError('Model "%s" is already registered for dataset "%s"' % - (model_name, dataset_name)) - model_map[model_name] = model_func + """Register a new model that can be obtained with `get_model_config`.""" + model_map = _get_model_map(dataset_name) + if model_name in model_map: + raise ValueError('Model "%s" is already registered for dataset "%s"' % + (model_name, dataset_name)) + model_map[model_name] = model_func diff --git a/python/ray/experimental/sgd/tfbench/modified_allreduce.py b/python/ray/experimental/sgd/tfbench/modified_allreduce.py index f86cdddfe219e..1b7c50521fd15 100644 --- a/python/ray/experimental/sgd/tfbench/modified_allreduce.py +++ b/python/ray/experimental/sgd/tfbench/modified_allreduce.py @@ -25,34 +25,35 @@ from tensorflow.contrib import nccl from tensorflow.contrib.all_reduce.python import all_reduce -AllReduceSpecTuple = pycoll.namedtuple('AllReduceSpecTuple', 'alg shards limit') +AllReduceSpecTuple = pycoll.namedtuple('AllReduceSpecTuple', + 'alg shards limit') def parse_general_int(s): - """Parse integer with power-of-2 suffix eg. 32k.""" - mo = re.match(r'(\d+)([KkMGT]?)$', s) - if mo: - i, suffix = mo.group(1, 2) - v = int(i) - if suffix: - if suffix == 'K' or suffix == 'k': - v *= 1024 - elif suffix == 'M': - v *= (1024 * 1024) - elif suffix == 'G': - v *= (1024 * 1024 * 1024) - elif suffix == 'T': - v *= (1024 * 1024 * 1024 * 1024) - else: - raise ValueError('invalid integer string %s' % s) + """Parse integer with power-of-2 suffix eg. 32k.""" + mo = re.match(r'(\d+)([KkMGT]?)$', s) + if mo: + i, suffix = mo.group(1, 2) + v = int(i) + if suffix: + if suffix == 'K' or suffix == 'k': + v *= 1024 + elif suffix == 'M': + v *= (1024 * 1024) + elif suffix == 'G': + v *= (1024 * 1024 * 1024) + elif suffix == 'T': + v *= (1024 * 1024 * 1024 * 1024) + else: + raise ValueError('invalid integer string %s' % s) + return v + else: + v = int(s) return v - else: - v = int(s) - return v def parse_all_reduce_spec(all_reduce_spec): - """Parse all_reduce_spec. + """Parse all_reduce_spec. Args: all_reduce_spec: a string specifying a combination of all-reduce @@ -101,44 +102,48 @@ def parse_all_reduce_spec(all_reduce_spec): locally reduced tensors across workers (on the worker CPUs), then scatter the globally reduced values locally from each worker CPU. """ - range_parts = all_reduce_spec.split(':') + ['-1'] - if len(range_parts) % 2: - raise ValueError('all_reduce_spec not well formed: %s' % all_reduce_spec) - limit = 0 - spec = [] - alg = None - shards = 1 - for i, range_part in enumerate(range_parts): - if i % 2 == 1: - try: - limit = parse_general_int(range_part) - spec.append(AllReduceSpecTuple(alg=alg, shards=shards, limit=limit)) - except ValueError: - raise ValueError('all_reduce_spec (%s) contains non-integer range %s' % - (all_reduce_spec, range_part)) - else: - alg = range_part - alg_parts = range_part.split('#') - alg = alg_parts[0] - if len(alg_parts) > 1: - try: - shards = int(alg_parts[1]) - except ValueError: - raise ValueError('all_reduce_spec (%s) contains non-integer ' - 'shards %s' % all_reduce_spec, alg_parts[1]) - else: - shards = 1 - if alg not in [ - 'nccl', 'nccl/xring', 'nccl/rechd', 'nccl/pscpu', 'xring', 'pscpu', - 'psgpu', 'pscpu/pscpu' - ]: - raise ValueError('all_reduce_spec (%s) contains invalid alg %s' % - (all_reduce_spec, alg)) - return spec + range_parts = all_reduce_spec.split(':') + ['-1'] + if len(range_parts) % 2: + raise ValueError( + 'all_reduce_spec not well formed: %s' % all_reduce_spec) + limit = 0 + spec = [] + alg = None + shards = 1 + for i, range_part in enumerate(range_parts): + if i % 2 == 1: + try: + limit = parse_general_int(range_part) + spec.append( + AllReduceSpecTuple(alg=alg, shards=shards, limit=limit)) + except ValueError: + raise ValueError( + 'all_reduce_spec (%s) contains non-integer range %s' % + (all_reduce_spec, range_part)) + else: + alg = range_part + alg_parts = range_part.split('#') + alg = alg_parts[0] + if len(alg_parts) > 1: + try: + shards = int(alg_parts[1]) + except ValueError: + raise ValueError( + 'all_reduce_spec (%s) contains non-integer ' + 'shards %s' % all_reduce_spec, alg_parts[1]) + else: + shards = 1 + if alg not in [ + 'nccl', 'nccl/xring', 'nccl/rechd', 'nccl/pscpu', 'xring', + 'pscpu', 'psgpu', 'pscpu/pscpu' + ]: + raise ValueError('all_reduce_spec (%s) contains invalid alg %s' + % (all_reduce_spec, alg)) + return spec def build_all_reduce_device_prefixes(job_name, num_tasks): - """Build list of device prefix names for all_reduce. + """Build list of device prefix names for all_reduce. Args: job_name: 'worker', 'ps' or 'localhost'. @@ -149,15 +154,15 @@ def build_all_reduce_device_prefixes(job_name, num_tasks): host name without adding the device. e.g. '/job:worker/task:0' """ - if job_name != 'localhost': - return ['/job:%s/task:%d' % (job_name, d) for d in range(0, num_tasks)] - else: - assert num_tasks == 1 - return ['/job:%s' % job_name] + if job_name != 'localhost': + return ['/job:%s/task:%d' % (job_name, d) for d in range(0, num_tasks)] + else: + assert num_tasks == 1 + return ['/job:%s' % job_name] def group_device_names(devices, group_size): - """Group device names into groups of group_size. + """Group device names into groups of group_size. Args: devices: list of strings naming devices. @@ -172,20 +177,21 @@ def group_device_names(devices, group_size): Raises: ValueError: group_size > len(devices) """ - num_devices = len(devices) - if group_size > num_devices: - raise ValueError('only %d devices, but group_size=%d' % (num_devices, - group_size)) - num_groups = ( - num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) - groups = [[] for i in range(num_groups)] - for i in range(0, num_groups * group_size): - groups[i % num_groups].append(devices[i % num_devices]) - return groups + num_devices = len(devices) + if group_size > num_devices: + raise ValueError( + 'only %d devices, but group_size=%d' % (num_devices, group_size)) + num_groups = ( + num_devices // group_size + (1 if + (num_devices % group_size != 0) else 0)) + groups = [[] for i in range(num_groups)] + for i in range(0, num_groups * group_size): + groups[i % num_groups].append(devices[i % num_devices]) + return groups def split_grads_by_size(threshold_size, device_grads): - """Break gradients into two sets according to tensor size. + """Break gradients into two sets according to tensor size. Args: threshold_size: int size cutoff for small vs large tensor. @@ -198,22 +204,22 @@ def split_grads_by_size(threshold_size, device_grads): large_grads: Subset of device_grads where shape is > threshold_size elements. """ - small_grads = [] - large_grads = [] - for dl in device_grads: - small_dl = [] - large_dl = [] - for (g, v) in dl: - tensor_size = g.get_shape().num_elements() - if tensor_size <= threshold_size: - small_dl.append([g, v]) - else: - large_dl.append([g, v]) - if small_dl: - small_grads.append(small_dl) - if large_dl: - large_grads.append(large_dl) - return small_grads, large_grads + small_grads = [] + large_grads = [] + for dl in device_grads: + small_dl = [] + large_dl = [] + for (g, v) in dl: + tensor_size = g.get_shape().num_elements() + if tensor_size <= threshold_size: + small_dl.append([g, v]) + else: + large_dl.append([g, v]) + if small_dl: + small_grads.append(small_dl) + if large_dl: + large_grads.append(large_dl) + return small_grads, large_grads def build_reduce_sum(scaled_grads): @@ -221,12 +227,14 @@ def build_reduce_sum(scaled_grads): reduced = tf.reduce_sum(stacked, 0) return [reduced] * len(scaled_grads) + def build_trivial_sum(scaled_grads): return scaled_grads + def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, check_inf_nan): - """Calculate the average gradient for a shared variable across all towers. + """Calculate the average gradient for a shared variable across all towers. Note that this function provides a synchronization point across all towers. @@ -243,23 +251,23 @@ def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, gradient has been averaged across all towers. The variable is chosen from the first tower. The has_nan_or_inf indicates the grads has nan or inf. """ - grads = [g for g, _ in grad_and_vars] - grad = tf.add_n(grads) + grads = [g for g, _ in grad_and_vars] + grad = tf.add_n(grads) - if use_mean and len(grads) > 1: - grad = tf.multiply(grad, 1.0 / len(grads)) + if use_mean and len(grads) > 1: + grad = tf.multiply(grad, 1.0 / len(grads)) - v = grad_and_vars[0][1] - if check_inf_nan: - has_nan_or_inf = tf.logical_not(tf.reduce_all(tf.is_finite(grads))) - return (grad, v), has_nan_or_inf - else: - return (grad, v), None + v = grad_and_vars[0][1] + if check_inf_nan: + has_nan_or_inf = tf.logical_not(tf.reduce_all(tf.is_finite(grads))) + return (grad, v), has_nan_or_inf + else: + return (grad, v), None def aggregate_gradients_using_copy_with_device_selection( - tower_grads, avail_devices, use_mean=True, check_inf_nan=False): - """Aggregate gradients, controlling device for the aggregation. + tower_grads, avail_devices, use_mean=True, check_inf_nan=False): + """Aggregate gradients, controlling device for the aggregation. Args: tower_grads: List of lists of (gradient, variable) tuples. The outer list @@ -272,15 +280,15 @@ def aggregate_gradients_using_copy_with_device_selection( gradient has been averaged across all towers. The variable is chosen from the first tower. The has_nan_or_inf indicates the grads has nan or inf. """ - agg_grads = [] - has_nan_or_inf_list = [] - for i, single_grads in enumerate(zip(*tower_grads)): - with tf.device(avail_devices[i % len(avail_devices)]): - grad_and_var, has_nan_or_inf = aggregate_single_gradient_using_copy( - single_grads, use_mean, check_inf_nan) - agg_grads.append(grad_and_var) - has_nan_or_inf_list.append(has_nan_or_inf) - return agg_grads + agg_grads = [] + has_nan_or_inf_list = [] + for i, single_grads in enumerate(zip(*tower_grads)): + with tf.device(avail_devices[i % len(avail_devices)]): + grad_and_var, has_nan_or_inf = aggregate_single_gradient_using_copy( + single_grads, use_mean, check_inf_nan) + agg_grads.append(grad_and_var) + has_nan_or_inf_list.append(has_nan_or_inf) + return agg_grads def sum_grad_and_var_all_reduce(grad_and_vars, @@ -289,51 +297,51 @@ def sum_grad_and_var_all_reduce(grad_and_vars, gpu_indices, aux_devices=None, num_shards=1): - """Apply all-reduce algorithm over specified gradient tensors.""" - with tf.name_scope('allreduce'): - # Note that each grad_and_vars looks like the following: - # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) - scaled_grads = [g for g, _ in grad_and_vars] - if alg == 'nccl': - summed_grads = nccl.all_sum(scaled_grads) - elif alg == 'simple': - summed_grads = build_reduce_sum(scaled_grads) - elif alg == 'trivial': - summed_grads = build_trivial_sum(scaled_grads) - elif alg == 'xring': - summed_grads = all_reduce.build_ring_all_reduce( - scaled_grads, num_workers, num_shards, gpu_indices, tf.add) - elif alg == 'nccl/xring': - summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, - tf.add) - elif alg == 'nccl/rechd': - summed_grads = all_reduce.build_nccl_then_recursive_hd( - scaled_grads, tf.add) - elif alg == 'nccl/pscpu': - summed_grads = all_reduce.build_nccl_then_shuffle( - scaled_grads, aux_devices, tf.add, tf.add_n) - elif alg == 'pscpu/pscpu': - summed_grads = all_reduce.build_shuffle_then_shuffle( - scaled_grads, - aux_devices, - # TODO(tucker): devise a way of better specifying the device set - # for the second level. - [aux_devices[0]], - tf.add_n) - elif alg in ['pscpu', 'psgpu']: - summed_grads = all_reduce.build_shuffle_all_reduce( - scaled_grads, aux_devices, tf.add_n) - else: - raise ValueError('unsupported all_reduce alg: ', alg) - - result = [] - for (_, v), g in zip(grad_and_vars, summed_grads): - result.append([g, v]) - return result + """Apply all-reduce algorithm over specified gradient tensors.""" + with tf.name_scope('allreduce'): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + scaled_grads = [g for g, _ in grad_and_vars] + if alg == 'nccl': + summed_grads = nccl.all_sum(scaled_grads) + elif alg == 'simple': + summed_grads = build_reduce_sum(scaled_grads) + elif alg == 'trivial': + summed_grads = build_trivial_sum(scaled_grads) + elif alg == 'xring': + summed_grads = all_reduce.build_ring_all_reduce( + scaled_grads, num_workers, num_shards, gpu_indices, tf.add) + elif alg == 'nccl/xring': + summed_grads = all_reduce.build_nccl_then_ring( + scaled_grads, num_shards, tf.add) + elif alg == 'nccl/rechd': + summed_grads = all_reduce.build_nccl_then_recursive_hd( + scaled_grads, tf.add) + elif alg == 'nccl/pscpu': + summed_grads = all_reduce.build_nccl_then_shuffle( + scaled_grads, aux_devices, tf.add, tf.add_n) + elif alg == 'pscpu/pscpu': + summed_grads = all_reduce.build_shuffle_then_shuffle( + scaled_grads, + aux_devices, + # TODO(tucker): devise a way of better specifying the device set + # for the second level. + [aux_devices[0]], + tf.add_n) + elif alg in ['pscpu', 'psgpu']: + summed_grads = all_reduce.build_shuffle_all_reduce( + scaled_grads, aux_devices, tf.add_n) + else: + raise ValueError('unsupported all_reduce alg: ', alg) + + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result def contains_any(haystack, needles): - """Tests if any needle is a substring of haystack. + """Tests if any needle is a substring of haystack. Args: haystack: a string @@ -343,10 +351,10 @@ def contains_any(haystack, needles): True if any element of needles is a substring of haystack, False otherwise. """ - for n in needles: - if n in haystack: - return True - return False + for n in needles: + if n in haystack: + return True + return False def sum_gradients_all_reduce(dev_prefixes, @@ -356,7 +364,7 @@ def sum_gradients_all_reduce(dev_prefixes, num_shards, gpu_indices, agg_small_grads_max_bytes=0): - """Apply all-reduce algorithm over specified gradient tensors. + """Apply all-reduce algorithm over specified gradient tensors. Args: dev_prefixes: list of prefix strings to use to generate PS device names. @@ -371,69 +379,68 @@ def sum_gradients_all_reduce(dev_prefixes, Returns: list of reduced tensors, packing values """ - alg_contains_shuffle = contains_any(alg, ['pscpu', 'psgpu']) - is_hierarchical = '/' in alg - if 'pscpu' in alg: - aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] - elif 'psgpu' in alg: - aux_devices = [ - prefix + '/gpu:%d' % i - for i in range(len(gpu_indices)) - for prefix in dev_prefixes - ] - else: - aux_devices = ['/job:localhost/cpu:0'] - aux_device_groups = group_device_names(aux_devices, num_shards - if alg_contains_shuffle else 1) - group_index = 0 - if agg_small_grads_max_bytes > 0: - tower_grads, packing = pack_small_tensors( - tower_grads, - max_bytes=agg_small_grads_max_bytes) - - else: - packing = None - new_tower_grads = [] - if alg == 'better': - raw_devices = ['/gpu:%i' % (i) for i in gpu_indices] - agg_grads = aggregate_gradients_using_copy_with_device_selection( - tower_grads, raw_devices) - for arr in tower_grads: - new_tower_grads.append( - [(g, v) for (_, v), (g, _) in zip(arr, agg_grads)]) - else: - reduced_gv_list = [] - for grad_and_vars in zip(*tower_grads): - reduced_gv_list.append( - sum_grad_and_var_all_reduce( - grad_and_vars, num_workers, alg, gpu_indices, aux_devices - if is_hierarchical else aux_device_groups[group_index], num_shards)) - group_index = (group_index + 1) % len(aux_device_groups) - new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] - return new_tower_grads, packing + alg_contains_shuffle = contains_any(alg, ['pscpu', 'psgpu']) + is_hierarchical = '/' in alg + if 'pscpu' in alg: + aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] + elif 'psgpu' in alg: + aux_devices = [ + prefix + '/gpu:%d' % i for i in range(len(gpu_indices)) + for prefix in dev_prefixes + ] + else: + aux_devices = ['/job:localhost/cpu:0'] + aux_device_groups = group_device_names( + aux_devices, num_shards if alg_contains_shuffle else 1) + group_index = 0 + if agg_small_grads_max_bytes > 0: + tower_grads, packing = pack_small_tensors( + tower_grads, max_bytes=agg_small_grads_max_bytes) + + else: + packing = None + new_tower_grads = [] + if alg == 'better': + raw_devices = ['/gpu:%i' % (i) for i in gpu_indices] + agg_grads = aggregate_gradients_using_copy_with_device_selection( + tower_grads, raw_devices) + for arr in tower_grads: + new_tower_grads.append( + [(g, v) for (_, v), (g, _) in zip(arr, agg_grads)]) + else: + reduced_gv_list = [] + for grad_and_vars in zip(*tower_grads): + reduced_gv_list.append( + sum_grad_and_var_all_reduce( + grad_and_vars, num_workers, alg, gpu_indices, aux_devices + if is_hierarchical else aux_device_groups[group_index], + num_shards)) + group_index = (group_index + 1) % len(aux_device_groups) + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return new_tower_grads, packing + def print_stats(sizes): - def sizeof_fmt(num, suffix='B'): - for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']: - if abs(num) < 1024.0: - return "%3.1f%s%s" % (num, unit, suffix) - num /= 1024.0 - return "%.1f%s%s" % (num, 'Yi', suffix) - stats = { - "avg": np.mean(sizes), - "median": np.median(sizes), - "total size": np.sum(sizes) - } - print("Stats " + ", ".join( - ["%s: %s" % (k, sizeof_fmt(v)) for k, v in stats.items()])) - other_stats = { - "len": len(sizes) - } - print(", ".join(["%s: %f" % (k, v) for k, v in other_stats.items()])) + def sizeof_fmt(num, suffix='B'): + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + + stats = { + "avg": np.mean(sizes), + "median": np.median(sizes), + "total size": np.sum(sizes) + } + print("Stats " + + ", ".join(["%s: %s" % (k, sizeof_fmt(v)) for k, v in stats.items()])) + other_stats = {"len": len(sizes)} + print(", ".join(["%s: %f" % (k, v) for k, v in other_stats.items()])) def extract_ranges(index_list, range_size_limit=32): - """Extract consecutive ranges and singles from index_list. + """Extract consecutive ranges and singles from index_list. Args: index_list: List of monotone increasing non-negative integers. @@ -446,34 +453,34 @@ def extract_ranges(index_list, range_size_limit=32): consecutive elements in index_list, and singles is all of the other elements, in original order. """ - if not index_list: - return [], [] - first = index_list[0] - last = first - ranges = [] - singles = [] - for i in index_list[1:]: - if i == last + 1 and (last - first) <= range_size_limit: - last = i - else: - if last > first: + if not index_list: + return [], [] + first = index_list[0] + last = first + ranges = [] + singles = [] + for i in index_list[1:]: + if i == last + 1 and (last - first) <= range_size_limit: + last = i + else: + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + first = i + last = i + if last > first: ranges.append([first, last]) - else: + else: singles.append(first) - first = i - last = i - if last > first: - ranges.append([first, last]) - else: - singles.append(first) - return ranges, singles + return ranges, singles GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') def pack_range(key, packing, grad_vars, rng): - """Form the concatenation of a specified range of gradient tensors. + """Form the concatenation of a specified range of gradient tensors. Args: key: Value under which to store meta-data in packing that will be used @@ -486,26 +493,26 @@ def pack_range(key, packing, grad_vars, rng): Returns: A tensor that is the concatenation of all the specified small tensors. """ - to_pack = grad_vars[rng[0]:rng[1] + 1] - members = [] - variables = [] - restore_shapes = [] - with tf.name_scope('pack'): - for g, v in to_pack: - variables.append(v) - restore_shapes.append(g.shape) - with tf.device(g.device): - members.append(tf.reshape(g, [-1])) - packing[key] = GradPackTuple( - indices=range(rng[0], rng[1] + 1), - vars=variables, - shapes=restore_shapes) - with tf.device(members[0].device): - return tf.concat(members, 0) + to_pack = grad_vars[rng[0]:rng[1] + 1] + members = [] + variables = [] + restore_shapes = [] + with tf.name_scope('pack'): + for g, v in to_pack: + variables.append(v) + restore_shapes.append(g.shape) + with tf.device(g.device): + members.append(tf.reshape(g, [-1])) + packing[key] = GradPackTuple( + indices=range(rng[0], rng[1] + 1), + vars=variables, + shapes=restore_shapes) + with tf.device(members[0].device): + return tf.concat(members, 0) def unpack_grad_tuple(gv, gpt): - """Unpack a previously packed collection of gradient tensors. + """Unpack a previously packed collection of gradient tensors. Args: gv: A (grad, var) pair to be unpacked. @@ -516,18 +523,19 @@ def unpack_grad_tuple(gv, gpt): originally packed into gv, maybe following subsequent operations like reduction. """ - elt_widths = [x.num_elements() for x in gpt.shapes] - with tf.device(gv[0][0].device): - with tf.name_scope('unpack'): - splits = tf.split(gv[0], elt_widths) - unpacked_gv = [] - for idx, s in enumerate(splits): - unpacked_gv.append((tf.reshape(s, gpt.shapes[idx]), gpt.vars[idx])) - return unpacked_gv + elt_widths = [x.num_elements() for x in gpt.shapes] + with tf.device(gv[0][0].device): + with tf.name_scope('unpack'): + splits = tf.split(gv[0], elt_widths) + unpacked_gv = [] + for idx, s in enumerate(splits): + unpacked_gv.append((tf.reshape(s, gpt.shapes[idx]), + gpt.vars[idx])) + return unpacked_gv def pack_small_tensors(tower_grads, max_bytes=0): - """Concatenate gradients together more intelligently. + """Concatenate gradients together more intelligently. Does binpacking Args: @@ -535,59 +543,59 @@ def pack_small_tensors(tower_grads, max_bytes=0): max_bytes: Int giving max number of bytes in a tensor that may be considered small. """ - assert max_bytes >= 0 - orig_grads = [g for g, _ in tower_grads[0]] - # Check to make sure sizes are accurate; not entirely important - assert all(g.dtype == tf.float32 for g in orig_grads) - sizes = [4 * g.shape.num_elements() for g in orig_grads] - print("Before packing") - print_stats(sizes) - small_ranges = [] - large_indices = [] - new_sizes = [] - - def end_interval(indices, small_ranges, large_indices): - if len(indices) > 1: - small_ranges.insert(0, [indices[0], indices[-1]]) + assert max_bytes >= 0 + orig_grads = [g for g, _ in tower_grads[0]] + # Check to make sure sizes are accurate; not entirely important + assert all(g.dtype == tf.float32 for g in orig_grads) + sizes = [4 * g.shape.num_elements() for g in orig_grads] + print("Before packing") + print_stats(sizes) + small_ranges = [] + large_indices = [] + new_sizes = [] + + def end_interval(indices, small_ranges, large_indices): + if len(indices) > 1: + small_ranges.insert(0, [indices[0], indices[-1]]) + else: + large_indices.insert(0, indices[0]) + + cur_range = [] + cur_size = 0 + for i, s in reversed(list(enumerate(sizes))): + if cur_size > max_bytes: + end_interval(cur_range, small_ranges, large_indices) + new_sizes.insert(0, cur_size) + cur_range = [] + cur_size = 0 + cur_range.insert(0, i) + cur_size += s + end_interval(cur_range, small_ranges, large_indices) + new_sizes.insert(0, cur_size) + + print("After packing") + print_stats(new_sizes) + num_gv = len(orig_grads) + packing = {} + if len(small_ranges): + new_tower_grads = [] + for dev_idx, gv_list in enumerate(tower_grads): + assert len(gv_list) == num_gv + new_gv_list = [] + for r in small_ranges: + key = '%d:%d' % (dev_idx, len(new_gv_list)) + new_gv_list.append((pack_range(key, packing, gv_list, r), + 'packing_var_placeholder')) + for i in large_indices: + new_gv_list.append(gv_list[i]) + new_tower_grads.append(new_gv_list) + return new_tower_grads, packing else: - large_indices.insert(0, indices[0]) - - cur_range = [] - cur_size = 0 - for i, s in reversed(list(enumerate(sizes))): - if cur_size > max_bytes: - end_interval(cur_range, small_ranges, large_indices) - new_sizes.insert(0, cur_size) - cur_range = [] - cur_size = 0 - cur_range.insert(0, i) - cur_size += s - end_interval(cur_range, small_ranges, large_indices) - new_sizes.insert(0, cur_size) - - print("After packing") - print_stats(new_sizes) - num_gv = len(orig_grads) - packing = {} - if len(small_ranges): - new_tower_grads = [] - for dev_idx, gv_list in enumerate(tower_grads): - assert len(gv_list) == num_gv - new_gv_list = [] - for r in small_ranges: - key = '%d:%d' % (dev_idx, len(new_gv_list)) - new_gv_list.append((pack_range(key, packing, gv_list, r), - 'packing_var_placeholder')) - for i in large_indices: - new_gv_list.append(gv_list[i]) - new_tower_grads.append(new_gv_list) - return new_tower_grads, packing - else: - return tower_grads, None + return tower_grads, None def unpack_small_tensors(tower_grads, packing): - """Undo the structure alterations to tower_grads done by pack_small_tensors. + """Undo the structure alterations to tower_grads done by pack_small_tensors. Args: tower_grads: List of List of (grad, var) tuples. @@ -599,19 +607,19 @@ def unpack_small_tensors(tower_grads, packing): of small tensors have been split apart and returned to their original positions, paired with their original variables. """ - if not packing: - return tower_grads - new_tower_grads = [] - num_devices = len(tower_grads) - num_packed = len(packing.keys()) // num_devices - for dev_idx, gv_list in enumerate(tower_grads): - new_gv_list = gv_list[num_packed:] - for i in xrange(0, num_packed): - k = '%d:%d' % (dev_idx, i) - gpt = packing[k] - gv = unpack_grad_tuple(gv_list[i], gpt) - for gi, idx in enumerate(gpt.indices): - assert idx == gpt.indices[gi] - new_gv_list.insert(idx, gv[gi]) - new_tower_grads.append(new_gv_list) - return new_tower_grads + if not packing: + return tower_grads + new_tower_grads = [] + num_devices = len(tower_grads) + num_packed = len(packing.keys()) // num_devices + for dev_idx, gv_list in enumerate(tower_grads): + new_gv_list = gv_list[num_packed:] + for i in xrange(0, num_packed): + k = '%d:%d' % (dev_idx, i) + gpt = packing[k] + gv = unpack_grad_tuple(gv_list[i], gpt) + for gi, idx in enumerate(gpt.indices): + assert idx == gpt.indices[gi] + new_gv_list.insert(idx, gv[gi]) + new_tower_grads.append(new_gv_list) + return new_tower_grads diff --git a/python/ray/experimental/sgd/tfbench/resnet_model.py b/python/ray/experimental/sgd/tfbench/resnet_model.py index f2f348a02bd9b..38c1fc33a9204 100644 --- a/python/ray/experimental/sgd/tfbench/resnet_model.py +++ b/python/ray/experimental/sgd/tfbench/resnet_model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Resnet model configuration. References: @@ -38,7 +37,7 @@ def bottleneck_block_v1(cnn, depth, depth_bottleneck, stride): - """Bottleneck block with identity short-cut for ResNet v1. + """Bottleneck block with identity short-cut for ResNet v1. Args: cnn: the network to append bottleneck blocks. @@ -46,39 +45,64 @@ def bottleneck_block_v1(cnn, depth, depth_bottleneck, stride): depth_bottleneck: the number of bottleneck filters for this block. stride: Stride used in the first layer of the bottleneck block. """ - input_layer = cnn.top_layer - in_size = cnn.top_size - name_key = 'resnet_v1' - name = name_key + str(cnn.counts[name_key]) - cnn.counts[name_key] += 1 - - with tf.variable_scope(name): - if depth == in_size: - if stride == 1: - shortcut = input_layer - else: - shortcut = cnn.apool( - 1, 1, stride, stride, input_layer=input_layer, - num_channels_in=in_size) - else: - shortcut = cnn.conv( - depth, 1, 1, stride, stride, activation=None, - use_batch_norm=True, input_layer=input_layer, - num_channels_in=in_size, bias=None) - cnn.conv(depth_bottleneck, 1, 1, stride, stride, - input_layer=input_layer, num_channels_in=in_size, - use_batch_norm=True, bias=None) - cnn.conv(depth_bottleneck, 3, 3, 1, 1, mode='SAME_RESNET', - use_batch_norm=True, bias=None) - res = cnn.conv(depth, 1, 1, 1, 1, activation=None, - use_batch_norm=True, bias=None) - output = tf.nn.relu(shortcut + res) - cnn.top_layer = output - cnn.top_size = depth + input_layer = cnn.top_layer + in_size = cnn.top_size + name_key = 'resnet_v1' + name = name_key + str(cnn.counts[name_key]) + cnn.counts[name_key] += 1 + + with tf.variable_scope(name): + if depth == in_size: + if stride == 1: + shortcut = input_layer + else: + shortcut = cnn.apool( + 1, + 1, + stride, + stride, + input_layer=input_layer, + num_channels_in=in_size) + else: + shortcut = cnn.conv( + depth, + 1, + 1, + stride, + stride, + activation=None, + use_batch_norm=True, + input_layer=input_layer, + num_channels_in=in_size, + bias=None) + cnn.conv( + depth_bottleneck, + 1, + 1, + stride, + stride, + input_layer=input_layer, + num_channels_in=in_size, + use_batch_norm=True, + bias=None) + cnn.conv( + depth_bottleneck, + 3, + 3, + 1, + 1, + mode='SAME_RESNET', + use_batch_norm=True, + bias=None) + res = cnn.conv( + depth, 1, 1, 1, 1, activation=None, use_batch_norm=True, bias=None) + output = tf.nn.relu(shortcut + res) + cnn.top_layer = output + cnn.top_size = depth def bottleneck_block_v2(cnn, depth, depth_bottleneck, stride): - """Bottleneck block with identity short-cut for ResNet v2. + """Bottleneck block with identity short-cut for ResNet v2. The main difference from v1 is that a batch norm and relu are done at the start of the block, instead of the end. This initial batch norm and relu is @@ -90,40 +114,73 @@ def bottleneck_block_v2(cnn, depth, depth_bottleneck, stride): depth_bottleneck: the number of bottleneck filters for this block. stride: Stride used in the first layer of the bottleneck block. """ - input_layer = cnn.top_layer - in_size = cnn.top_size - name_key = 'resnet_v2' - name = name_key + str(cnn.counts[name_key]) - cnn.counts[name_key] += 1 - - preact = cnn.batch_norm() - preact = tf.nn.relu(preact) - with tf.variable_scope(name): - if depth == in_size: - if stride == 1: - shortcut = input_layer - else: - shortcut = cnn.apool( - 1, 1, stride, stride, input_layer=input_layer, - num_channels_in=in_size) - else: - shortcut = cnn.conv( - depth, 1, 1, stride, stride, activation=None, use_batch_norm=False, - input_layer=preact, num_channels_in=in_size, bias=None) - cnn.conv(depth_bottleneck, 1, 1, stride, stride, - input_layer=preact, num_channels_in=in_size, - use_batch_norm=True, bias=None) - cnn.conv(depth_bottleneck, 3, 3, 1, 1, mode='SAME_RESNET', - use_batch_norm=True, bias=None) - res = cnn.conv(depth, 1, 1, 1, 1, activation=None, - use_batch_norm=False, bias=None) - output = shortcut + res - cnn.top_layer = output - cnn.top_size = depth + input_layer = cnn.top_layer + in_size = cnn.top_size + name_key = 'resnet_v2' + name = name_key + str(cnn.counts[name_key]) + cnn.counts[name_key] += 1 + + preact = cnn.batch_norm() + preact = tf.nn.relu(preact) + with tf.variable_scope(name): + if depth == in_size: + if stride == 1: + shortcut = input_layer + else: + shortcut = cnn.apool( + 1, + 1, + stride, + stride, + input_layer=input_layer, + num_channels_in=in_size) + else: + shortcut = cnn.conv( + depth, + 1, + 1, + stride, + stride, + activation=None, + use_batch_norm=False, + input_layer=preact, + num_channels_in=in_size, + bias=None) + cnn.conv( + depth_bottleneck, + 1, + 1, + stride, + stride, + input_layer=preact, + num_channels_in=in_size, + use_batch_norm=True, + bias=None) + cnn.conv( + depth_bottleneck, + 3, + 3, + 1, + 1, + mode='SAME_RESNET', + use_batch_norm=True, + bias=None) + res = cnn.conv( + depth, + 1, + 1, + 1, + 1, + activation=None, + use_batch_norm=False, + bias=None) + output = shortcut + res + cnn.top_layer = output + cnn.top_size = depth def bottleneck_block(cnn, depth, depth_bottleneck, stride, pre_activation): - """Bottleneck block with identity short-cut. + """Bottleneck block with identity short-cut. Args: cnn: the network to append bottleneck blocks. @@ -132,14 +189,14 @@ def bottleneck_block(cnn, depth, depth_bottleneck, stride, pre_activation): stride: Stride used in the first layer of the bottleneck block. pre_activation: use pre_activation structure used in v2 or not. """ - if pre_activation: - bottleneck_block_v2(cnn, depth, depth_bottleneck, stride) - else: - bottleneck_block_v1(cnn, depth, depth_bottleneck, stride) + if pre_activation: + bottleneck_block_v2(cnn, depth, depth_bottleneck, stride) + else: + bottleneck_block_v1(cnn, depth, depth_bottleneck, stride) def residual_block(cnn, depth, stride, pre_activation): - """Residual block with identity short-cut. + """Residual block with identity short-cut. Args: cnn: the network to append residual blocks. @@ -147,116 +204,139 @@ def residual_block(cnn, depth, stride, pre_activation): stride: Stride used in the first layer of the residual block. pre_activation: use pre_activation structure or not. """ - input_layer = cnn.top_layer - in_size = cnn.top_size - if in_size != depth: - # Plan A of shortcut. - shortcut = cnn.apool(1, 1, stride, stride, - input_layer=input_layer, - num_channels_in=in_size) - padding = (depth - in_size) // 2 - if cnn.channel_pos == 'channels_last': - shortcut = tf.pad( - shortcut, [[0, 0], [0, 0], [0, 0], [padding, padding]]) + input_layer = cnn.top_layer + in_size = cnn.top_size + if in_size != depth: + # Plan A of shortcut. + shortcut = cnn.apool( + 1, + 1, + stride, + stride, + input_layer=input_layer, + num_channels_in=in_size) + padding = (depth - in_size) // 2 + if cnn.channel_pos == 'channels_last': + shortcut = tf.pad(shortcut, + [[0, 0], [0, 0], [0, 0], [padding, padding]]) + else: + shortcut = tf.pad(shortcut, + [[0, 0], [padding, padding], [0, 0], [0, 0]]) + else: + shortcut = input_layer + if pre_activation: + res = cnn.batch_norm(input_layer) + res = tf.nn.relu(res) else: - shortcut = tf.pad( - shortcut, [[0, 0], [padding, padding], [0, 0], [0, 0]]) - else: - shortcut = input_layer - if pre_activation: - res = cnn.batch_norm(input_layer) - res = tf.nn.relu(res) - else: - res = input_layer - cnn.conv(depth, 3, 3, stride, stride, - input_layer=res, num_channels_in=in_size, - use_batch_norm=True, bias=None) - if pre_activation: - res = cnn.conv(depth, 3, 3, 1, 1, activation=None, - use_batch_norm=False, bias=None) - output = shortcut + res - else: - res = cnn.conv(depth, 3, 3, 1, 1, activation=None, - use_batch_norm=True, bias=None) - output = tf.nn.relu(shortcut + res) - cnn.top_layer = output - cnn.top_size = depth + res = input_layer + cnn.conv( + depth, + 3, + 3, + stride, + stride, + input_layer=res, + num_channels_in=in_size, + use_batch_norm=True, + bias=None) + if pre_activation: + res = cnn.conv( + depth, + 3, + 3, + 1, + 1, + activation=None, + use_batch_norm=False, + bias=None) + output = shortcut + res + else: + res = cnn.conv( + depth, 3, 3, 1, 1, activation=None, use_batch_norm=True, bias=None) + output = tf.nn.relu(shortcut + res) + cnn.top_layer = output + cnn.top_size = depth class ResnetModel(model_lib.Model): - """Resnet cnn network configuration.""" - - def __init__(self, model, layer_counts): - default_batch_sizes = { - 'resnet50': 64, - 'resnet101': 32, - 'resnet152': 32, - 'resnet50_v2': 64, - 'resnet101_v2': 32, - 'resnet152_v2': 32, - } - batch_size = default_batch_sizes.get(model, 32) - super(ResnetModel, self).__init__(model, 224, batch_size, 0.005, - layer_counts) - self.pre_activation = 'v2' in model - - def add_inference(self, cnn): - if self.layer_counts is None: - raise ValueError('Layer counts not specified for %s' % self.get_model()) - cnn.use_batch_norm = True - cnn.batch_norm_config = {'decay': 0.997, 'epsilon': 1e-5, 'scale': True} - cnn.conv(64, 7, 7, 2, 2, mode='SAME_RESNET', use_batch_norm=True) - cnn.mpool(3, 3, 2, 2, mode='SAME') - for _ in xrange(self.layer_counts[0]): - bottleneck_block(cnn, 256, 64, 1, self.pre_activation) - for i in xrange(self.layer_counts[1]): - stride = 2 if i == 0 else 1 - bottleneck_block(cnn, 512, 128, stride, self.pre_activation) - for i in xrange(self.layer_counts[2]): - stride = 2 if i == 0 else 1 - bottleneck_block(cnn, 1024, 256, stride, self.pre_activation) - for i in xrange(self.layer_counts[3]): - stride = 2 if i == 0 else 1 - bottleneck_block(cnn, 2048, 512, stride, self.pre_activation) - if self.pre_activation: - cnn.batch_norm() - cnn.top_layer = tf.nn.relu(cnn.top_layer) - cnn.spatial_mean() - - def get_learning_rate(self, global_step, batch_size): - num_batches_per_epoch = ( - float(datasets.IMAGENET_NUM_TRAIN_IMAGES) / batch_size) - boundaries = [int(num_batches_per_epoch * x) for x in [30, 60]] - values = [0.1, 0.01, 0.001] - return tf.train.piecewise_constant(global_step, boundaries, values) + """Resnet cnn network configuration.""" + + def __init__(self, model, layer_counts): + default_batch_sizes = { + 'resnet50': 64, + 'resnet101': 32, + 'resnet152': 32, + 'resnet50_v2': 64, + 'resnet101_v2': 32, + 'resnet152_v2': 32, + } + batch_size = default_batch_sizes.get(model, 32) + super(ResnetModel, self).__init__(model, 224, batch_size, 0.005, + layer_counts) + self.pre_activation = 'v2' in model + + def add_inference(self, cnn): + if self.layer_counts is None: + raise ValueError( + 'Layer counts not specified for %s' % self.get_model()) + cnn.use_batch_norm = True + cnn.batch_norm_config = { + 'decay': 0.997, + 'epsilon': 1e-5, + 'scale': True + } + cnn.conv(64, 7, 7, 2, 2, mode='SAME_RESNET', use_batch_norm=True) + cnn.mpool(3, 3, 2, 2, mode='SAME') + for _ in xrange(self.layer_counts[0]): + bottleneck_block(cnn, 256, 64, 1, self.pre_activation) + for i in xrange(self.layer_counts[1]): + stride = 2 if i == 0 else 1 + bottleneck_block(cnn, 512, 128, stride, self.pre_activation) + for i in xrange(self.layer_counts[2]): + stride = 2 if i == 0 else 1 + bottleneck_block(cnn, 1024, 256, stride, self.pre_activation) + for i in xrange(self.layer_counts[3]): + stride = 2 if i == 0 else 1 + bottleneck_block(cnn, 2048, 512, stride, self.pre_activation) + if self.pre_activation: + cnn.batch_norm() + cnn.top_layer = tf.nn.relu(cnn.top_layer) + cnn.spatial_mean() + + def get_learning_rate(self, global_step, batch_size): + num_batches_per_epoch = ( + float(datasets.IMAGENET_NUM_TRAIN_IMAGES) / batch_size) + boundaries = [int(num_batches_per_epoch * x) for x in [30, 60]] + values = [0.1, 0.01, 0.001] + return tf.train.piecewise_constant(global_step, boundaries, values) def create_resnet50_model(): - return ResnetModel('resnet50', (3, 4, 6, 3)) + return ResnetModel('resnet50', (3, 4, 6, 3)) def create_resnet50_v2_model(): - return ResnetModel('resnet50_v2', (3, 4, 6, 3)) + return ResnetModel('resnet50_v2', (3, 4, 6, 3)) def create_resnet101_model(): - return ResnetModel('resnet101', (3, 4, 23, 3)) + return ResnetModel('resnet101', (3, 4, 23, 3)) def create_resnet101_v2_model(): - return ResnetModel('resnet101_v2', (3, 4, 23, 3)) + return ResnetModel('resnet101_v2', (3, 4, 23, 3)) def create_resnet152_model(): - return ResnetModel('resnet152', (3, 8, 36, 3)) + return ResnetModel('resnet152', (3, 8, 36, 3)) def create_resnet152_v2_model(): - return ResnetModel('resnet152_v2', (3, 8, 36, 3)) + return ResnetModel('resnet152_v2', (3, 8, 36, 3)) class ResnetCifar10Model(model_lib.Model): - """Resnet cnn network configuration for Cifar 10 dataset. + """Resnet cnn network configuration for Cifar 10 dataset. V1 model architecture follows the one defined in the paper: https://arxiv.org/pdf/1512.03385.pdf. @@ -265,82 +345,83 @@ class ResnetCifar10Model(model_lib.Model): https://arxiv.org/pdf/1603.05027.pdf. """ - def __init__(self, model, layer_counts): - self.pre_activation = 'v2' in model - super(ResnetCifar10Model, self).__init__( - model, 32, 128, 0.1, layer_counts) - - def add_inference(self, cnn): - if self.layer_counts is None: - raise ValueError('Layer counts not specified for %s' % self.get_model()) - - cnn.use_batch_norm = True - cnn.batch_norm_config = {'decay': 0.9, 'epsilon': 1e-5, 'scale': True} - if self.pre_activation: - cnn.conv(16, 3, 3, 1, 1, use_batch_norm=True) - else: - cnn.conv(16, 3, 3, 1, 1, activation=None, use_batch_norm=True) - for i in xrange(self.layer_counts[0]): - # reshape to batch_size x 16 x 32 x 32 - residual_block(cnn, 16, 1, self.pre_activation) - for i in xrange(self.layer_counts[1]): - # Subsampling is performed at the first convolution with a stride of 2 - stride = 2 if i == 0 else 1 - # reshape to batch_size x 32 x 16 x 16 - residual_block(cnn, 32, stride, self.pre_activation) - for i in xrange(self.layer_counts[2]): - stride = 2 if i == 0 else 1 - # reshape to batch_size x 64 x 8 x 8 - residual_block(cnn, 64, stride, self.pre_activation) - if self.pre_activation: - cnn.batch_norm() - cnn.top_layer = tf.nn.relu(cnn.top_layer) - cnn.spatial_mean() - - def get_learning_rate(self, global_step, batch_size): - num_batches_per_epoch = int(50000 / batch_size) - boundaries = num_batches_per_epoch * np.array([82, 123, 300], - dtype=np.int64) - boundaries = [x for x in boundaries] - values = [0.1, 0.01, 0.001, 0.0002] - return tf.train.piecewise_constant(global_step, boundaries, values) + def __init__(self, model, layer_counts): + self.pre_activation = 'v2' in model + super(ResnetCifar10Model, self).__init__(model, 32, 128, 0.1, + layer_counts) + + def add_inference(self, cnn): + if self.layer_counts is None: + raise ValueError( + 'Layer counts not specified for %s' % self.get_model()) + + cnn.use_batch_norm = True + cnn.batch_norm_config = {'decay': 0.9, 'epsilon': 1e-5, 'scale': True} + if self.pre_activation: + cnn.conv(16, 3, 3, 1, 1, use_batch_norm=True) + else: + cnn.conv(16, 3, 3, 1, 1, activation=None, use_batch_norm=True) + for i in xrange(self.layer_counts[0]): + # reshape to batch_size x 16 x 32 x 32 + residual_block(cnn, 16, 1, self.pre_activation) + for i in xrange(self.layer_counts[1]): + # Subsampling is performed at the first convolution with a stride of 2 + stride = 2 if i == 0 else 1 + # reshape to batch_size x 32 x 16 x 16 + residual_block(cnn, 32, stride, self.pre_activation) + for i in xrange(self.layer_counts[2]): + stride = 2 if i == 0 else 1 + # reshape to batch_size x 64 x 8 x 8 + residual_block(cnn, 64, stride, self.pre_activation) + if self.pre_activation: + cnn.batch_norm() + cnn.top_layer = tf.nn.relu(cnn.top_layer) + cnn.spatial_mean() + + def get_learning_rate(self, global_step, batch_size): + num_batches_per_epoch = int(50000 / batch_size) + boundaries = num_batches_per_epoch * np.array( + [82, 123, 300], dtype=np.int64) + boundaries = [x for x in boundaries] + values = [0.1, 0.01, 0.001, 0.0002] + return tf.train.piecewise_constant(global_step, boundaries, values) def create_resnet20_cifar_model(): - return ResnetCifar10Model('resnet20', (3, 3, 3)) + return ResnetCifar10Model('resnet20', (3, 3, 3)) def create_resnet20_v2_cifar_model(): - return ResnetCifar10Model('resnet20_v2', (3, 3, 3)) + return ResnetCifar10Model('resnet20_v2', (3, 3, 3)) def create_resnet32_cifar_model(): - return ResnetCifar10Model('resnet32_v2', (5, 5, 5)) + return ResnetCifar10Model('resnet32_v2', (5, 5, 5)) def create_resnet32_v2_cifar_model(): - return ResnetCifar10Model('resnet32_v2', (5, 5, 5)) + return ResnetCifar10Model('resnet32_v2', (5, 5, 5)) def create_resnet44_cifar_model(): - return ResnetCifar10Model('resnet44', (7, 7, 7)) + return ResnetCifar10Model('resnet44', (7, 7, 7)) def create_resnet44_v2_cifar_model(): - return ResnetCifar10Model('resnet44_v2', (7, 7, 7)) + return ResnetCifar10Model('resnet44_v2', (7, 7, 7)) def create_resnet56_cifar_model(): - return ResnetCifar10Model('resnet56', (9, 9, 9)) + return ResnetCifar10Model('resnet56', (9, 9, 9)) def create_resnet56_v2_cifar_model(): - return ResnetCifar10Model('resnet56_v2', (9, 9, 9)) + return ResnetCifar10Model('resnet56_v2', (9, 9, 9)) def create_resnet110_cifar_model(): - return ResnetCifar10Model('resnet110', (18, 18, 18)) + return ResnetCifar10Model('resnet110', (18, 18, 18)) def create_resnet110_v2_cifar_model(): - return ResnetCifar10Model('resnet110_v2', (18, 18, 18)) + return ResnetCifar10Model('resnet110_v2', (18, 18, 18)) diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index c41ce39dbc4a0..e56566d669733 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -18,7 +18,9 @@ def run_timeline(sess, ops, feed_dict={}, write_timeline=False, name=""): run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() fetches = sess.run( - ops, options=run_options, run_metadata=run_metadata, + ops, + options=run_options, + run_metadata=run_metadata, feed_dict=feed_dict) trace = timeline.Timeline(step_stats=run_metadata.step_stats) outf = "timeline-{}-{}.json".format(name, os.getpid()) From ffafb6aa500de728a28622cd5ecb7c595e738b16 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Sep 2018 18:23:44 -0700 Subject: [PATCH 07/17] typo --- python/ray/experimental/sgd/sgd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index 1a49f26e8ad60..b6de7d065ea3a 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -76,7 +76,7 @@ def __init__(self, list(range(num_devices)), agg_small_grads_max_bytes=max_bytes)) else: - self.packed_grads_and_vars = ( + self.packed_grads_and_vars, _ = ( modified_allreduce.sum_gradients_all_reduce( "", grad_ops, From 4525ff0e7c4d20dfe2523f85764570abdcd1c472 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Sep 2018 18:30:25 -0700 Subject: [PATCH 08/17] lint --- .../sgd/{tfbench => }/modified_allreduce.py | 9 ++-- python/ray/experimental/sgd/sgd.py | 50 +++++++++---------- python/ray/experimental/sgd/test_sgd.py | 3 +- .../sgd/{example => tfbench}/test_model.py | 0 python/ray/experimental/sgd/util.py | 7 ++- 5 files changed, 37 insertions(+), 32 deletions(-) rename python/ray/experimental/sgd/{tfbench => }/modified_allreduce.py (99%) mode change 100644 => 100755 python/ray/experimental/sgd/test_sgd.py rename python/ray/experimental/sgd/{example => tfbench}/test_model.py (100%) diff --git a/python/ray/experimental/sgd/tfbench/modified_allreduce.py b/python/ray/experimental/sgd/modified_allreduce.py similarity index 99% rename from python/ray/experimental/sgd/tfbench/modified_allreduce.py rename to python/ray/experimental/sgd/modified_allreduce.py index 1b7c50521fd15..41cfce61095e9 100644 --- a/python/ray/experimental/sgd/tfbench/modified_allreduce.py +++ b/python/ray/experimental/sgd/modified_allreduce.py @@ -17,6 +17,7 @@ from __future__ import print_function import collections as pycoll +import logger import re from six.moves import xrange # pylint: disable=redefined-builtin @@ -25,6 +26,8 @@ from tensorflow.contrib import nccl from tensorflow.contrib.all_reduce.python import all_reduce +logger = logging.getLogger(__name__) + AllReduceSpecTuple = pycoll.namedtuple('AllReduceSpecTuple', 'alg shards limit') @@ -433,10 +436,10 @@ def sizeof_fmt(num, suffix='B'): "median": np.median(sizes), "total size": np.sum(sizes) } - print("Stats " + + logger.info("Stats " + ", ".join(["%s: %s" % (k, sizeof_fmt(v)) for k, v in stats.items()])) other_stats = {"len": len(sizes)} - print(", ".join(["%s: %f" % (k, v) for k, v in other_stats.items()])) + logger.info(", ".join(["%s: %f" % (k, v) for k, v in other_stats.items()])) def extract_ranges(index_list, range_size_limit=32): @@ -548,7 +551,6 @@ def pack_small_tensors(tower_grads, max_bytes=0): # Check to make sure sizes are accurate; not entirely important assert all(g.dtype == tf.float32 for g in orig_grads) sizes = [4 * g.shape.num_elements() for g in orig_grads] - print("Before packing") print_stats(sizes) small_ranges = [] large_indices = [] @@ -573,7 +575,6 @@ def end_interval(indices, small_ranges, large_indices): end_interval(cur_range, small_ranges, large_indices) new_sizes.insert(0, cur_size) - print("After packing") print_stats(new_sizes) num_gv = len(orig_grads) packing = {} diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index b6de7d065ea3a..61fdd821a572a 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +import logging import os import random import ray @@ -16,6 +17,8 @@ from util import Timeline, fetch, run_timeline from tfbench import modified_allreduce +logger = logging.getLogger(__name__) + class SGDWorker(object): def __init__(self, @@ -91,7 +94,7 @@ def __init__(self, assert (len(self.per_device_grads) == num_devices) self.num_grads = num_grads = len(self.packed_grads_and_vars[0]) if max_bytes: - print("Packed grads => {} tensors".format(num_grads)) + logger.info("Packed grads => {} tensors".format(num_grads)) # Ops for reading grads with the right control deps nccl_noops = [] @@ -152,7 +155,7 @@ def __init__(self, plasma_manager_socket_name=manager_socket) grad_ph = tf.reshape(grad_ph, self.packed_grads_and_vars[0][j][0].shape) - print("Packed tensor", grad_ph) + logger.info("Packed tensor", grad_ph) packed_plasma_grads.append(grad_ph) for i in range(num_devices): per_device = [] @@ -208,7 +211,7 @@ def compute_gradients(self, verbose): ], feed_dict=feed_dict) if verbose: - print("compute grad interior time", time.time() - start) + logger.info("compute grad interior time", time.time() - start) return fetches def apply_gradients(self, avg_grads, verbose): @@ -219,7 +222,7 @@ def apply_gradients(self, avg_grads, verbose): } self.sess.run(self.apply_op, feed_dict=result) if verbose: - print("apply grad interior time", time.time() - start) + logger.info("apply grad interior time", time.time() - start) def ps_compute_apply(self, out_grad_shard_oids, @@ -302,9 +305,6 @@ def add_spinwait(self, grad_shard_ids): def add(self, grad_shard_id): self.timeline.start("add") - # self.timeline.start("add_wait") - # ray.wait([ray.local_scheduler.ObjectID(grad_shard_id)]) - # self.timeline.end("add_wait") self.timeline.start("get_buffers") oid = ray.pyarrow.plasma.ObjectID(grad_shard_id) [raw_grads] = ray.worker.global_worker.plasma_client.get_buffers([oid]) @@ -338,9 +338,9 @@ def pin(self, cpu_id): import psutil p = psutil.Process() p.cpu_affinity([cpu_id]) - print("Setting CPU Affinity to: ", cpu_id) + logger.info("Setting CPU Affinity to: ", cpu_id) except Exception as e: - print(e) + logger.error(e) def average_gradients(grads): @@ -356,7 +356,7 @@ def do_sgd_step(actors, verbose): losses = [f[0] for f in fetches] grads = [f[1] for f in fetches] if verbose: - print("compute all grads time", time.time() - start) + logger.info("compute all grads time", time.time() - start) start = time.time() if len(actors) == 1: assert len(grads) == 1 @@ -364,11 +364,11 @@ def do_sgd_step(actors, verbose): else: avg_grad = average_gradients(grads) if verbose: - print("grad reduce time", time.time() - start) + logger.info("grad reduce time", time.time() - start) start = time.time() ray.get([a.apply_gradients.remote(avg_grad, verbose) for a in actors]) if verbose: - print("apply all grads time", time.time() - start) + logger.info("apply all grads time", time.time() - start) return np.mean(losses) @@ -376,17 +376,17 @@ def distributed_sgd_step(actors, ps_list, verbose, write_timeline): # Preallocate object ids that actors will write gradient shards to grad_shard_oids_list = [[np.random.bytes(20) for _ in ps_list] for _ in actors] - print("generated grad oids") + logger.info("generated grad oids") # Preallocate object ids that param servers will write new weights to accum_shard_ids = [np.random.bytes(20) for _ in ps_list] - print("generated accum oids") + logger.info("generated accum oids") # Kick off the fused compute grad / update weights tf run for each actor for actor, grad_shard_oids in zip(actors, grad_shard_oids_list): actor.ps_compute_apply.remote( grad_shard_oids, accum_shard_ids, write_timeline=write_timeline) - print("Launched all ps_compute_applys on all actors") + logger.info("Launched all ps_compute_applys on all actors") # Issue prefetch ops for j, (ps, weight_shard_oid) in list( @@ -396,7 +396,7 @@ def distributed_sgd_step(actors, ps_list, verbose, write_timeline): to_fetch.append(grad_shard_oids[j]) random.shuffle(to_fetch) ps.prefetch.remote(to_fetch) - print("Launched all prefetch ops") + logger.info("Launched all prefetch ops") # Aggregate the gradients produced by the actors. These operations # run concurrently with the actor methods above. @@ -405,11 +405,11 @@ def distributed_sgd_step(actors, ps_list, verbose, write_timeline): enumerate(zip(ps_list, accum_shard_ids)))[::-1]: ps.add_spinwait.remote([gs[j] for gs in grad_shard_oids_list]) ps_gets.append(ps.get.remote(weight_shard_oid)) - print("Launched all aggregate ops") + logger.info("Launched all aggregate ops") if verbose: timelines = [ps.get_timeline.remote() for ps in ps_list] - print("launched timeline gets") + logger.info("launched timeline gets") timelines = ray.get(timelines) t0 = timelines[0] for t in timelines[1:]: @@ -436,13 +436,13 @@ def create_ps(): while (any(len(v) < min_placed for v in ip_mapping.values()) or (len(ip_mapping) < num_ips)): - print("generating new ps, ip map so far", ip_mapping) + logger.info("generating new ps, ip map so far", ip_mapping) new_ps = create_ps() ps_ip = ray.get(new_ps.ip.remote()) if spread_ps and ps_ip in worker_ips: - print("ignoring ps that is on same node as worker") + logger.info("ignoring ps that is on same node as worker") elif not spread_ps and ps_ip not in worker_ips: - print("ignoring ps that NOT on same node as some worker") + logger.info("ignoring ps that NOT on same node as some worker") else: ip_mapping[ps_ip] += [new_ps] @@ -456,11 +456,11 @@ def create_ps(): for ps in sum(candidates, []): if ps not in final_list: ps.__ray_terminate__.remote(ps._ray_actor_id.id()) - print("removing a ps...") + logger.info("removing a ps...") else: - print("saving ps...") + logger.info("saving ps...") - print("Final PS balance: ", + logger.info("Final PS balance: ", Counter(ray.get([ps.ip.remote() for ps in final_list]))) for i, ps in enumerate(final_list): ps.set_tid.remote(i) @@ -478,7 +478,7 @@ def __init__(self, model_creator, num_workers, devices_per_worker, RemoteSGDWorker = ray.remote(**requests)(SGDWorker) self.workers = [] for worker_index in range(num_workers): - print("Creating worker", worker_index) + logger.info("Creating worker", worker_index) self.workers.append( RemoteSGDWorker.remote( worker_index, diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py old mode 100644 new mode 100755 index fc7c0233e867f..93cba50c8980c --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -8,9 +8,10 @@ import numpy as np import tensorflow as tf -from ray.experimental.sgd.example.test_model import TFBenchModel +from ray.experimental.sgd.tfbench.test_model import TFBenchModel from ray.experimental.sgd.sgd import DistributedSGD + if __name__ == "__main__": ray.init() diff --git a/python/ray/experimental/sgd/example/test_model.py b/python/ray/experimental/sgd/tfbench/test_model.py similarity index 100% rename from python/ray/experimental/sgd/example/test_model.py rename to python/ray/experimental/sgd/tfbench/test_model.py diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index e56566d669733..3112a89831cf4 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -4,8 +4,11 @@ import ray import json +import logger import time +logger = logging.getLogger(__name__) + def fetch(oids): for o in oids: @@ -25,7 +28,7 @@ def run_timeline(sess, ops, feed_dict={}, write_timeline=False, name=""): trace = timeline.Timeline(step_stats=run_metadata.step_stats) outf = "timeline-{}-{}.json".format(name, os.getpid()) trace_file = open(outf, "w") - print("wrote tf timeline to", os.path.abspath(outf)) + logger.info("wrote tf timeline to", os.path.abspath(outf)) trace_file.write(trace.generate_chrome_trace_format()) else: fetches = sess.run(ops, feed_dict=feed_dict) @@ -90,7 +93,7 @@ def chrome_trace_format(self, filename): }) with open(filename, "w") as f: f.write(json.dumps(out)) - print("Wrote chrome timeline to", filename) + logger.info("Wrote chrome timeline to", filename) if __name__ == "__main__": From 5a6c913bec67cdc26577c04910081b4f36e09ee0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Sep 2018 18:35:39 -0700 Subject: [PATCH 09/17] plasma op change --- .../experimental/sgd/modified_allreduce.py | 2 +- python/ray/experimental/sgd/sgd.py | 34 +++++++++++-------- python/ray/experimental/sgd/test_sgd.py | 3 +- python/ray/experimental/sgd/util.py | 2 +- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/python/ray/experimental/sgd/modified_allreduce.py b/python/ray/experimental/sgd/modified_allreduce.py index 41cfce61095e9..7927f462e7b70 100644 --- a/python/ray/experimental/sgd/modified_allreduce.py +++ b/python/ray/experimental/sgd/modified_allreduce.py @@ -17,7 +17,7 @@ from __future__ import print_function import collections as pycoll -import logger +import logging import re from six.moves import xrange # pylint: disable=redefined-builtin diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index 61fdd821a572a..c3603f88e588a 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -10,12 +10,13 @@ from tensorflow.python.client import timeline import numpy as np +import pyarrow.plasma as plasma import tensorflow as tf import tensorflow.contrib.nccl as nccl import tensorflow.contrib.slim as slim from util import Timeline, fetch, run_timeline -from tfbench import modified_allreduce +from ray.experimental.sgd.modified_allreduce import sum_gradients_all_reduce logger = logging.getLogger(__name__) @@ -70,7 +71,7 @@ def __init__(self, else: if max_bytes: self.packed_grads_and_vars, packing_vals = ( - modified_allreduce.sum_gradients_all_reduce( + sum_gradients_all_reduce( "", grad_ops, 1, @@ -80,7 +81,7 @@ def __init__(self, agg_small_grads_max_bytes=max_bytes)) else: self.packed_grads_and_vars, _ = ( - modified_allreduce.sum_gradients_all_reduce( + sum_gradients_all_reduce( "", grad_ops, 1, @@ -113,7 +114,7 @@ def __init__(self, ray.worker.global_worker.plasma_client.store_socket_name) manager_socket = ( ray.worker.global_worker.plasma_client.manager_socket_name) - memcpy_plasma_module = tf.load_op_library( + memcpy_plasma_module = plasma.build_plasma_tensorflow_op( os.path.join( os.path.dirname(os.path.abspath(__file__)), "ops/memcpy_plasma_op.so")) @@ -131,7 +132,7 @@ def __init__(self, ix += 1 # round robin assignment ix %= num_devices with tf.device(self.models[ix].device): - plasma_grad = memcpy_plasma_module.tensor_to_plasma( + plasma_grad = plasma.tf_plasma_op.tensor_to_plasma( [grad], self.plasma_in_grads_oids[j], plasma_store_socket_name=store_socket, @@ -149,7 +150,7 @@ def __init__(self, for j in range(num_grads): with tf.device(self.plasma_in_grads[j].device): with tf.control_dependencies([self.plasma_in_grads[j]]): - grad_ph = memcpy_plasma_module.plasma_to_tensor( + grad_ph = plasma.tf_plasma_op.plasma_to_tensor( self.plasma_out_grads_oids[j], plasma_store_socket_name=store_socket, plasma_manager_socket_name=manager_socket) @@ -211,7 +212,8 @@ def compute_gradients(self, verbose): ], feed_dict=feed_dict) if verbose: - logger.info("compute grad interior time", time.time() - start) + logger.info( + "compute grad interior time {}".format(time.time() - start)) return fetches def apply_gradients(self, avg_grads, verbose): @@ -222,7 +224,8 @@ def apply_gradients(self, avg_grads, verbose): } self.sess.run(self.apply_op, feed_dict=result) if verbose: - logger.info("apply grad interior time", time.time() - start) + logger.info( + "apply grad interior time {}".format(time.time() - start)) def ps_compute_apply(self, out_grad_shard_oids, @@ -338,7 +341,7 @@ def pin(self, cpu_id): import psutil p = psutil.Process() p.cpu_affinity([cpu_id]) - logger.info("Setting CPU Affinity to: ", cpu_id) + logger.info("Setting CPU Affinity to: {}".format(cpu_id)) except Exception as e: logger.error(e) @@ -356,7 +359,7 @@ def do_sgd_step(actors, verbose): losses = [f[0] for f in fetches] grads = [f[1] for f in fetches] if verbose: - logger.info("compute all grads time", time.time() - start) + logger.info("compute all grads time {}".format(time.time() - start)) start = time.time() if len(actors) == 1: assert len(grads) == 1 @@ -364,11 +367,11 @@ def do_sgd_step(actors, verbose): else: avg_grad = average_gradients(grads) if verbose: - logger.info("grad reduce time", time.time() - start) + logger.info("grad reduce time {}".format(time.time() - start)) start = time.time() ray.get([a.apply_gradients.remote(avg_grad, verbose) for a in actors]) if verbose: - logger.info("apply all grads time", time.time() - start) + logger.info("apply all grads time {}".format(time.time() - start)) return np.mean(losses) @@ -436,7 +439,7 @@ def create_ps(): while (any(len(v) < min_placed for v in ip_mapping.values()) or (len(ip_mapping) < num_ips)): - logger.info("generating new ps, ip map so far", ip_mapping) + logger.info("generating new ps, ip map so far {}".format(ip_mapping)) new_ps = create_ps() ps_ip = ray.get(new_ps.ip.remote()) if spread_ps and ps_ip in worker_ips: @@ -469,7 +472,7 @@ def create_ps(): class DistributedSGD(object): def __init__(self, model_creator, num_workers, devices_per_worker, - use_cpus): + use_cpus=False, use_plasma_op=False): self.model_creator = model_creator if use_cpus: requests = {"num_cpus": devices_per_worker} @@ -478,12 +481,13 @@ def __init__(self, model_creator, num_workers, devices_per_worker, RemoteSGDWorker = ray.remote(**requests)(SGDWorker) self.workers = [] for worker_index in range(num_workers): - logger.info("Creating worker", worker_index) + logger.info("Creating worker {}".format(worker_index)) self.workers.append( RemoteSGDWorker.remote( worker_index, model_creator, num_devices=devices_per_worker, + plasma_op=use_plasma_op, use_cpus=use_cpus, verbose=True)) diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index 93cba50c8980c..ad94350a96a5e 100755 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -19,7 +19,8 @@ lambda worker_idx, device_idx: TFBenchModel(batch=1, use_cpus=True)) sgd = DistributedSGD( - model_creator, num_workers=2, devices_per_worker=2, use_cpus=True) + model_creator, num_workers=2, devices_per_worker=2, use_cpus=True, + use_plasma_op=True) for _ in range(100): loss = sgd.step() diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index 3112a89831cf4..7087e1654afd4 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -4,7 +4,7 @@ import ray import json -import logger +import logging import time logger = logging.getLogger(__name__) From 63e4580e843a559aaa452ac8d9e22e04076dc877 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Sep 2018 18:50:40 -0700 Subject: [PATCH 10/17] fix plasma op --- python/ray/experimental/sgd/sgd.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index c3603f88e588a..83a955501f9cb 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -114,10 +114,7 @@ def __init__(self, ray.worker.global_worker.plasma_client.store_socket_name) manager_socket = ( ray.worker.global_worker.plasma_client.manager_socket_name) - memcpy_plasma_module = plasma.build_plasma_tensorflow_op( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "ops/memcpy_plasma_op.so")) + plasma.build_plasma_tensorflow_op() # For fetching grads -> plasma self.plasma_in_grads = [] @@ -131,7 +128,7 @@ def __init__(self, if round_robin_devices: ix += 1 # round robin assignment ix %= num_devices - with tf.device(self.models[ix].device): + with tf.device(self.models[ix].loss.device): plasma_grad = plasma.tf_plasma_op.tensor_to_plasma( [grad], self.plasma_in_grads_oids[j], @@ -152,6 +149,7 @@ def __init__(self, with tf.control_dependencies([self.plasma_in_grads[j]]): grad_ph = plasma.tf_plasma_op.plasma_to_tensor( self.plasma_out_grads_oids[j], + dtype=tf.float32, plasma_store_socket_name=store_socket, plasma_manager_socket_name=manager_socket) grad_ph = tf.reshape(grad_ph, @@ -295,10 +293,7 @@ def add_spinwait(self, grad_shard_ids): for p in plasma_ids: if ray.worker.global_worker.plasma_client.contains(p): self.timeline.start("get_buffers") - [raw_grads - ] = (ray.worker.global_worker.plasma_client.get_buffers( - [p])) - grads = np.frombuffer(raw_grads, dtype=np.float32) + grads = ray.worker.global_worker.plasma_client.get(p) self.accumulated += grads self.acc_counter += 1 self.timeline.end("get_buffers") @@ -310,8 +305,7 @@ def add(self, grad_shard_id): self.timeline.start("add") self.timeline.start("get_buffers") oid = ray.pyarrow.plasma.ObjectID(grad_shard_id) - [raw_grads] = ray.worker.global_worker.plasma_client.get_buffers([oid]) - grads = np.frombuffer(raw_grads, dtype=np.float32) + grads = ray.worker.global_worker.plasma_client.get(oid) self.timeline.end("get_buffers") self.accumulated += grads self.acc_counter += 1 From 2ca7c851c4f2a5ca774acb0fbe9461db9355dcdf Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Sep 2018 19:16:27 -0700 Subject: [PATCH 11/17] still not working --- .../experimental/sgd/modified_allreduce.py | 5 +- python/ray/experimental/sgd/sgd.py | 89 +++++++++---------- python/ray/experimental/sgd/test_sgd.py | 6 +- python/ray/experimental/sgd/util.py | 12 ++- 4 files changed, 57 insertions(+), 55 deletions(-) diff --git a/python/ray/experimental/sgd/modified_allreduce.py b/python/ray/experimental/sgd/modified_allreduce.py index 7927f462e7b70..88ad851b1f7a2 100644 --- a/python/ray/experimental/sgd/modified_allreduce.py +++ b/python/ray/experimental/sgd/modified_allreduce.py @@ -18,6 +18,7 @@ import collections as pycoll import logging +import numpy as np import re from six.moves import xrange # pylint: disable=redefined-builtin @@ -436,8 +437,8 @@ def sizeof_fmt(num, suffix='B'): "median": np.median(sizes), "total size": np.sum(sizes) } - logger.info("Stats " + - ", ".join(["%s: %s" % (k, sizeof_fmt(v)) for k, v in stats.items()])) + logger.info("Stats " + ", ".join( + ["%s: %s" % (k, sizeof_fmt(v)) for k, v in stats.items()])) other_stats = {"len": len(sizes)} logger.info(", ".join(["%s: %f" % (k, v) for k, v in other_stats.items()])) diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index 83a955501f9cb..e0bbdd0c8513f 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -16,7 +16,8 @@ import tensorflow.contrib.slim as slim from util import Timeline, fetch, run_timeline -from ray.experimental.sgd.modified_allreduce import sum_gradients_all_reduce +from ray.experimental.sgd.modified_allreduce import sum_gradients_all_reduce, \ + unpack_small_tensors logger = logging.getLogger(__name__) @@ -28,9 +29,8 @@ def __init__(self, all_reduce_alg="simple", num_devices=1, use_cpus=False, - max_bytes=0, - plasma_op=False, - verbose=False): + max_bytes=60000000, + plasma_op=False): self.worker_index = worker_index assert num_devices > 0 @@ -80,15 +80,14 @@ def __init__(self, list(range(num_devices)), agg_small_grads_max_bytes=max_bytes)) else: - self.packed_grads_and_vars, _ = ( - sum_gradients_all_reduce( - "", - grad_ops, - 1, - all_reduce_alg, - 1, - list(range(num_devices)), - agg_small_grads_max_bytes=0)) + self.packed_grads_and_vars, _ = (sum_gradients_all_reduce( + "", + grad_ops, + 1, + all_reduce_alg, + 1, + list(range(num_devices)), + agg_small_grads_max_bytes=0)) self.per_device_grads = [ list(zip(*dev_gv))[0] for dev_gv in self.packed_grads_and_vars ] @@ -119,7 +118,7 @@ def __init__(self, # For fetching grads -> plasma self.plasma_in_grads = [] self.plasma_in_grads_oids = [ - tf.placeholder(shape=[], dtype=tf.string) + tf.placeholder(shape=[], dtype=tf.string, name="in_grad_oids") for _ in range(num_grads) ] ix = 0 @@ -139,7 +138,8 @@ def __init__(self, # For applying grads <- plasma unpacked_gv = [] self.plasma_out_grads_oids = [ - tf.placeholder(shape=[], dtype=tf.string) + tf.placeholder( + shape=[], dtype=tf.string, name="grad_out_oids") for _ in range(num_grads) ] packed_plasma_grads = [] @@ -154,7 +154,7 @@ def __init__(self, plasma_manager_socket_name=manager_socket) grad_ph = tf.reshape(grad_ph, self.packed_grads_and_vars[0][j][0].shape) - logger.info("Packed tensor", grad_ph) + logger.debug("Packed tensor {}".format(grad_ph)) packed_plasma_grads.append(grad_ph) for i in range(num_devices): per_device = [] @@ -164,8 +164,7 @@ def __init__(self, unpacked_gv.append(per_device) if max_bytes: - unpacked_gv = allreduce.unpack_small_tensors( - unpacked_gv, packing_vals) + unpacked_gv = unpack_small_tensors(unpacked_gv, packing_vals) elif max_bytes: unpacked_gv = allreduce.unpack_small_tensors( @@ -195,7 +194,7 @@ def foreach_model(self, fn): def foreach_worker(self, fn): return fn(self) - def compute_gradients(self, verbose): + def compute_gradients(self): start = time.time() feed_dict = {} # Aggregate feed dicts for each model on this worker. @@ -209,21 +208,18 @@ def compute_gradients(self, verbose): self.nccl_control_out ], feed_dict=feed_dict) - if verbose: - logger.info( - "compute grad interior time {}".format(time.time() - start)) + logger.debug( + "compute grad interior time {}".format(time.time() - start)) return fetches - def apply_gradients(self, avg_grads, verbose): + def apply_gradients(self, avg_grads): start = time.time() result = { g: avg_grads[i] for (i, g) in enumerate(self.per_device_grads[0]) } self.sess.run(self.apply_op, feed_dict=result) - if verbose: - logger.info( - "apply grad interior time {}".format(time.time() - start)) + logger.debug("apply grad interior time {}".format(time.time() - start)) def ps_compute_apply(self, out_grad_shard_oids, @@ -316,10 +312,7 @@ def get(self, object_id): client = ray.worker.global_worker.plasma_client assert self.acc_counter == self.num_sgd_workers, self.acc_counter oid = ray.pyarrow.plasma.ObjectID(object_id) - buff = client.create(oid, self.accumulated.nbytes) - wrapper = np.frombuffer(buff, dtype=np.float32) - np.copyto(wrapper, self.accumulated) - client.seal(oid) + client.put(self.accumulate.flatten(), object_id=oid) self.accumulated = np.zeros_like(self.accumulated) self.acc_counter = 0 self.timeline.end("get") @@ -347,29 +340,26 @@ def average_gradients(grads): return out -def do_sgd_step(actors, verbose): +def do_sgd_step(actors): start = time.time() - fetches = ray.get([a.compute_gradients.remote(verbose) for a in actors]) + fetches = ray.get([a.compute_gradients.remote() for a in actors]) losses = [f[0] for f in fetches] grads = [f[1] for f in fetches] - if verbose: - logger.info("compute all grads time {}".format(time.time() - start)) + logger.debug("compute all grads time {}".format(time.time() - start)) start = time.time() if len(actors) == 1: assert len(grads) == 1 avg_grad = grads[0] else: avg_grad = average_gradients(grads) - if verbose: - logger.info("grad reduce time {}".format(time.time() - start)) + logger.debug("grad reduce time {}".format(time.time() - start)) start = time.time() - ray.get([a.apply_gradients.remote(avg_grad, verbose) for a in actors]) - if verbose: - logger.info("apply all grads time {}".format(time.time() - start)) + ray.get([a.apply_gradients.remote(avg_grad) for a in actors]) + logger.debug("apply all grads time {}".format(time.time() - start)) return np.mean(losses) -def distributed_sgd_step(actors, ps_list, verbose, write_timeline): +def distributed_sgd_step(actors, ps_list, write_timeline): # Preallocate object ids that actors will write gradient shards to grad_shard_oids_list = [[np.random.bytes(20) for _ in ps_list] for _ in actors] @@ -404,7 +394,7 @@ def distributed_sgd_step(actors, ps_list, verbose, write_timeline): ps_gets.append(ps.get.remote(weight_shard_oid)) logger.info("Launched all aggregate ops") - if verbose: + if write_timeline: timelines = [ps.get_timeline.remote() for ps in ps_list] logger.info("launched timeline gets") timelines = ray.get(timelines) @@ -452,21 +442,25 @@ def create_ps(): for ps in sum(candidates, []): if ps not in final_list: - ps.__ray_terminate__.remote(ps._ray_actor_id.id()) + ps.__ray_terminate__.remote() logger.info("removing a ps...") else: logger.info("saving ps...") logger.info("Final PS balance: ", - Counter(ray.get([ps.ip.remote() for ps in final_list]))) + Counter(ray.get([ps.ip.remote() for ps in final_list]))) for i, ps in enumerate(final_list): ps.set_tid.remote(i) return final_list class DistributedSGD(object): - def __init__(self, model_creator, num_workers, devices_per_worker, - use_cpus=False, use_plasma_op=False): + def __init__(self, + model_creator, + num_workers, + devices_per_worker, + use_cpus=False, + use_plasma_op=False): self.model_creator = model_creator if use_cpus: requests = {"num_cpus": devices_per_worker} @@ -482,8 +476,7 @@ def __init__(self, model_creator, num_workers, devices_per_worker, model_creator, num_devices=devices_per_worker, plasma_op=use_plasma_op, - use_cpus=use_cpus, - verbose=True)) + use_cpus=use_cpus)) def foreach_worker(self, fn): results = ray.get([w.foreach_worker.remote(fn) for w in self.workers]) @@ -497,4 +490,4 @@ def foreach_model(self, fn): return r def step(self): - return do_sgd_step(self.workers, True) + return do_sgd_step(self.workers) diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index ad94350a96a5e..698f19f63617c 100755 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -11,7 +11,6 @@ from ray.experimental.sgd.tfbench.test_model import TFBenchModel from ray.experimental.sgd.sgd import DistributedSGD - if __name__ == "__main__": ray.init() @@ -19,7 +18,10 @@ lambda worker_idx, device_idx: TFBenchModel(batch=1, use_cpus=True)) sgd = DistributedSGD( - model_creator, num_workers=2, devices_per_worker=2, use_cpus=True, + model_creator, + num_workers=2, + devices_per_worker=2, + use_cpus=True, use_plasma_op=True) for _ in range(100): diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index 7087e1654afd4..859f011ec8cac 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -11,9 +11,15 @@ def fetch(oids): - for o in oids: - plasma_id = ray.pyarrow.plasma.ObjectID(o) - ray.worker.global_worker.plasma_client.fetch([plasma_id]) + if ray.global_state.use_raylet: + local_sched_client = ray.worker.global_worker.local_scheduler_client + for o in oids: + ray_obj_id = ray.ObjectID(o) + local_sched_client.reconstruct_objects([ray_obj_id], True) + else: + for o in oids: + plasma_id = ray.pyarrow.plasma.ObjectID(o) + ray.worker.global_worker.plasma_client.fetch([plasma_id]) def run_timeline(sess, ops, feed_dict={}, write_timeline=False, name=""): From f37af9099e51cfdefb0765387b6bae3207b9fbf6 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Sep 2018 19:17:47 -0700 Subject: [PATCH 12/17] fix --- python/ray/experimental/sgd/sgd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index e0bbdd0c8513f..bc21270ea6f99 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -113,7 +113,8 @@ def __init__(self, ray.worker.global_worker.plasma_client.store_socket_name) manager_socket = ( ray.worker.global_worker.plasma_client.manager_socket_name) - plasma.build_plasma_tensorflow_op() + if not plasma.tf_plasma_op: + plasma.build_plasma_tensorflow_op() # For fetching grads -> plasma self.plasma_in_grads = [] @@ -218,6 +219,7 @@ def apply_gradients(self, avg_grads): g: avg_grads[i] for (i, g) in enumerate(self.per_device_grads[0]) } + print("APPLY DICT", result) self.sess.run(self.apply_op, feed_dict=result) logger.debug("apply grad interior time {}".format(time.time() - start)) From 11672b05cb504bb235df41cc7628a7cc90041ded Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Sep 2018 19:21:04 -0700 Subject: [PATCH 13/17] fix --- python/ray/experimental/sgd/sgd.py | 5 +++-- python/ray/experimental/sgd/test_sgd.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index bc21270ea6f99..6b178e97bccfb 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -168,7 +168,7 @@ def __init__(self, unpacked_gv = unpack_small_tensors(unpacked_gv, packing_vals) elif max_bytes: - unpacked_gv = allreduce.unpack_small_tensors( + unpacked_gv = unpack_small_tensors( self.packed_grads_and_vars, packing_vals) else: unpacked_gv = self.packed_grads_and_vars @@ -219,7 +219,6 @@ def apply_gradients(self, avg_grads): g: avg_grads[i] for (i, g) in enumerate(self.per_device_grads[0]) } - print("APPLY DICT", result) self.sess.run(self.apply_op, feed_dict=result) logger.debug("apply grad interior time {}".format(time.time() - start)) @@ -479,6 +478,8 @@ def __init__(self, num_devices=devices_per_worker, plasma_op=use_plasma_op, use_cpus=use_cpus)) + assert not use_plasma_op, \ + "TODO: when use_plasma_op is true, we must run in PS mode" def foreach_worker(self, fn): results = ray.get([w.foreach_worker.remote(fn) for w in self.workers]) diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index 698f19f63617c..1f6d657633562 100755 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -22,7 +22,7 @@ num_workers=2, devices_per_worker=2, use_cpus=True, - use_plasma_op=True) + use_plasma_op=False) for _ in range(100): loss = sgd.step() From 8da7c9a3fa46d23b14a8d43f15178ef8c309f2e8 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 15 Sep 2018 12:47:48 -0700 Subject: [PATCH 14/17] comments --- python/ray/experimental/sgd/modified_allreduce.py | 3 +++ python/ray/experimental/sgd/sgd.py | 8 ++++---- python/ray/experimental/sgd/test_sgd.py | 3 +-- python/ray/experimental/sgd/tfbench/README.txt | 1 + python/ray/experimental/sgd/util.py | 6 ++++-- 5 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 python/ray/experimental/sgd/tfbench/README.txt diff --git a/python/ray/experimental/sgd/modified_allreduce.py b/python/ray/experimental/sgd/modified_allreduce.py index 88ad851b1f7a2..f9e580002e970 100644 --- a/python/ray/experimental/sgd/modified_allreduce.py +++ b/python/ray/experimental/sgd/modified_allreduce.py @@ -1,3 +1,6 @@ +# This file is adapted from https://github.com/tensorflow/benchmarks +# /blob/master/scripts/tf_cnn_benchmarks/allreduce.py +# # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index 6b178e97bccfb..266d500ac7af4 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -5,7 +5,6 @@ import logging import os import random -import ray import time from tensorflow.python.client import timeline @@ -15,7 +14,8 @@ import tensorflow.contrib.nccl as nccl import tensorflow.contrib.slim as slim -from util import Timeline, fetch, run_timeline +import ray +from ray.experimental.sgd.util import Timeline, fetch, run_timeline from ray.experimental.sgd.modified_allreduce import sum_gradients_all_reduce, \ unpack_small_tensors @@ -168,8 +168,8 @@ def __init__(self, unpacked_gv = unpack_small_tensors(unpacked_gv, packing_vals) elif max_bytes: - unpacked_gv = unpack_small_tensors( - self.packed_grads_and_vars, packing_vals) + unpacked_gv = unpack_small_tensors(self.packed_grads_and_vars, + packing_vals) else: unpacked_gv = self.packed_grads_and_vars diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index 1f6d657633562..396bf44a2f27a 100755 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -2,12 +2,11 @@ from __future__ import division from __future__ import print_function -import ray - import argparse import numpy as np import tensorflow as tf +import ray from ray.experimental.sgd.tfbench.test_model import TFBenchModel from ray.experimental.sgd.sgd import DistributedSGD diff --git a/python/ray/experimental/sgd/tfbench/README.txt b/python/ray/experimental/sgd/tfbench/README.txt new file mode 100644 index 0000000000000..7ef1a87f759b5 --- /dev/null +++ b/python/ray/experimental/sgd/tfbench/README.txt @@ -0,0 +1 @@ +Files in this directory are adapted from https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks. diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index 859f011ec8cac..1e4a45e190395 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -2,11 +2,12 @@ from __future__ import division from __future__ import print_function -import ray import json import logging import time +import ray + logger = logging.getLogger(__name__) @@ -22,7 +23,8 @@ def fetch(oids): ray.worker.global_worker.plasma_client.fetch([plasma_id]) -def run_timeline(sess, ops, feed_dict={}, write_timeline=False, name=""): +def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""): + feed_dict = feed_dict or {} if write_timeline: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() From 8e226a192aa5dfe871c39c3f46c064ae761eff6d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 17 Sep 2018 23:34:16 -0700 Subject: [PATCH 15/17] yapf --- .../experimental/sgd/modified_allreduce.py | 7 +-- python/ray/experimental/sgd/sgd.py | 58 ++----------------- python/ray/experimental/sgd/test_sgd.py | 4 -- .../sgd/tfbench/convnet_builder.py | 28 +-------- python/ray/experimental/sgd/tfbench/model.py | 9 +-- .../experimental/sgd/tfbench/resnet_model.py | 7 +-- .../experimental/sgd/tfbench/test_model.py | 1 - python/ray/experimental/sgd/util.py | 4 +- 8 files changed, 16 insertions(+), 102 deletions(-) diff --git a/python/ray/experimental/sgd/modified_allreduce.py b/python/ray/experimental/sgd/modified_allreduce.py index f9e580002e970..a9d6879f99c7b 100644 --- a/python/ray/experimental/sgd/modified_allreduce.py +++ b/python/ray/experimental/sgd/modified_allreduce.py @@ -239,8 +239,7 @@ def build_trivial_sum(scaled_grads): return scaled_grads -def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, - check_inf_nan): +def aggregate_single_gradient(grad_and_vars, use_mean, check_inf_nan): """Calculate the average gradient for a shared variable across all towers. Note that this function provides a synchronization point across all towers. @@ -291,7 +290,7 @@ def aggregate_gradients_using_copy_with_device_selection( has_nan_or_inf_list = [] for i, single_grads in enumerate(zip(*tower_grads)): with tf.device(avail_devices[i % len(avail_devices)]): - grad_and_var, has_nan_or_inf = aggregate_single_gradient_using_copy( + grad_and_var, has_nan_or_inf = aggregate_single_gradient( single_grads, use_mean, check_inf_nan) agg_grads.append(grad_and_var) has_nan_or_inf_list.append(has_nan_or_inf) @@ -331,7 +330,7 @@ def sum_grad_and_var_all_reduce(grad_and_vars, summed_grads = all_reduce.build_shuffle_then_shuffle( scaled_grads, aux_devices, - # TODO(tucker): devise a way of better specifying the device set + # TODO(tucker): devise a way of better specifying the device # for the second level. [aux_devices[0]], tf.add_n) diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index 266d500ac7af4..c569c036f1b10 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -3,16 +3,12 @@ from __future__ import print_function import logging -import os import random import time -from tensorflow.python.client import timeline import numpy as np import pyarrow.plasma as plasma import tensorflow as tf -import tensorflow.contrib.nccl as nccl -import tensorflow.contrib.slim as slim import ray from ray.experimental.sgd.util import Timeline, fetch, run_timeline @@ -99,9 +95,10 @@ def __init__(self, # Ops for reading grads with the right control deps nccl_noops = [] for j in range(num_grads)[::-1]: - with tf.control_dependencies( - nccl_noops + - [dev_grad[j] for dev_grad in self.per_device_grads]): + deps = nccl_noops + [ + dev_grad[j] for dev_grad in self.per_device_grads + ] + with tf.control_dependencies(deps): nccl_noops = [tf.no_op()] # You must fetch this otherwise the NCCL allreduce will hang @@ -408,53 +405,6 @@ def distributed_sgd_step(actors, ps_list, write_timeline): ray.get(ps_gets) -def roundrobin_ps(ps_cls, sgd_workers, shard_shapes, spread_ps): - worker_ips = ray.get([w.ip.remote() for w in sgd_workers]) - num_ips = len(set(worker_ips)) - num_workers = len(sgd_workers) - min_placed = np.ceil(len(shard_shapes) / num_ips) - from collections import Counter, defaultdict - tid_counter = [0] - - def create_ps(): - tid_counter[0] += 1 - return RemotePS.remote(num_workers, tid_counter[0]) - - ip_mapping = defaultdict(list) - - while (any(len(v) < min_placed for v in ip_mapping.values()) - or (len(ip_mapping) < num_ips)): - logger.info("generating new ps, ip map so far {}".format(ip_mapping)) - new_ps = create_ps() - ps_ip = ray.get(new_ps.ip.remote()) - if spread_ps and ps_ip in worker_ips: - logger.info("ignoring ps that is on same node as worker") - elif not spread_ps and ps_ip not in worker_ips: - logger.info("ignoring ps that NOT on same node as some worker") - else: - ip_mapping[ps_ip] += [new_ps] - - final_list = [] - candidates = list(ip_mapping.values()) - for i, s in enumerate(shard_shapes): - ps = candidates[i % num_ips][i // num_ips] - final_list += [ps] - ps.initialize.remote(s) - - for ps in sum(candidates, []): - if ps not in final_list: - ps.__ray_terminate__.remote() - logger.info("removing a ps...") - else: - logger.info("saving ps...") - - logger.info("Final PS balance: ", - Counter(ray.get([ps.ip.remote() for ps in final_list]))) - for i, ps in enumerate(final_list): - ps.set_tid.remote(i) - return final_list - - class DistributedSGD(object): def __init__(self, model_creator, diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index 396bf44a2f27a..259af6f4eb5dc 100755 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -2,10 +2,6 @@ from __future__ import division from __future__ import print_function -import argparse -import numpy as np -import tensorflow as tf - import ray from ray.experimental.sgd.tfbench.test_model import TFBenchModel from ray.experimental.sgd.sgd import DistributedSGD diff --git a/python/ray/experimental/sgd/tfbench/convnet_builder.py b/python/ray/experimental/sgd/tfbench/convnet_builder.py index ad4225540ef85..aebae86e03033 100644 --- a/python/ray/experimental/sgd/tfbench/convnet_builder.py +++ b/python/ray/experimental/sgd/tfbench/convnet_builder.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== +# ============================================================================= """CNN builder.""" from __future__ import print_function @@ -56,7 +56,7 @@ def __init__(self, self.aux_top_size = 0 def get_custom_getter(self): - """Returns a custom getter that this class's methods must be called under. + """Returns a custom getter that this class's methods must be called All methods of this class must be called under a variable scope that was passed this custom getter. Example: @@ -75,21 +75,13 @@ def get_custom_getter(self): """ def inner_custom_getter(getter, *args, **kwargs): - """Custom getter that forces variables to have type self.variable_type.""" if not self.use_tf_layers: return getter(*args, **kwargs) requested_dtype = kwargs['dtype'] if not (requested_dtype == tf.float32 and self.variable_dtype == tf.float16): - # Only change the variable dtype if doing so does not decrease variable - # precision. kwargs['dtype'] = self.variable_dtype var = getter(*args, **kwargs) - # This if statement is needed to guard the cast, because batch norm - # assigns directly to the return value of this custom getter. The cast - # makes the return value not a variable so it cannot be assigned. Batch - # norm variables are always in fp32 so this if statement is never - # triggered for them. if var.dtype.base_dtype != requested_dtype: var = tf.cast(var, requested_dtype) return var @@ -112,10 +104,6 @@ def switch_to_aux_top_layer(self): self.top_size = saved_top_size def get_variable(self, name, shape, dtype, cast_dtype, *args, **kwargs): - # TODO(reedwm): Currently variables and gradients are transferred to other - # devices and machines as type `dtype`, not `cast_dtype`. In particular, - # this means in fp16 mode, variables are transferred as fp32 values, not - # fp16 values, which uses extra bandwidth. var = tf.get_variable(name, shape, dtype, *args, **kwargs) return tf.cast(var, cast_dtype) @@ -135,10 +123,6 @@ def _conv2d_impl(self, input_layer, num_channels_in, filters, kernel_size, weights_shape = [ kernel_size[0], kernel_size[1], num_channels_in, filters ] - # We use the name 'conv2d/kernel' so the variable has the same name as its - # tf.layers equivalent. This way, if a checkpoint is written when - # self.use_tf_layers == True, it can be loaded when - # self.use_tf_layers == False, and vice versa. weights = self.get_variable( 'conv2d/kernel', weights_shape, @@ -385,7 +369,7 @@ def inception_module(self, name, cols, input_layer=None, in_size=None): self.mpool(*args, **kwargs) elif ltype == 'apool': self.apool(*args, **kwargs) - elif ltype == 'share': # Share matching layer from previous column + elif ltype == 'share': self.top_layer = col_layers[c - 1][l] self.top_size = col_layer_sizes[c - 1][l] else: @@ -427,9 +411,6 @@ def dropout(self, keep_prob=0.5, input_layer=None): def _batch_norm_without_layers(self, input_layer, decay, use_scale, epsilon): """Batch normalization on `input_layer` without tf.layers.""" - # We make this function as similar as possible to the - # tf.contrib.layers.batch_norm, to minimize the differences between using - # layers and not using layers. shape = input_layer.shape num_channels = shape[3] if self.data_format == 'NHWC' else shape[1] beta = self.get_variable( @@ -445,9 +426,6 @@ def _batch_norm_without_layers(self, input_layer, decay, use_scale, initializer=tf.ones_initializer()) else: gamma = tf.constant(1.0, tf.float32, [num_channels]) - # For moving variables, we use tf.get_variable instead of self.get_variable, - # since self.get_variable returns the result of tf.cast which we cannot - # assign to. moving_mean = tf.get_variable( 'moving_mean', [num_channels], tf.float32, diff --git a/python/ray/experimental/sgd/tfbench/model.py b/python/ray/experimental/sgd/tfbench/model.py index 02ebebb52f070..c33c502479c62 100644 --- a/python/ray/experimental/sgd/tfbench/model.py +++ b/python/ray/experimental/sgd/tfbench/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== +# ============================================================================= """Base model configuration for CNN benchmarks.""" import tensorflow as tf @@ -34,8 +34,6 @@ def __init__(self, self.default_batch_size = batch_size self.learning_rate = learning_rate self.layer_counts = layer_counts - # TODO(reedwm) Set custom loss scales for each model instead of using the - # default of 128. self.fp16_loss_scale = fp16_loss_scale def get_model(self): @@ -68,7 +66,7 @@ def add_inference(self, unused_cnn): raise ValueError('Must be implemented in derived classes') def skip_final_affine_layer(self): - """Returns if the caller of this class should skip the final affine layer. + """Returns if the caller of this class should skip the final affine Normally, this class adds a final affine layer to the model after calling self.add_inference(), to generate the logits. If a subclass override this @@ -115,7 +113,4 @@ def build_network(self, aux_logits = tf.cast(aux_logits, tf.float32) return logits, aux_logits - # Subclasses can override this to define their own loss function. By default, - # benchmark_cnn.py defines its own loss function. If overridden, it must have - # the same signature as benchmark_cnn.loss_function. loss_function = None diff --git a/python/ray/experimental/sgd/tfbench/resnet_model.py b/python/ray/experimental/sgd/tfbench/resnet_model.py index 38c1fc33a9204..59052ed576fbf 100644 --- a/python/ray/experimental/sgd/tfbench/resnet_model.py +++ b/python/ray/experimental/sgd/tfbench/resnet_model.py @@ -304,11 +304,7 @@ def add_inference(self, cnn): cnn.spatial_mean() def get_learning_rate(self, global_step, batch_size): - num_batches_per_epoch = ( - float(datasets.IMAGENET_NUM_TRAIN_IMAGES) / batch_size) - boundaries = [int(num_batches_per_epoch * x) for x in [30, 60]] - values = [0.1, 0.01, 0.001] - return tf.train.piecewise_constant(global_step, boundaries, values) + raise NotImplementedError def create_resnet50_model(): @@ -365,7 +361,6 @@ def add_inference(self, cnn): # reshape to batch_size x 16 x 32 x 32 residual_block(cnn, 16, 1, self.pre_activation) for i in xrange(self.layer_counts[1]): - # Subsampling is performed at the first convolution with a stride of 2 stride = 2 if i == 0 else 1 # reshape to batch_size x 32 x 16 x 16 residual_block(cnn, 32, stride, self.pre_activation) diff --git a/python/ray/experimental/sgd/tfbench/test_model.py b/python/ray/experimental/sgd/tfbench/test_model.py index d5b4ff56fc65a..0dd48607ef0a6 100644 --- a/python/ray/experimental/sgd/tfbench/test_model.py +++ b/python/ray/experimental/sgd/tfbench/test_model.py @@ -2,7 +2,6 @@ from __future__ import division from __future__ import print_function -import numpy as np import tensorflow as tf from tfbench import model_config diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index 1e4a45e190395..ca72bb5e9ef43 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -4,7 +4,9 @@ import json import logging +import os import time +import tensorflow as tf import ray @@ -33,7 +35,7 @@ def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""): options=run_options, run_metadata=run_metadata, feed_dict=feed_dict) - trace = timeline.Timeline(step_stats=run_metadata.step_stats) + trace = Timeline(step_stats=run_metadata.step_stats) outf = "timeline-{}-{}.json".format(name, os.getpid()) trace_file = open(outf, "w") logger.info("wrote tf timeline to", os.path.abspath(outf)) From 8c64bde9277a2536e67557ecb2ae579c4c8e1420 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 19 Sep 2018 16:27:41 -0700 Subject: [PATCH 16/17] silly flake8 --- python/ray/experimental/sgd/tfbench/convnet_builder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/experimental/sgd/tfbench/convnet_builder.py b/python/ray/experimental/sgd/tfbench/convnet_builder.py index aebae86e03033..e59085e5fc298 100644 --- a/python/ray/experimental/sgd/tfbench/convnet_builder.py +++ b/python/ray/experimental/sgd/tfbench/convnet_builder.py @@ -357,12 +357,12 @@ def inception_module(self, name, cols, input_layer=None, in_size=None): for c, col in enumerate(cols): col_layers.append([]) col_layer_sizes.append([]) - for l, layer in enumerate(col): + for lx, layer in enumerate(col): ltype, args = layer[0], layer[1:] kwargs = { 'input_layer': input_layer, 'num_channels_in': in_size - } if l == 0 else {} + } if lx == 0 else {} if ltype == 'conv': self.conv(*args, **kwargs) elif ltype == 'mpool': @@ -370,8 +370,8 @@ def inception_module(self, name, cols, input_layer=None, in_size=None): elif ltype == 'apool': self.apool(*args, **kwargs) elif ltype == 'share': - self.top_layer = col_layers[c - 1][l] - self.top_size = col_layer_sizes[c - 1][l] + self.top_layer = col_layers[c - 1][lx] + self.top_size = col_layer_sizes[c - 1][lx] else: raise KeyError( 'Invalid layer type for inception module: \'%s\'' % @@ -381,7 +381,7 @@ def inception_module(self, name, cols, input_layer=None, in_size=None): catdim = 3 if self.data_format == 'NHWC' else 1 self.top_layer = tf.concat([layers[-1] for layers in col_layers], catdim) - self.top_size = sum([sizes[-1] for sizes in col_layer_sizes]) + self.top_size = sum(sizes[-1] for sizes in col_layer_sizes) return self.top_layer def spatial_mean(self, keep_dims=False): From b840c3f5f43e7de27a72b9bd80144ef769e01c71 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 19 Sep 2018 16:30:47 -0700 Subject: [PATCH 17/17] small test --- python/ray/experimental/sgd/test_sgd.py | 10 +++++++++- test/jenkins_tests/run_multi_node_tests.sh | 3 +++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index 259af6f4eb5dc..d6369a4e00011 100755 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -2,13 +2,21 @@ from __future__ import division from __future__ import print_function +import argparse + import ray from ray.experimental.sgd.tfbench.test_model import TFBenchModel from ray.experimental.sgd.sgd import DistributedSGD +parser = argparse.ArgumentParser() +parser.add_argument( + "--num-iters", default=100, type=int, help="Number of iterations to run") + if __name__ == "__main__": ray.init() + args, _ = parser.parse_known_args() + model_creator = ( lambda worker_idx, device_idx: TFBenchModel(batch=1, use_cpus=True)) @@ -19,6 +27,6 @@ use_cpus=True, use_plasma_op=False) - for _ in range(100): + for _ in range(args.num_iters): loss = sgd.step() print("Current loss", loss) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index e3813e54af68d..c36f129f095a7 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -293,6 +293,9 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/multiagent_two_trainers.py --num-iters=2 +docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 + # No Xray for PyTorch docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \