Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train.init_from_checkpoint does not support mirrorredStrategy and CollectiveAllReduceStrategy #23986

Closed
libliang opened this issue Nov 27, 2018 · 15 comments

Comments

@libliang
Copy link

@libliang libliang commented Nov 27, 2018

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
    We are changing some code in https://github.com/google-research/bert run_classifier.py to make it can be run on machine with multiple GPUs

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    Distributor ID: Ubuntu
    Description: Ubuntu 16.04.5 LTS
    Release: 16.04
    Codename: xenial

  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
    NA

  • TensorFlow installed from (source or binary):
    binary

  • TensorFlow version (use command below):
    v1.11.0-0-gc19e29306c. and v1.12.0-0-ga6d8ffae09

  • Python version:
    2.7.12

  • Bazel version (if compiling from source):
    NA

  • GCC/Compiler version (if compiling from source):
    NA

  • CUDA/cuDNN version:

  • GPU model and memory:

You can collect some of this information using our environment capture script
You can also obtain the TensorFlow version with
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"

Describe the current behavior
In Bert's source code, it will call tf.train.init_from_checkpoint(init_checkpoint, assignment_map) in the model_fn

we change some of the Bert's source code to make it can run on mirroredStrategy mode, but we met a failure with the call stack as below:

Traceback (most recent call last):
File "run_classifier_collect.py", line 859, in
tf.app.run()
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "run_classifier_collect.py", line 815, in main
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/estimator/training.py", line 462, in train_and_evaluate
estimator, train_spec, eval_spec, _TrainingExecutor)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/distribute/estimator_training.py", line 279, in train_and_evaluate
session_config=run_config.session_config)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/distribute/distribute_coordinator.py", line 792, in run_distribute_coordinator
task_id, session_config, rpc_layer)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/distribute/distribute_coordinator.py", line 344, in _run_single_worker
worker_fn(strategy)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/distribute/estimator_training.py", line 246, in _worker_fn
hooks=hooks)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
rendezvous.raise_errors()
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
six.reraise(typ, value, traceback)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
saving_listeners=saving_listeners
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1205, in _train_model
return self._train_model_distributed(input_fn, hooks, saving_listeners)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1316, in _train_model_distributed
self.config)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/training/distribute.py", line 721, in call_for_each_tower
return self._call_for_each_tower(fn, *args, **kwargs)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 556, in _call_for_each_tower
return _call_for_each_tower(self, fn, *args, **kwargs)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 183, in _call_for_each_tower
coord.join(threads)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/training/coordinator.py", line 389, in join
six.reraise(*self._exc_info_to_raise)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
yield
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 177, in _call_for_each_tower
**merge_kwargs)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/training/checkpoint_utils.py", line 211, in _init_from_checkpoint
var = _collect_partitioned_variable(current_var_or_name, store_vars)
File "/data/anaconda2/envs/py27tf12/lib/python2.7/site-packages/tensorflow/python/training/checkpoint_utils.py", line 365, in _collect_partitioned_variable
if name + "/part_0" in all_vars:
TypeError: unsupported operand type(s) for +: 'PerDevice' and 'str'

we can see the code in init_from_checkpoint:
if distribution_strategy_context.get_cross_tower_context():
_init_from_checkpoint(None, ckpt_dir_or_file, assignment_map)
else:
distribution_strategy_context.get_tower_context().merge_call(
_init_from_checkpoint, ckpt_dir_or_file, assignment_map)

seems that above code goes into the merge_call version

I think init_from_checkpoint is indent to support cross_tower scenario, and seems that above code goes into the merge_call version in BERT's scenario

but the function _init_from_checkpoint does not properly handle the parameter passed across tower?

Describe the expected behavior
init_from_checkpoint can be called successfully

Code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate the problem.

  1. Modified BERT's code to make it use MirrorredStrategy
  2. Modified BERT's optimization.py to make it can be run in distribution environment

Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

@libliang libliang changed the title train.init_from_checkpoint does not support mirrorredStrategy and train.init_from_checkpoint does not support mirrorredStrategy and CollectiveAllReduceStrategy Nov 27, 2018
@YangXuefeng

This comment has been minimized.

Copy link

@YangXuefeng YangXuefeng commented Nov 27, 2018

I have the same issue.
raceback (most recent call last):
File "run_squad.py", line 1285, in
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "run_squad.py", line 1212, in main
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1205, in _train_model
return self._train_model_distributed(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1316, in _train_model_distributed
self.config)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/distribute.py", line 721, in call_for_each_tower
return self._call_for_each_tower(fn, *args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 556, in _call_for_each_tower
return _call_for_each_tower(self, fn, *args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 183, in _call_for_each_tower
coord.join(threads)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 389, in join
six.reraise(*self._exc_info_to_raise)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
yield
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 177, in _call_for_each_tower
**merge_kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 213, in _init_from_checkpoint
var = _collect_partitioned_variable(current_var_or_name, store_vars)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 368, in _collect_partitioned_variable
if name + "/part_0" in all_vars:
TypeError: unsupported operand type(s) for +: 'PerDevice' and 'str'

Just like libliang said, There are some problems in checkpoint_util.py
The assignment_map arg given to init_from_checkpoint are {str:str} type, but it is changed to {str:Perdevice} type in _init_from_checkpoint

How to handle the restore problem for PerDevice object ?

PerDevice:{'/replica:0/task:0/device:GPU:0': 'bert/embeddings/LayerNorm/beta', '/replica:0/task:0/device:GPU:1': 'bert/embeddings/LayerNorm/beta', '/replica:0/task:0/device:GPU:2': 'bert/embeddings/LayerNorm/beta'},<class 'tensorflow.contrib.distribute.python.values.PerDevice'>

@JayYip

This comment has been minimized.

Copy link

@JayYip JayYip commented Nov 29, 2018

Some work around is, you can first create a checkpoint with warm start from BERT using one GPU, say, train 10 steps. Then disable the scaffold(or the warm start init op) and do multiple GPU training. It's ugly, but works.

First run with single GPU:

warm_start = True

def scaffold():
    init_op = tf.train.init_from_checkpoint(
        self.config.init_checkpoint, assignment_map)
    return tf.train.Scaffold(init_op)

if not warm_start:
    train_scaffold = None
else:
    train_scaffold = scaffold()

Second run with MirroredStrategy:

warm_start = False

def scaffold():
    init_op = tf.train.init_from_checkpoint(
        self.config.init_checkpoint, assignment_map)
    return tf.train.Scaffold(init_op)

if not warm_start:
    train_scaffold = None
else:
    train_scaffold = scaffold()

Update:
I have figured out a better solution: use a train hook to initialize from checkpoint.

You can refer to this script:
https://github.com/JayYip/bert-multiple-gpu/blob/master/src/ckpt_restore_hook.py

@sunyerui

This comment has been minimized.

Copy link

@sunyerui sunyerui commented Dec 4, 2018

I don't think it's an issue of TensorFlow.
An alternative solution is modify modeling.py in BERT codes, generating assignment map as string->variable format, instead of string->string.
Specifically, change line 338 of modeling.py from assignment_map[name]=name to assignment_map[name]=name_to_variable[name], and mirroredStrategy works well.

@YangXuefeng

This comment has been minimized.

Copy link

@YangXuefeng YangXuefeng commented Dec 5, 2018

I don't think it's an issue of TensorFlow.
An alternative solution is modify modeling.py in BERT codes, generating assignment map as string->variable format, instead of string->string.
Specifically, change line 338 of modeling.py from assignment_map[name]=name to assignment_map[name]=name_to_variable[name], and mirroredStrategy works well.

That's PERFECT !Thank you.

@ohwe

This comment has been minimized.

Copy link

@ohwe ohwe commented Dec 5, 2018

@sunyerui thanks a lot for your fix! But right after that MirroredStrategy still fails with

ValueError: You must specify an aggregation method to update a MirroredVariable in Tower Context"

(here's the entire error message if needed:

INFO:tensorflow:Error reported to Coordinator: You must specify an aggregation method to update a MirroredVariable in Tower Context. Traceback (most recent call last): File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception yield File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 795, in run self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2195, in _call_model_fn features, labels, mode, config) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1195, in _call_model_fn model_fn_results = self._model_fn(features=features, **kwargs) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2479, in _model_fn features, labels, is_export_mode=is_export_mode) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 1259, in call_without_tpu return self._call_model_fn(features, labels, is_export_mode=is_export_mode) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 1533, in _call_model_fn estimator_spec = self._model_fn(features=features, **kwargs) File "run_pretraining.py", line 179, in model_fn total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) File "/home/borislav/bert/optimization.py", line 77, in create_optimizer zip(grads, tvars), global_step=global_step) File "/home/borislav/bert/optimization.py", line 151, in apply_gradients [param.assign(next_param), File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/distribute/python/values.py", line 402, in assign return self._assign_func(f=assign_fn, *args, **kwargs) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/distribute/python/values.py", line 379, in _assign_func raise ValueError("You must specify an aggregation method to update a " ValueError: You must specify an aggregation method to update a MirroredVariable in Tower Context.

@JayYip

This comment has been minimized.

Copy link

@JayYip JayYip commented Dec 5, 2018

@sunyerui thanks a lot for your fix! But right after that MirroredStrategy still fails with

ValueError: You must specify an aggregation method to update a MirroredVariable in Tower Context"

(here's the entire error message if needed:

INFO:tensorflow:Error reported to Coordinator: You must specify an aggregation method to update a MirroredVariable in Tower Context. Traceback (most recent call last): File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception yield File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 795, in run self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2195, in _call_model_fn features, labels, mode, config) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 1195, in _call_model_fn model_fn_results = self._model_fn(features=features, **kwargs) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2479, in _model_fn features, labels, is_export_mode=is_export_mode) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 1259, in call_without_tpu return self._call_model_fn(features, labels, is_export_mode=is_export_mode) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 1533, in _call_model_fn estimator_spec = self._model_fn(features=features, **kwargs) File "run_pretraining.py", line 179, in model_fn total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) File "/home/borislav/bert/optimization.py", line 77, in create_optimizer zip(grads, tvars), global_step=global_step) File "/home/borislav/bert/optimization.py", line 151, in apply_gradients [param.assign(next_param), File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/distribute/python/values.py", line 402, in assign return self._assign_func(f=assign_fn, *args, **kwargs) File "/home/borislav/tf/local/lib/python2.7/site-packages/tensorflow/contrib/distribute/python/values.py", line 379, in _assign_func raise ValueError("You must specify an aggregation method to update a " ValueError: You must specify an aggregation method to update a MirroredVariable in Tower Context.

This is not related to this issue. You need to re-implement the optimizer.

You can take my implementation as a reference:
https://github.com/JayYip/bert-multiple-gpu/blob/master/src/optimizer.py

@libliang

This comment has been minimized.

Copy link
Author

@libliang libliang commented Dec 5, 2018

@JayYip is right, in order to use mirroredStrategy, we need to re-implement the optimizer, because the current version of optimizer in BERT does not support merge_call, which is required by mirroredStrategy

@libliang

This comment has been minimized.

Copy link
Author

@libliang libliang commented Dec 5, 2018

Thanks @sunyerui 's solution, which can work fine in the current model_fn code, but we still need to consider one thing is, the logic in model_fn will be executed in every tower, that means the variables will be initialized the as much as the number of gpu, I suspect we still have room to do more improvement to NOT make such redundant initialization

@JayYip

This comment has been minimized.

Copy link

@JayYip JayYip commented Dec 5, 2018

Thanks @sunyerui 's solution, which can work fine in the current model_fn code, but we still need to consider one thing is, the logic in model_fn will be executed in every tower, that means the variables will be initialized the as much as the number of gpu, I suspect we still have room to do more improvement to NOT make such redundant initialization

Please see my comment above. I think using a session run hook to restore variables will only initialize once. Please correct me if I am wrong.

@YFwang1992

This comment has been minimized.

Copy link

@YFwang1992 YFwang1992 commented Dec 25, 2018

Can anyone release multi-gpus pretraining code about bert(tensorflow) using mirroredstragegy

@yuefengz

This comment has been minimized.

Copy link
Member

@yuefengz yuefengz commented Jan 28, 2019

There is a PR that should fix this issue: https://github.com/tensorflow/tensorflow/pull/24245/files . You can try it out or wait until it is merged.

@yuefengz

This comment has been minimized.

Copy link
Member

@yuefengz yuefengz commented Mar 12, 2019

Closing this bug since the PR is merged.

@yuefengz yuefengz closed this Mar 12, 2019
@megjoshi

This comment has been minimized.

Copy link

@megjoshi megjoshi commented Jun 14, 2019

which version of tensorflow has this fix? I m trying 1.12.2 but still facing same error.

@leonsim

This comment has been minimized.

Copy link

@leonsim leonsim commented Sep 13, 2019

Hi, some question as above. I'm trying 1.13.1 but still facing the same error.

@xmy7216

This comment has been minimized.

Copy link

@xmy7216 xmy7216 commented Sep 17, 2019

Hi, some question as above. I'm trying 1.14.0, but still facing the same error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
You can’t perform that action at this time.