-
Notifications
You must be signed in to change notification settings - Fork 580
RFC: TensorFloat-32 support in TensorFlow #247
Changes from all commits
d25b2e5
2826217
3664d6b
b2d5d47
1d34ede
ec1dff5
6626197
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # TensorFloat-32 in TensorFlow | ||
|
|
||
| | Status | Accepted | | ||
| :-------------- |:---------------------------------------------------- | | ||
| | **RFC #** | [247](https://github.com/tensorflow/community/pull/247) | | ||
| | **Author(s)** | Reed Wanderman-Milne (reedwm@google.com) | | ||
| | **Sponsor** | Sanjoy Das (sanjoy@google.com) | | ||
| | **Updated** | 2020-06-10 | | ||
|
|
||
| ## Objective | ||
|
|
||
| Allow [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format) to be used in TensorFlow to improve performance. | ||
|
|
||
| ## Motivation | ||
|
|
||
| [NVIDIA Ampere](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/), an upcoming generation of NVIDIA GPUs announced at GTC 2020, introduces a new numeric format called TensorFloat-32, or TF32 for short. | ||
| TF32 has the range of float32/bfloat16 (i.e. 8 bits of exponent) and the precision of fp16 (i.e. 10 bits of mantissa). | ||
| It is not an in-memory format, but tensor cores natively support it as a computation format. | ||
| TF32 should not be thought of as an in-memory dtype but instead a computation mode that increases performance and decreases numeric precision for certain float32 operations. | ||
| NVIDIA has not found any cases where TF32 reduces the convergence of deep learning models. | ||
|
|
||
| Upcoming versions of cuDNN, cuBLAS, and other CUDA libraries will expose a mode of execution that has float32 inputs and outputs, but internally truncates float32 to TF32 and uses tensor cores. This is expected to be sufficiently accurate to reach the same convergence as the “full” float32 mode of execution but significantly faster. Each element still takes four bytes, so there is still a memory and performance penalty compared to using float16 or bfloat16. | ||
|
|
||
| As TF32 is only usable by tensor cores, it can only be used for matrix multiplications and other ops implemented in terms of matrix multiplications, such as convolutions. It is not used for pointwise ops or reductions. | ||
|
|
||
| TF32 will benefit users who run float32 models on Ampere GPUs, so we need an API to allow these users to enable TF32. | ||
|
|
||
| ## Design Proposal | ||
|
|
||
| In TensorFlow, TF32 can be enabled for supported ops on Ampere GPUs with the following call: | ||
|
|
||
| ```python | ||
| tf.config.allow_tensor_float_32_execution(True) | ||
| ``` | ||
|
|
||
| The word "allow" emphasizes only certain devices (Ampere GPUs) and ops (such as matmuls and convolutions) will be affected. Once enabled, all local and remote Ampere GPUs use TF32 for supported float32 ops. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should an error be raised (or a warning) if allow=True and the device does not support TF32? One could imagine users being surprised that no complaint is raised when when this mode is requested. I guess in that case the flag would be "use_tensor_float_32_execution" instead of allow... but maybe explicit is preferable here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I considered this and the original draft did warn. But I think we should encourage users to put the |
||
|
|
||
| Passing `False` to `allow_tensor_float_32_execution` will disable TF32 if already enabled. This is useful if multiple models are run sequentially in the same process, where only some should use TF32. It is also useful for tests, as it allows a test class to test both TF32 being enabled and disabled. | ||
|
|
||
| We call the function "allow_tensor_float_32_execution" instead of the more concise "allow_tf32_execution" because people may mistakenly interpret the phrase "tf32" to refer to TensorFlow instead of TensorFloat. | ||
|
|
||
| The following can be used to query whether TF32 is enabled. The function returns a bool. | ||
|
|
||
| ```python | ||
| tf.config.tensor_float_32_execution_allowed() | ||
| ``` | ||
|
|
||
| Since TF32 only affects Ampere GPUs, moving an op to a GPU can affect numerics. Grappler and other graph optimizations will not consider this, and will freely move ops between devices without regard to numeric stability. As a result, explicitly putting an op on the CPU does not ensure it will use the full float32 precision instead of TF32. | ||
|
|
||
| Since TensorFlow 2.3 will not support CUDA 11, which is required for TF32, this API will first be exposed in TensorFlow 2.4. However, downstream repackagers of TensorFlow (such as Google Cloud) are encouraged to cherrypick CUDA 11 and this API into their version of 2.3, so they can offer TF32 support to their customers who use TensorFlow 2.3. | ||
|
|
||
|
|
||
| ### Turning TF32 on by default | ||
|
|
||
| Numerical studies by NVIDIA covering many common models suggest that TF32 is numerically robust for deep learning applications. In order to take advantage of these new accelerations in Ampere hardware for float32 models, we would like to enable TF32 by default. However, since the TensorFlow 2.4 release is still months away and we intend to use that time to further test and evaluate TF32, it is too early to decide in this RFC whether TF32 execution will be enabled or disabled by default. Here we begin a discussion by listing the most likely scenarios. Comments are also welcome. The scenarios are: | ||
|
|
||
| 1. Turn it on by default in 2.4, the first release with the TF32 API. | ||
| 2. Turn it on by default in 2.5, the second release with the TF32 API. | ||
| 3. Do not turn it on by default. | ||
|
|
||
|
|
||
| The advantage of (1) is that all Ampere float32 users get the performance benefit unless they opt out. Additionally, Ampere numerics will not be loosened in a new release: TensorFlow 2.4 will be the first release with Ampere support, and it will immediately default to TF32 being enabled. The disadvantage is that we cannot collect as much feedback from users before defaulting to TF32, because no stable version of TensorFlow will support TF32 but not have it enabled by default. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't buy this: the models that TF32 targets use FP32 today, so I'd expect users to notice a regression even if 2.4 enables it by default, which they can corroborate further by comparing the accuracy with disabling it explicitly.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't fully understand your argument. We'd like to have a release where users can try TF32 and give us feedback before we decide to whether to turn it on by default. If we immediately turn it on by default in 2.4, users can still give feedback, but it will be too late: we will have already made our decision.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is the assumption that disabling tf32 by default (if users report problems after we enable it by default in 2.4) is more of a breaking change than enabling it by default (if users try it with 2.4 and don't report problems)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, enabling tf32 is probably more of a breaking change. However, we only want to make such a change at most once. After enabling tf32, I don't think we should subsequently disable it. |
||
|
|
||
| The advantage of (2) is that it allows users to test and give feedback on TF32 with a stable version of TensorFlow before we decide whether it should be default. The disadvantage is it’s possible we break Ampere users who relied on the full float32 precision in 2.4 when they upgrade to 2.5 | ||
|
|
||
| The advantage of (3) is that a user’s model will never break due to using reduced precision, even if they upgrade from an earlier GPU to Ampere. The disadvantage is that many Ampere users would not get the performance benefit from TF32 as they would not know about the API to enable it. | ||
|
|
||
| Another advantage of turning on TF32 by default is that it makes TensorFlow’s behavior with GPUs more consistent with TPUs. TPUs internally use lower precision for float32 matmuls and convolutions, similar to how Ampere GPUs will use lower precision for float32 matmuls and convolutions if TF32 is enabled. | ||
|
|
||
| **If you know of any models whose accuracy may be impacted by TF32, please comment on this RFC.** Note that TF32 is equivalent to float32 except it has 10 bits of mantissa instead of 23 bits. It will initially be used only for matmuls and convolutions, but may be used for other ops in the future if they are implemented in terms of a matmul. Once TensorFlow 2.4 is released, you will be able to test the impact of TF32 on your models if you have Ampere GPUs. You will be able to test earlier if you use Tensorflow nightly packages, and even earlier if you build from source with CUDA 11 support. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to indicate a way to receive private feedback about this too.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think it's likely someone would be willing to share with us but not publicly? I could recommend emailing me for private feedback, but I would rather people post feedback publicly since I want to be transparent about why we make whatever decision we make.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for the transparency.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have to recommend it, but not everyone may be at freedom to talk about what they're working on in a public forum. So mentioning a private channel seems like a good idea.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok I'll mention this but state we much prefer it be posted publicly, even if that requires being vague about the use case. We should list at least two emails in case one of us is sick. @sanjoy should I list my email and yours? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there users that NVIDIA can help us find directly?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nvidia will probably directly collect feedback and tell us. They have already tested themselves on many models. |
||
|
|
||
| ### Remote devices | ||
|
|
||
| Enabling TF32 will affect remote Ampere GPUs in addition to local Ampere GPUs. In particular, it will affect devices on hosts connected to via [`tf.config.experimental_connect_to_host`](https://www.tensorflow.org/api_docs/python/tf/config/experimental_connect_to_host) or [`tf.config.experimental_connect_to_cluster`](https://www.tensorflow.org/api_docs/python/tf/config/experimental_connect_to_cluster). The initial, unexposed version of the function in TensorFlow 2.3 will likely only support local devices, not remote devices, since we will probably not have time to implement remote device support. | ||
|
|
||
| We will need to issue an RPC to remote devices when TF32 is enabled or disabled. This means calling `allow_tensor_float_32_execution` will be a fairly heavy function call. It should only be used at the beginning of the program, or in between executing two models or tests. It is not intended to be used within a single model to make parts of it run in TF32 and parts of it run in float32, especially considering that approach would also not work within a `tf.function`. | ||
|
|
||
| ### Alternatives considered | ||
|
|
||
| We could have an API to enable TF32 on a per-op basis, to allow users to run only part of their model in TF32. This would be useful if they discover certain TF32 ops in their model need the full float32 precision. However, we anticipate that almost every model can run safely in TF32, so we do not think this alternative is necessary. If we discover specifying TF32 on a per-op basis is useful, we can later add a TF32 scope or some other mechanism to do this. | ||
|
|
||
| We could disallow enabling/disabling TF32 once a tensor has been created. This makes dealing with remote devices simpler, since we would only have to modify an RPC to create a context with TF32 enabled. We would not have to support updating a context to enable/disable TF32 after the context has been created. `tf.config.set_visible_devices` has this behavior. However, this is more limiting, and it will be non obvious to users that they have to enable TF32 before creating any tensors. | ||
|
|
||
| We could export this API in TensorFlow 2.3. The issue is we don’t plan on building TensorFlow 2.3 with CUDA 11. Without CUDA 11 support, TF32 cannot be used, so the API would not be usable except by those who build TensorFlow from source. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using the Keras mixed precision policy API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This affects ops outside Keras, so it shouldn't be under
tf.keras. In a sense, TF32 is a form of mixed precision, as some ops use TF32 and others use float32. We could put it undertf.mixed_precision, but I thinktf.configis better since tf32 should be though of as a mode, not a dtype.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the mixed precision API mostly changes the dtype of tensors, while tf32 doesn't affect tensor dtype (afaict) just the dtype of accumulators inside ops.