Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All_reduce of collective_ops hangs in a distributed environment #31913

Closed
whhu opened this issue Aug 23, 2019 · 8 comments
Closed

All_reduce of collective_ops hangs in a distributed environment #31913

whhu opened this issue Aug 23, 2019 · 8 comments
Assignees
Labels
comp:ops OPs related issues TF 1.13 Issues related to TF 1.13 type:bug Bug

Comments

@whhu
Copy link

whhu commented Aug 23, 2019

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux CentOS 7.6.1810
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: No
  • TensorFlow installed from (source or binary): No
  • TensorFlow version (use command below): v1.13.1-0-g6612da8951
  • Python version: 3.6.8
  • Bazel version (if compiling from source): None
  • GCC/Compiler version (if compiling from source): None
  • CUDA/cuDNN version: None
  • GPU model and memory: None

Describe the current behavior

The monitored session hangs in there fetching the reduced_weight.

Describe the expected behavior

The all_reduce tensor reduced_weight gives proper answer on all workers.

Code to reproduce the issue

"""Illustrate AllReduce"""

import multiprocessing as mp

MP_METHOD = 'fork'  # 'fork' (UNIX), 'spawn' (WINDOWS);
NUM_PROCESSES = 2

def process_fn(worker_hosts, task_index):
    """allreduce process"""
    import time
    import tensorflow as tf
    from tensorflow.python.ops import collective_ops

    num_workers = len(worker_hosts)

    cluster_spec = tf.train.ClusterSpec({'worker': worker_hosts})

    server = tf.train.Server(cluster_spec,
                             job_name='worker', task_index=task_index)
    group_key = 0
    instance_key = 0
    with tf.Graph().as_default():
        weights = list()
        reduced_weight = list()
        for worker_index in range(num_workers):
            with tf.variable_scope('worker{}'.format(worker_index)), \
                    tf.device('job:worker/task:{}/device:CPU:0'.format(
                            worker_index)):
                weight = tf.get_variable('weight', shape=[])
                weights.append(weight)
                if worker_index == task_index:
                    reduced_weight = collective_ops.all_reduce(
                        weight, num_workers, group_key, instance_key,
                        'Add', 'Div')

        session_creator = tf.train.ChiefSessionCreator(master=server.target)
        with tf.train.MonitoredSession(session_creator=session_creator) \
                as mon_sess:
            print('task {} have {}'.format(task_index, mon_sess.run(weights)))
            result = mon_sess.run(reduced_weight)
        print('task {} reduce {}'.format(task_index, result))
        time.sleep(1)

def start_process():
    """start process"""
    port = 60000
    host_fmt = 'localhost:{}'
    worker_hosts = list()
    for process_index in range(NUM_PROCESSES):
        worker_hosts.append(host_fmt.format(port + process_index))
    mp_ctx = mp.get_context(MP_METHOD)
    processes = list()
    for process_index in range(NUM_PROCESSES):
        process = mp_ctx.Process(target=process_fn,
                                 args=(worker_hosts, process_index,))
        processes.append(process)
        process.start()
    for process in processes:
        process.join()

if __name__ == '__main__':
    start_process()

Other info / logs

(tf-1.13-py3) [huwh1@huwh1-centos worksync]$ python ./tf_distribute_collective_ops.py
2019-08-23 10:55:20.150797: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-08-23 10:55:20.152951: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-08-23 10:55:20.163464: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3408000000 Hz
2019-08-23 10:55:20.163852: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x4364100 executing computations on platform Host. Devices:
2019-08-23 10:55:20.163883: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>
2019-08-23 10:55:20.165614: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:252] Initialize GrpcChannelCache for job worker -> {0 -> localhost:60000, 1 -> localhost:60001}
2019-08-23 10:55:20.165828: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3408000000 Hz
2019-08-23 10:55:20.166148: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x4363fe0 executing computations on platform Host. Devices:
2019-08-23 10:55:20.166174: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>
2019-08-23 10:55:20.166519: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:391] Started server with target: grpc://localhost:60001
2019-08-23 10:55:20.167632: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:252] Initialize GrpcChannelCache for job worker -> {0 -> localhost:60000, 1 -> localhost:60001}
2019-08-23 10:55:20.168829: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:391] Started server with target: grpc://localhost:60000
WARNING:tensorflow:From /home/huwh1/virtualenv/tf-1.13-py3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /home/huwh1/virtualenv/tf-1.13-py3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
2019-08-23 10:55:20.262106: I tensorflow/core/distributed_runtime/master_session.cc:1192] Start master session c45b1693e334d401 with config: 
2019-08-23 10:55:20.269965: I tensorflow/core/distributed_runtime/master_session.cc:1192] Start master session a7c551a16b557bd8 with config: 
task 1 have [-0.82924074, -0.72853804]
task 0 have [-0.82924074, -0.72853804]

The collective_ops.all_reduce seems to be referenced only once in build_collective_reduce, where the instruction suggests "input_tensors: tensors within a single worker graph that are to be reduced together; must be one per device." Is that mean the all_reduce is only applicable to In-graph replication?

@gadagashwini-zz gadagashwini-zz self-assigned this Aug 23, 2019
@gadagashwini-zz gadagashwini-zz added TF 1.13 Issues related to TF 1.13 comp:apis Highlevel API related issues type:bug Bug labels Aug 23, 2019
@gadagashwini-zz gadagashwini-zz added the comp:ops OPs related issues label Aug 23, 2019
@ymodak ymodak removed the comp:apis Highlevel API related issues label Aug 23, 2019
@ymodak ymodak assigned mrry and unassigned ymodak Aug 23, 2019
@mrry mrry assigned dubey and unassigned mrry Aug 23, 2019
@mrry
Copy link
Contributor

mrry commented Aug 23, 2019

Reassigning this to @dubey since it involves collective_ops.

@dubey
Copy link
Member

dubey commented Aug 27, 2019

Collective ops support both in-graph and between-graph communication.

To enable between-graph communication, each worker needs to know the collective_group_leader. If left unconfigured, each worker assumes it is the leader and waits for a message from other workers in the group.

The fix is to add the following snippet:

from tensorflow.core.protobuf import config_pb2
config = config_pb2.ConfigProto()
config.experimental.collective_group_leader = '/job:worker/replica:0/task:0'

and pass this config to the tf.train.Server constructor.

@whhu
Copy link
Author

whhu commented Aug 28, 2019

Thanks very much to all you guys! The solution fix the hanging. 🙂

@whhu whhu closed this as completed Aug 28, 2019
@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@whhu
Copy link
Author

whhu commented Aug 29, 2019

Here comes a further question. Does this fix work concurrently for the both in-graph and between-graph communications? Two distributed workers are created with three towers on each, where the between-graph communication works fine. Nevertheless, the in-graph allreduce among the three towers on the same worker fails with a message

Check failed: cp->group.group_size == cp->instance.device_names.size() (3 vs. 2)0x7ff630142180

and worker0 exits. It seems that the group_key doesn't have the same effect as an mpi communicator. I wonder is there some method to carry out the communications under the same mon_sess? The complete code and log are attached below.

Code:

"""Illustrate AllReduce"""

import multiprocessing as mp

MP_METHOD = 'fork'  # 'fork' (UNIX), 'spawn' (WINDOWS);
NUM_PROCESSES = 2
NUM_TOWERS = 3

def process_fn(worker_hosts, task_index):
    """allreduce process"""

    import time
    import tensorflow as tf
    from tensorflow.python.ops import collective_ops

    cluster_spec = tf.train.ClusterSpec({'worker': worker_hosts})

    # unconfigured collective_group_leader make each worker the leader
    # '/replica:0' is necessary in the configuration.
    config = tf.ConfigProto(device_count={'CPU': NUM_TOWERS})
    config.experimental.collective_group_leader = '/job:worker/replica:0/task:0'
    server = tf.train.Server(cluster_spec, config=config,
                             job_name='worker', task_index=task_index)
    with tf.Graph().as_default():
        # create weight
        all_weights = list()
        for worker_index in range(NUM_PROCESSES):
            worker_weights = list()
            with tf.variable_scope('worker{}'.format(worker_index)):
                for tower_index in range(NUM_TOWERS):
                    device = '/job:worker/replica:0/task:{}/device:CPU:{}'.format(
                        worker_index, tower_index)
                    with tf.device(device):
                        worker_weights.append(tf.get_variable(
                            'weight{}'.format(tower_index), shape=[]))
            all_weights.append(worker_weights)

        intra_reduced = list()
        inter_reduced = None
        with tf.variable_scope('worker{}'.format(task_index)):
            # intra-worker allreduce
            for weight in all_weights[task_index]:
                with tf.device(weight.device):
                    intra_reduced.append(collective_ops.all_reduce(
                        weight, NUM_TOWERS, task_index, 0, 'Add', 'Div'))
            # inter-worker allreduce
            weight = all_weights[task_index][0]
            with tf.device(weight.device):
                inter_reduced = collective_ops.all_reduce(
                    weight, NUM_PROCESSES, NUM_PROCESSES, 0, 'Add', 'Div')

        if task_index == 0:
            session_creator = tf.train.ChiefSessionCreator(
                master=server.target)
        else:
            session_creator = tf.train.WorkerSessionCreator(
                master=server.target)
        with tf.train.MonitoredSession(session_creator=session_creator) \
                as mon_sess:
            result_all = mon_sess.run(all_weights)
            print('task {} sense {}'.format(task_index, result_all))
            result_inter = mon_sess.run([inter_reduced])
            print('task {} inter_reduce {}'.format(task_index, result_inter))
            result_intra = mon_sess.run([intra_reduced])
            print('task {} intra_reduce {}'.format(task_index, result_intra))
            time.sleep(1)

def start_process():
    """start process"""
    port = 60000
    host_fmt = 'localhost:{}'
    worker_hosts = list()
    for process_index in range(NUM_PROCESSES):
        worker_hosts.append(host_fmt.format(port + process_index))
    mp_ctx = mp.get_context(MP_METHOD)
    processes = list()
    for process_index in range(NUM_PROCESSES):
        process = mp_ctx.Process(target=process_fn,
                                 args=(worker_hosts, process_index,))
        processes.append(process)
        process.start()
    for process in processes:
        process.join()

if __name__ == '__main__':
    start_process()

Log:

(tf-1.13-py3) [huwh1@huwh1-centos worksync]$ python ./tf_distribute_collective_ops_rev2.py 
2019-08-29 15:28:57.955533: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-08-29 15:28:57.957342: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-08-29 15:28:57.969872: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3408000000 Hz
2019-08-29 15:28:57.970223: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x3c8cb80 executing computations on platform Host. Devices:
2019-08-29 15:28:57.970257: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>
2019-08-29 15:28:57.971219: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3408000000 Hz
2019-08-29 15:28:57.971524: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x3c8ce20 executing computations on platform Host. Devices:
2019-08-29 15:28:57.971544: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>
2019-08-29 15:28:57.972298: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:252] Initialize GrpcChannelCache for job worker -> {0 -> localhost:60000, 1 -> localhost:60001}
2019-08-29 15:28:57.973081: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:252] Initialize GrpcChannelCache for job worker -> {0 -> localhost:60000, 1 -> localhost:60001}
2019-08-29 15:28:57.973289: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:391] Started server with target: grpc://localhost:60000
2019-08-29 15:28:57.974456: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:391] Started server with target: grpc://localhost:60001
WARNING:tensorflow:From /home/huwh1/virtualenv/tf-1.13-py3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /home/huwh1/virtualenv/tf-1.13-py3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
2019-08-29 15:28:58.154712: I tensorflow/core/distributed_runtime/master_session.cc:1192] Start master session 35144c1b1c60a213 with config: device_count { key: "CPU" value: 3 } experimental { collective_group_leader: "/job:worker/replica:0/task:0" }
2019-08-29 15:28:58.170439: I tensorflow/core/distributed_runtime/master_session.cc:1192] Start master session 3c9b794a1f1e9112 with config: device_count { key: "CPU" value: 3 } experimental { collective_group_leader: "/job:worker/replica:0/task:0" }
task 0 sense [[-0.8433976, 1.4618217, 1.4098305], [-0.9607944, -0.2295152, 1.4139572]]
2019-08-29 15:29:28.195776: I tensorflow/core/distributed_runtime/master_session.cc:1192] Start master session 12138e7e76d8c500 with config: device_count { key: "CPU" value: 3 } experimental { collective_group_leader: "/job:worker/replica:0/task:0" }
task 1 sense [[-0.8433976, 1.4618217, 1.4098305], [-0.9607944, -0.2295152, 1.4139572]]
task 0 inter_reduce [-0.90209603]
task 1 inter_reduce [-0.90209603]
2019-08-29 15:29:28.244516: F tensorflow/core/common_runtime/collective_param_resolver_local.cc:389] Check failed: cp->group.group_size == cp->instance.device_names.size() (3 vs. 2)0x7ff630142180
2019-08-29 15:29:30.742565: W tensorflow/core/common_runtime/base_collective_executor.cc:203] BaseCollectiveExecutor::StartAbort Unavailable: OS Error
	 [[{{node worker1_1/CollectiveReduce}}]]
2019-08-29 15:29:30.742915: W tensorflow/core/common_runtime/base_collective_executor.cc:203] BaseCollectiveExecutor::StartAbort Unavailable: OS Error
	 [[{{node worker1_1/CollectiveReduce_1}}]]
2019-08-29 15:29:30.743016: W tensorflow/core/common_runtime/base_collective_executor.cc:203] BaseCollectiveExecutor::StartAbort Unavailable: OS Error
	 [[{{node worker1_1/CollectiveReduce_2}}]]
2019-08-29 15:29:40.753678: W tensorflow/core/distributed_runtime/master_session.cc:1363] Timeout for closing worker session
2019-08-29 15:29:50.758593: I tensorflow/core/distributed_runtime/master.cc:267] CreateSession still waiting for response from worker: /job:worker/replica:0/task:0
Terminated

@dubey
Copy link
Member

dubey commented Sep 25, 2019

Are you trying to do something like a hierarchical all-reduce?

I haven't tried this myself but one thing I see that confuses me is the assignment of group keys and instance keys. You want a unique group key for each set of devices participating in a collective, and you want a unique instance key for every instance of a collective.

@whhu
Copy link
Author

whhu commented Sep 26, 2019

Thank you for the fruitful advice. 😃 It seems a global unique instance_key across all groups is required for the all_reduce op, in spite of different group is specified by the group_key in a distributed environment. I think this issue should be closed.

@whhu whhu closed this as completed Sep 26, 2019
@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues TF 1.13 Issues related to TF 1.13 type:bug Bug
Projects
None yet
Development

No branches or pull requests

5 participants