Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TBE UVM cache line locking - backend #1883

Closed
wants to merge 1 commit into from

Conversation

yuguo68
Copy link
Contributor

@yuguo68 yuguo68 commented Jul 19, 2023

Summary:
This diff is to support cache prefetch pipeline, where cache insert can execute in parallel with embedding table forward/backward. As cache prefetch may evict cache lines, we must make sure that cache lines that are used by forward/backward won't be evicted.

The implementation here targets at training kernel and LRU cache policy. We create a lxu_cache_locking_counter of size (cache_sets, warp_size) to indicate whether a cache slot is in use (counter > 0) or not (counter = 0).

Operations on lxu_cache_locking_counter:

In lru_cache_find_uncached_cuda, if an index is already in cache, the lxu_cache_locking_counter of the corresponding cache_slot is incremented.

In lru_cache_insert_cuda, we first sort the cache slots based on timestamp within a cache set as the original LRU implementation. When inserting, we check whether the lxu_cache_locking_counter of each cache slot to insert is positive of not. If the counter of a cache slot is positive, we skip inserting and move on to next cache slot. If a cache slot is inserted, the lxu_cache_locking_counter of that slot is incremented.

After the backward pass is done, we call lxu_cache_locking_counter_decrement through a backward hook. For any cache_slot in lxu_cache_locations, the counter of that cache_slot is decremented by 1. Duplicate cache_slots only get decrement once.

With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.

Example of the issue is as follows:

        idx is in batch_i, batch_{i+1}
        prefetch(batch_i)
          - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
        forward(batch_i)
        prefetch(batch_{i+1})
          - insert idx into cache, cache is loaded from host memory
        backward(batch_i)
          - cache_locations_batch_i of idx is -1, the host memory is updated
        forward(batch_{i+1})
          - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.

The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.

Reviewed By: sryap

Differential Revision: D46172802

Summary:
This diff is to support cache prefetch pipeline, where cache insert can execute in parallel with embedding table forward/backward. As cache prefetch may evict cache lines, we must make sure that cache lines that are used by forward/backward won't be evicted.

The implementation here targets at training kernel and LRU cache policy. We create a `lxu_cache_locking_counter` of size `(cache_sets, warp_size)` to indicate whether a cache slot is in use (`counter > 0`) or not (`counter = 0`).

Operations on `lxu_cache_locking_counter`:

In `lru_cache_find_uncached_cuda`, if an index is already in cache, the `lxu_cache_locking_counter` of the corresponding cache_slot is incremented.

In `lru_cache_insert_cuda`, we first sort the cache slots based on timestamp within a cache set as the original LRU implementation. When inserting, we check whether the `lxu_cache_locking_counter` of each cache slot to insert is positive of not. If the counter of a cache slot is positive, we skip inserting and move on to next cache slot. If a cache slot is inserted, the `lxu_cache_locking_counter` of that slot is incremented.

After the backward pass is done, we call `lxu_cache_locking_counter_decrement` through a backward hook. For any cache_slot in lxu_cache_locations, the counter of that cache_slot is decremented by 1. Duplicate cache_slots only get decrement once.

With pipeline,  in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.

Example of the issue is as follows:

```
        idx is in batch_i, batch_{i+1}
        prefetch(batch_i)
          - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
        forward(batch_i)
        prefetch(batch_{i+1})
          - insert idx into cache, cache is loaded from host memory
        backward(batch_i)
          - cache_locations_batch_i of idx is -1, the host memory is updated
        forward(batch_{i+1})
          - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.
```

The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.

Reviewed By: sryap

Differential Revision: D46172802

fbshipit-source-id: 883ce6f875c59f3060946292c92aad6a1f75af8d
@netlify
Copy link

netlify bot commented Jul 19, 2023

Deploy Preview for pytorch-fbgemm-docs canceled.

Name Link
🔨 Latest commit 40ead72
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/64b850c13590720008a72dc7

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D46172802

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 4096e8d.

yuguo68 added a commit to yuguo68/FBGEMM that referenced this pull request Jul 26, 2023
Summary:
This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation.

## 1.  prevent immature cache eviction: cache gets evicted while it is being used by forward pass

Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/pytorch#1883 to check whether a cache slot is in use or not when an eviction is attempted.

## 2. prevent dirty cache: weight is being updated while it is loading to cache
If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE.

The stream sync looks like:
```
# backward(batch_i) waits for prefetch(batch_{i+1})
backward pre_hook: cur_stream.wait_stream(prefetch_stream)
# backward(batch_i)
TBE.backward()
# prefetch(batch_{i+2}) waits for backward(batch_i)
backward hook: prefetch_stream.wait_stream(cur_stream)
```

## 3. prevent cache inconsistency: weight get updated after it is loaded to cache

With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.
Example of the issue is as follows:
        idx is in batch_i, batch_{i+1}
        prefetch(batch_i)
          - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
        forward(batch_i)
        prefetch(batch_{i+1})
          - insert idx into cache, cache is loaded from host memory
        backward(batch_i)
          - cache_locations_batch_i of idx is -1, the host memory is updated
        forward(batch_{i+1})
          - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.

The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.

Reviewed By: sryap

Differential Revision: D47418650

fbshipit-source-id: 05081a3b61d924238884e4263396847fe4fac4ed
yuguo68 added a commit to yuguo68/FBGEMM that referenced this pull request Jul 26, 2023
Summary:
Pull Request resolved: pytorch#1893

This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation.

## 1.  prevent immature cache eviction: cache gets evicted while it is being used by forward pass

Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/pytorch#1883 to check whether a cache slot is in use or not when an eviction is attempted.

## 2. prevent dirty cache: weight is being updated while it is loading to cache
If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE.

The stream sync looks like:
```
# backward(batch_i) waits for prefetch(batch_{i+1})
backward pre_hook: cur_stream.wait_stream(prefetch_stream)
# backward(batch_i)
TBE.backward()
# prefetch(batch_{i+2}) waits for backward(batch_i)
backward hook: prefetch_stream.wait_stream(cur_stream)
```

## 3. prevent cache inconsistency: weight get updated after it is loaded to cache

With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.
Example of the issue is as follows:
        idx is in batch_i, batch_{i+1}
        prefetch(batch_i)
          - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
        forward(batch_i)
        prefetch(batch_{i+1})
          - insert idx into cache, cache is loaded from host memory
        backward(batch_i)
          - cache_locations_batch_i of idx is -1, the host memory is updated
        forward(batch_{i+1})
          - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.

The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.

Reviewed By: sryap

Differential Revision: D47418650

fbshipit-source-id: 6b01aa4df8bc7c121a7c66893fcf27b73cd8be73
yuguo68 added a commit to yuguo68/FBGEMM that referenced this pull request Jul 26, 2023
Summary:
Pull Request resolved: pytorch#1893

This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation.

## 1.  prevent immature cache eviction: cache gets evicted while it is being used by forward pass

Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/pytorch#1883 to check whether a cache slot is in use or not when an eviction is attempted.

## 2. prevent dirty cache: weight is being updated while it is loading to cache
If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE.

The stream sync looks like:
```
# backward(batch_i) waits for prefetch(batch_{i+1})
backward pre_hook: cur_stream.wait_stream(prefetch_stream)
# backward(batch_i)
TBE.backward()
# prefetch(batch_{i+2}) waits for backward(batch_i)
backward hook: prefetch_stream.wait_stream(cur_stream)
```

## 3. prevent cache inconsistency: weight get updated after it is loaded to cache

With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.
Example of the issue is as follows:
        idx is in batch_i, batch_{i+1}
        prefetch(batch_i)
          - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
        forward(batch_i)
        prefetch(batch_{i+1})
          - insert idx into cache, cache is loaded from host memory
        backward(batch_i)
          - cache_locations_batch_i of idx is -1, the host memory is updated
        forward(batch_{i+1})
          - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.

The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.

Reviewed By: sryap

Differential Revision: D47418650

fbshipit-source-id: 84811c423ef30fec82282702be181c00310c4e84
yuguo68 added a commit to yuguo68/FBGEMM that referenced this pull request Jul 26, 2023
Summary:
Pull Request resolved: pytorch#1893

This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation.

## 1.  prevent immature cache eviction: cache gets evicted while it is being used by forward pass

Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/pytorch#1883 to check whether a cache slot is in use or not when an eviction is attempted.

## 2. prevent dirty cache: weight is being updated while it is loading to cache
If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE.

The stream sync looks like:
```
# backward(batch_i) waits for prefetch(batch_{i+1})
backward pre_hook: cur_stream.wait_stream(prefetch_stream)
# backward(batch_i)
TBE.backward()
# prefetch(batch_{i+2}) waits for backward(batch_i)
backward hook: prefetch_stream.wait_stream(cur_stream)
```

## 3. prevent cache inconsistency: weight get updated after it is loaded to cache

With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.
Example of the issue is as follows:
        idx is in batch_i, batch_{i+1}
        prefetch(batch_i)
          - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
        forward(batch_i)
        prefetch(batch_{i+1})
          - insert idx into cache, cache is loaded from host memory
        backward(batch_i)
          - cache_locations_batch_i of idx is -1, the host memory is updated
        forward(batch_{i+1})
          - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.

The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.

Reviewed By: sryap

Differential Revision: D47418650

fbshipit-source-id: 9500c0b902761f24aa5086198574f5b85b141cbd
facebook-github-bot pushed a commit that referenced this pull request Jul 27, 2023
Summary:
Pull Request resolved: #1893

This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation.

## 1.  prevent immature cache eviction: cache gets evicted while it is being used by forward pass

Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/#1883 to check whether a cache slot is in use or not when an eviction is attempted.

## 2. prevent dirty cache: weight is being updated while it is loading to cache
If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE.

The stream sync looks like:
```
# backward(batch_i) waits for prefetch(batch_{i+1})
backward pre_hook: cur_stream.wait_stream(prefetch_stream)
# backward(batch_i)
TBE.backward()
# prefetch(batch_{i+2}) waits for backward(batch_i)
backward hook: prefetch_stream.wait_stream(cur_stream)
```

## 3. prevent cache inconsistency: weight get updated after it is loaded to cache

With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.
Example of the issue is as follows:
        idx is in batch_i, batch_{i+1}
        prefetch(batch_i)
          - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
        forward(batch_i)
        prefetch(batch_{i+1})
          - insert idx into cache, cache is loaded from host memory
        backward(batch_i)
          - cache_locations_batch_i of idx is -1, the host memory is updated
        forward(batch_{i+1})
          - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.

The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.

Reviewed By: sryap

Differential Revision: D47418650

fbshipit-source-id: 144855513814ca9eb4a181c46c318d5cb70efb4d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants