New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Embedding and Partitioned Variables in TF 2.0 #55
Conversation
Gentle ping to @yangjunpro and @wangsiyu; it seems you guys were eager to this feature in the past. tensorflow/tensorflow#22473 tensorflow/tensorflow#22937 tensorflow/tensorflow#23254 |
Sorry for missing this discussion thread for a while. We will take a serious look at this RFC and provide feedback then. |
### Embedding in Mirrored Architecture: Sharding or Mirroring? | ||
|
||
|
||
In the `MirroredStrategy` or `CollectiveAllReduceStrategy`, there are several ways of handling embeddings. However due to the fact that `dynamic_partition` and `dynamic_stitch` required by embedding lookups on sharded embeddings only have CPU kernels, we can not shard embeddings across GPUs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Little bit curious about this sentence since it looks that TF already has GPU version of dynamic_partition and dynamic_stitch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. DynamicPartiton and DynamitcStitch have GPUs kernels. Modified these paragraphs.
#### Layers | ||
|
||
|
||
We will need to call `strategy.experimental_create_sharded_variable()` in Keras' `Embeeding` layer. Under `ParameterServerStrategy` scope, all variables can potentially be `PartitionedVariable` for loadbalancing only. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo, should be embedding
We will need to call `strategy.experimental_create_sharded_variable()` in Keras' `Embeeding` layer. Under `ParameterServerStrategy` scope, all variables can potentially be `PartitionedVariable` for loadbalancing only. | |
We will need to call `strategy.experimental_create_sharded_variable()` in Keras' `Embedding` layer. Under `ParameterServerStrategy` scope, all variables can potentially be `PartitionedVariable` for loadbalancing only. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
##### Partition Strategy: div or mod | ||
|
||
|
||
There are two strategies to look up embedding vectors on a sharded embedding variable: "div" and "mod". See an explanation [here](https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup). "Mod" strategy may be useful when references to embedding slices are not evenly distributed over their indices. However, "mod" strategy is a poor approximation to the actual load balancing users want. For example, when vocabulary is sorted by frequency, with "mod" strategy the first parameter server always has larger load than the second one. Furthermore, its current checkpointing mechanism also prevents users from migrating to a cluster with different number of parameter servers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the above sentence, seemingly it wants to tell us that "div" strategy performs bettern than "mod" when vocabulary is sorted by frequency. However, due to the nature of "mod" and "div" partition strategy, I think "mod" could bring more balanced partition. So I am a little bit puzzled about the example provided here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want put the random shuffling logic in keras.Embedding
which would balance load better than "mod".
|
||
However, to parallelize computation, people have to write methods that respect the partitions, e.g. `tf.nn.sampled_softmax_loss`. Therefore, this class has been overloaded by different use cases. | ||
|
||
On the other hand, when `PartitionedVariable` is used for sharded embeddings, `partitioned_strategy` has to be kept consistent when it is required by several methods down the stream such as `tf.nn.embedding_lookup` and `tf.nn.nce_loss`. This is uncessary for users and any inconsistency would lead to subtle bugs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could elaborate more about the justification.
On the other hand, when `PartitionedVariable` is used for sharded embeddings, `partitioned_strategy` has to be kept consistent when it is required by several methods down the stream such as `tf.nn.embedding_lookup` and `tf.nn.nce_loss`. This is uncessary for users and any inconsistency would lead to subtle bugs. | |
On the other hand, when `PartitionedVariable` is used for sharded embeddings, `partitioned_strategy` has to be kept consistent when it is required by several methods down the stream such as `tf.nn.embedding_lookup` and `tf.nn.nce_loss` which requires an additional `partition_strategy` argument. This is uncessary for users and any inconsistency would lead to subtle bugs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I've already mentioned "partitioned_strategy
has ... when it is required by several methods", prepending it with "which requires ..." seems redundant.
|
||
1. Support partitioned embedding and partitioned layer in TF 2.0 in parameter server architecture via Distribution Strategy's and Keras layer's API. | ||
1. Better support for embeddings in mirrored and collective allreduce architecture. We will not shard them in this architecture in our pre-TF 2.0 design. | ||
2. We will only support "div" partition strategy but as a post-TF 2.0 work we will support re-balancing embeddings in the Keras' `Embedding` layer. That means we will not support "mod" partition strategy any more. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, the reason for choosing "div" partition strategy is regarding to the checkpoint compatibility issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is another issue of "mod".
## Non-goals | ||
|
||
|
||
1. We don't have plan to support any flavor of model parallelism beyond the current implementations of embedding lookup and loss functions that respect partitioning. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we emphasize the "loss function" here in addition to the "embedding lookup"?
In my understanding, with partition layer support, we could do more than partitioning the loss functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current embedding lookup and a few losses functions respect the partitions of a PartitionedVariable
while other usage of PartitionedVariable
concats all its partitions before using it. The former case is a bit like model parallelism while the latter is only for loadbalancing. This is also the reason why we proposed ShardedVariable
As discussed above, `PartitionedVariable` concatenates component variables silently which can only serve the purpose of loadbalancing. So we need another object for uses cases where computation needs to be parellelized or storage needs to be sharded. Users don't have to deal with it as long as they use Keras' layer API and it won't be exposed as a public API in pre-TF 2.0 stage. | ||
|
||
```python | ||
class ShardedVariable(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will it be better by providing some specific use cases with SharededVariable to demonstrate its difference against the PartitionedVariable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ShardedVariable
is just like ParititionedVariable
excepts it doesn't auto-concats partitions. It will be passed to embedding lookup and a few loss functions. So it is pretty dummy.
|
||
When an embedding variable is small, it may benefit from being mirrored on devices. Lookup can be performed local to each device. Updates can be casted to dense tensors and use the existing allreduce primitive to exchange gradients. This can still be faster than running allgather on their corresponding sparse updates. | ||
|
||
When an embedding variable is larger, we'll need an efficient allgather primitive to exchange updates between devices. Alternatively, we can place it on host memory at the cost of transferring embeddings from host to devices in the forward pass and gathering updates from devices to host in the backward pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we considered for embedding partition using GPU memory rather than host memory? This may bring more complexity, but may also bring more performance gains since host2device traffic can be reduced.
Also, it looks that many embedding-dominating models are not complex so that GPU device memory may be a potential choice for holding the embedding partition.
Actually, inside Alibaba, we are working on prototyping a new strategy called as PEARLStrategy(Partitioned Embedding And RepLicated variables ) to support GPU-based embedding partitioning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just updated these paragraphs. Sharding on GPUs might be faster but more difficult to implement. We'll implement mirroring embeddings and allgather first.
I am not sure how faster it is to shard across GPUs than to put it on host memory. You've probably done some experiment : ). If it worths the complexity, we'll definitely implement it.
* mirroring on all workers' host memory; | ||
* mirroring on all replicas, i.e. devices. | ||
|
||
On multiple hosts, mirroring could be cheaper than sharding in terms of communication cost since it requires smaller number of communications although it transfers more data. This is true even compared to the optimal implementation of sharding. When updates to an embedding are not very sparse, converting them to dense updates and applying allreduce in the mirrored case can be faster than sharding which relies on all-to-all personalized communication or all-gather. Furthermore, mirroring is much easier to implement. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we elaborate it more? I think this only holds true when the sparsity is not high.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe some sentence is confusing. I meant when some gradients are not very sparse, we can densify and all-reduce them, which would be faster than allgather. This is one of the benefits of mirroring embedding.
…ave GPU kernels.
… to their saving and restoring mechanism.
* Move performance heuristics in Distribution Strategy level. We will not expose knobs for users to control; * Emphasize that embedding support in v2 will all be via `Embedding` layer. Users can use `tf.compat.v1` to handle embedding by themselves; * Mention that default `partition_strategy` in v1 `embedding_lookup` is "mod", which will possibly break users's model when they update to TF 2.0; * We want to prioritize shuffling embedding after 2.0 release; * We have plans to serialize and deserialize `Embedding` layer and Distribution Strategies to allow loading a saved model to a different number of partitions.
any progress here? |
@yuefengz is there a development branch we can follow for this feature ? |
Review period will close at the end of 2019-01-30
Summary: this RFC describes the design and interfaces for embeddings and partitioned variables in TF 2.0 based on Distribution Strategy. We propose to support embeddings and sharded layers with DistributionStrategy and Keras' layer API. We also propose to shard embeddings and layers only in
ParameterServerStrategy
and mirror them inMirroredStrategy
orCollectiveAllReduceStrategy.