Skip to content

Conversation

k-w-w
Copy link
Contributor

@k-w-w k-w-w commented Jun 5, 2018

Hi all, I added DistributionStrategy to the Transformer model. Currently, the model isn't running very well with MirroredStrategy and I'm not sure why. @robieta @guptapriya As people familiar with DistributionStrategy, please help!

Current stats:

GPUs Global steps/sec batch size (per device)
1 GPU 1.11 4096
4 GPU 0.34 3072

I decreased the batch size because of OOM errors.

@k-w-w k-w-w requested review from qlzh727, robieta and guptapriya June 5, 2018 21:35
@k-w-w k-w-w requested review from karmel and a team as code owners June 5, 2018 21:36
Copy link
Contributor

@karmel karmel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this works; have you validated that it does for CPU, 1-GPU, and multi-GPU?

else:
params["batch_size"] = distribution_utils.per_device_batch_size(
flags_obj.batch_size or params["default_batch_size"],
flags_core.get_num_gpus(flags_obj))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cleaner:

params["batch_size"] = flags_obj.batch_size or params["default_batch_size_tpu"]
if not params["use_tpu"]:
   params["batch_size"] = distribution_utils.per_device_batch_size(
      params["batch_size"], flags_core.get_num_gpus(flags_obj))

# limitations under the License.
# ==============================================================================
"""Helper functions for running models in a distributed setting."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to selves: Build files will be required for this.


remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: double quotes

@robieta
Copy link
Contributor

robieta commented Jun 6, 2018

FYI training blows up with synthetic data if the data_dir doesn't exist, even though it isn't used.

edit (kathy): can't reply, but this has been fixed. Had a typo that prevented the synthetic data flag from being seen.

@k-w-w
Copy link
Contributor Author

k-w-w commented Jun 7, 2018

This should be good to be reviewed. Thanks @guptapriya, @robieta, and @yuefengz for the DistributionStrategy help! We've determined that the embedding is slowing the model down. A feature request has been made to Dist Strat.

With the hierarchical all_reduce setting, the speeds are more reasonable:

GPUs Global steps/sec batch size (per device)
1 GPU 1.11 4096
4 GPU 0.57 3072
4 GPU (on an older TF rev) 0.66 3072

Copy link
Member

@qlzh727 qlzh727 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding distribution_utils

num_gpus=flags_core.get_num_gpus(flags_obj)
)
distribution_strategy = distribution_utils.get_distribution_strategy(
flags_core.get_num_gpus(flags_obj), use_hierarchical_copy=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better if the all_reduce algorithm was a performance flag since the optimal one will vary by hardware. For ResNet in particular, DistStrat auto selection should be fine in a week or so.

dataset = tf.data.Dataset.from_tensors(tf.ones([batch, length], tf.int64))
dataset = dataset.map(lambda x: (x, x))
dataset = dataset.cache()
dataset = dataset.repeat(1000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer that we don't hard code a dummy size into the synthetic data. Because there is already support for setting a step based schedule, infinite repeat should be fine. (This is what ResNet does.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See similar message in #4476 , but we should make a helper fn for getting synthetic data in a particular shape... don't both do that, though.

"base": model_params.BASE_PARAMS,
"base_multi_gpu": model_params.BASE_MULTI_GPU_PARAMS,
"big": model_params.BIG_PARAMS,
"big_multi_gpu": model_params.BIG_MULTI_GPU_PARAMS,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it might be better to exclude these from choices, and hot-swap them out if there are multiple GPUs, logging a message so that people know it happened. Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was conflicted about the choice between swapping the parameters vs creating a separate set of params were about equal. Usability-wise, I prefer swapping the params in the multi-gpu case so the running commands don't have to change as much. I ended up choosing the other option to maintain the idea to keep models "mathematically equivalent" when swapping to multiple GPUs (like how the batch size is global instead of per-device).

I don't have strong preferences. Do you still think we should go with swapping the params?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking have both sets of params, but don't include the multi-GPU in the mapping dict. Then, if multi GPU, swap the entire param set up front.

"""Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, train_epochs, etc.).
flags_core.define_base(multi_gpu=False, num_gpu=False)
flags_core.define_base(multi_gpu=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need the multi_gpu=False here, and it is deceptive, so better to remove?

Copy link
Contributor Author

@k-w-w k-w-w Jun 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should remove the multi_gpu flag, and use num_gpu instead? MNIST uses multi_gpu but not num_gpu, and the resnet models use num_gpu but not multi_gpu.

I have a question/suggestion about the num_gpu flag (+@robieta about this) - Currently, the default value of num_gpu is 0 or 1 depending on if there is a GPU. Can we change this default to None instead?
OneDeviceStrategy requires more memory than having no DistributionStrategy (batch size of 4096 causes OOM errors when using OneDeviceStrategy). When the user doesn't specify num_gpus, we should default to using no DistributionStrategy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. Can you ping DistStrat and see if this is expected, and what the advantages of using OneDevice are/will be?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, tried running it again, and it went through without failing. Seems like this might be Transformer getting unlucky with the dynamic batching? For now, I'll leave it as is.

What are your thoughts about the num_gpu vs multi_gpu flags?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to get everything on num_gpus and remove multi_gpu altogether.

dataset = tf.data.Dataset.from_tensors(tf.ones([batch, length], tf.int64))
dataset = dataset.map(lambda x: (x, x))
dataset = dataset.cache()
dataset = dataset.repeat(1000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See similar message in #4476 , but we should make a helper fn for getting synthetic data in a particular shape... don't both do that, though.

import tensorflow as tf


def get_distribution_strategy(num_gpus, use_hierarchical_copy=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args, returns? In particular, a description of hierarchical_copy would be good.

@k-w-w
Copy link
Contributor Author

k-w-w commented Jun 7, 2018 via email

@k-w-w
Copy link
Contributor Author

k-w-w commented Jun 11, 2018

@robieta @karmel PTAL, added synthetic dataset helper and removed the multi gpu flag.

Copy link
Contributor

@robieta robieta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple minor things, but overall LGTM.


return loss_scale > 0

if all_reduce_alg:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an enum.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Replied in another comment)

return False


def generate_synthetic_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this implementation differ from the current synthetic data approach? And could you document why it is better than just the naive .from_tensor_slices(...).repeat(...)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm missing something. The implementation is the same as the current synthetic data approach, except that it allows the input/label shapes to be nested.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm likely the one missing something.

Copy link
Contributor

@karmel karmel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit, but other than that, looks good, thanks

Args:
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Specify which algorithm to use when performing all-reduce.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the choices and default here? Might be nice to mention going to look at DistributionStrategies if more detail is desired.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The choices are nccl, hierarchical_copy, or if not specified DistStrat will look at the device topology and choose. (hc if it looks like a DGX, otherwise nccl I believe.) I think we want None to be the default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A concern I have with limiting the choices here is that more algorithms may be implemented in the future (and might be a hassle to update our code each time). I'll put a mention of DistStrat here.

@k-w-w k-w-w merged commit 29c9f98 into master Jun 12, 2018
@k-w-w k-w-w deleted the t-multi branch June 12, 2018 16:54
@venuswu
Copy link

venuswu commented Aug 25, 2019

Hi all, I added DistributionStrategy to the Transformer model. Currently, the model isn't running very well with MirroredStrategy and I'm not sure why. @robieta @guptapriya As people familiar with DistributionStrategy, please help!

Current stats:

GPUs Global steps/sec batch size (per device)
1 GPU 1.11 4096
4 GPU 0.34 3072
I decreased the batch size because of OOM errors.

I have ran the transformer, it seems to be very slow. @robieta

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants