New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TBE UVM cache prefetch pipeline #1893
Closed
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
✅ Deploy Preview for pytorch-fbgemm-docs canceled.
|
This pull request was exported from Phabricator. Differential Revision: D47418650 |
yuguo68
added a commit
to yuguo68/FBGEMM
that referenced
this pull request
Jul 26, 2023
Summary: Pull Request resolved: pytorch#1893 This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation. ## 1. prevent immature cache eviction: cache gets evicted while it is being used by forward pass Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/pytorch#1883 to check whether a cache slot is in use or not when an eviction is attempted. ## 2. prevent dirty cache: weight is being updated while it is loading to cache If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE. The stream sync looks like: ``` # backward(batch_i) waits for prefetch(batch_{i+1}) backward pre_hook: cur_stream.wait_stream(prefetch_stream) # backward(batch_i) TBE.backward() # prefetch(batch_{i+2}) waits for backward(batch_i) backward hook: prefetch_stream.wait_stream(cur_stream) ``` ## 3. prevent cache inconsistency: weight get updated after it is loaded to cache With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i. Example of the issue is as follows: idx is in batch_i, batch_{i+1} prefetch(batch_i) - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss) forward(batch_i) prefetch(batch_{i+1}) - insert idx into cache, cache is loaded from host memory backward(batch_i) - cache_locations_batch_i of idx is -1, the host memory is updated forward(batch_{i+1}) - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated. The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE. Reviewed By: sryap Differential Revision: D47418650 fbshipit-source-id: 6b01aa4df8bc7c121a7c66893fcf27b73cd8be73
yuguo68
force-pushed
the
export-D47418650
branch
from
July 26, 2023 17:35
406ef0f
to
0c17608
Compare
This pull request was exported from Phabricator. Differential Revision: D47418650 |
yuguo68
added a commit
to yuguo68/FBGEMM
that referenced
this pull request
Jul 26, 2023
Summary: Pull Request resolved: pytorch#1893 This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation. ## 1. prevent immature cache eviction: cache gets evicted while it is being used by forward pass Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/pytorch#1883 to check whether a cache slot is in use or not when an eviction is attempted. ## 2. prevent dirty cache: weight is being updated while it is loading to cache If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE. The stream sync looks like: ``` # backward(batch_i) waits for prefetch(batch_{i+1}) backward pre_hook: cur_stream.wait_stream(prefetch_stream) # backward(batch_i) TBE.backward() # prefetch(batch_{i+2}) waits for backward(batch_i) backward hook: prefetch_stream.wait_stream(cur_stream) ``` ## 3. prevent cache inconsistency: weight get updated after it is loaded to cache With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i. Example of the issue is as follows: idx is in batch_i, batch_{i+1} prefetch(batch_i) - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss) forward(batch_i) prefetch(batch_{i+1}) - insert idx into cache, cache is loaded from host memory backward(batch_i) - cache_locations_batch_i of idx is -1, the host memory is updated forward(batch_{i+1}) - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated. The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE. Reviewed By: sryap Differential Revision: D47418650 fbshipit-source-id: 84811c423ef30fec82282702be181c00310c4e84
yuguo68
force-pushed
the
export-D47418650
branch
from
July 26, 2023 17:44
0c17608
to
d9f861b
Compare
This pull request was exported from Phabricator. Differential Revision: D47418650 |
Summary: Pull Request resolved: pytorch#1893 This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation. ## 1. prevent immature cache eviction: cache gets evicted while it is being used by forward pass Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/pytorch#1883 to check whether a cache slot is in use or not when an eviction is attempted. ## 2. prevent dirty cache: weight is being updated while it is loading to cache If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE. The stream sync looks like: ``` # backward(batch_i) waits for prefetch(batch_{i+1}) backward pre_hook: cur_stream.wait_stream(prefetch_stream) # backward(batch_i) TBE.backward() # prefetch(batch_{i+2}) waits for backward(batch_i) backward hook: prefetch_stream.wait_stream(cur_stream) ``` ## 3. prevent cache inconsistency: weight get updated after it is loaded to cache With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i. Example of the issue is as follows: idx is in batch_i, batch_{i+1} prefetch(batch_i) - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss) forward(batch_i) prefetch(batch_{i+1}) - insert idx into cache, cache is loaded from host memory backward(batch_i) - cache_locations_batch_i of idx is -1, the host memory is updated forward(batch_{i+1}) - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated. The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE. Reviewed By: sryap Differential Revision: D47418650 fbshipit-source-id: 9500c0b902761f24aa5086198574f5b85b141cbd
yuguo68
force-pushed
the
export-D47418650
branch
from
July 26, 2023 17:50
d9f861b
to
056f1f0
Compare
This pull request was exported from Phabricator. Differential Revision: D47418650 |
This pull request has been merged in 78c60ce. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation.
1. prevent immature cache eviction: cache gets evicted while it is being used by forward pass
Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the
lxu_cache_locking_counter
in D46172802/#1883 to check whether a cache slot is in use or not when an eviction is attempted.2. prevent dirty cache: weight is being updated while it is loading to cache
If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE.
The stream sync looks like:
3. prevent cache inconsistency: weight get updated after it is loaded to cache
With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.
Example of the issue is as follows:
idx is in batch_i, batch_{i+1}
prefetch(batch_i)
- failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
forward(batch_i)
prefetch(batch_{i+1})
- insert idx into cache, cache is loaded from host memory
backward(batch_i)
- cache_locations_batch_i of idx is -1, the host memory is updated
forward(batch_{i+1})
- OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.
The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.
Reviewed By: sryap
Differential Revision: D47418650