Skip to content

Commit

Permalink
TBE UVM cache prefetch pipeline
Browse files Browse the repository at this point in the history
Summary:
This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation.

## 1.  prevent immature cache eviction: cache gets evicted while it is being used by forward pass

Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/#1883 to check whether a cache slot is in use or not when an eviction is attempted.

## 2. prevent dirty cache: weight is being updated while it is loading to cache
If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE.

The stream sync looks like:
```
# backward(batch_i) waits for prefetch(batch_{i+1})
backward pre_hook: cur_stream.wait_stream(prefetch_stream)
# backward(batch_i)
TBE.backward()
# prefetch(batch_{i+2}) waits for backward(batch_i)
backward hook: prefetch_stream.wait_stream(cur_stream)
```

## 3. prevent cache inconsistency: weight get updated after it is loaded to cache

With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.
Example of the issue is as follows:
        idx is in batch_i, batch_{i+1}
        prefetch(batch_i)
          - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
        forward(batch_i)
        prefetch(batch_{i+1})
          - insert idx into cache, cache is loaded from host memory
        backward(batch_i)
          - cache_locations_batch_i of idx is -1, the host memory is updated
        forward(batch_{i+1})
          - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.

The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.

Reviewed By: sryap

Differential Revision: D47418650

fbshipit-source-id: 05081a3b61d924238884e4263396847fe4fac4ed
  • Loading branch information
yuguo68 authored and facebook-github-bot committed Jul 26, 2023
1 parent 99f2287 commit 406ef0f
Show file tree
Hide file tree
Showing 2 changed files with 397 additions and 27 deletions.
168 changes: 158 additions & 10 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
record_cache_metrics: RecordCacheMetrics
uvm_cache_stats: torch.Tensor
local_uvm_cache_stats: torch.Tensor
linear_cache_indices_list: List[Tensor]

def __init__( # noqa C901
self,
Expand Down Expand Up @@ -323,13 +324,23 @@ def __init__( # noqa C901
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
uvm_non_rowwise_momentum: bool = False, # place non-rowwise momentum on UVM
use_experimental_tbe: bool = False, # set to True to use TBE v2 (only support NVIDIA GPUs)
# set to True to enable prefetch pipeline, currently only supports LRU cache policy.
# If a separate stream is used for prefetch, user is responsible to call
# set_prefetch_stream after module initialization and prefetch_tensors_record_stream
# after each prefetch call.
prefetch_pipeline: bool = False,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

self.pooling_mode = pooling_mode
self.bounds_check_mode_int: int = bounds_check_mode.value
self.weights_precision = weights_precision
self.output_dtype: int = output_dtype.as_int()
assert (
not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU
), "Only LRU cache policy supports prefetch_pipeline."
self.prefetch_pipeline: bool = prefetch_pipeline
self.lock_cache_line: bool = self.prefetch_pipeline

if record_cache_metrics is not None:
self.record_cache_metrics = record_cache_metrics
Expand Down Expand Up @@ -922,7 +933,7 @@ def forward( # noqa: C901
self.prefetch(indices, offsets)

self.timesteps_prefetched.pop(0)
lxu_cache_locations = (
self.lxu_cache_locations = (
self.lxu_cache_locations_empty
if len(self.lxu_cache_locations_list) == 0
else self.lxu_cache_locations_list.pop(0)
Expand All @@ -945,7 +956,7 @@ def forward( # noqa: C901
pooling_mode=self.pooling_mode,
indice_weights=per_sample_weights,
feature_requires_grad=feature_requires_grad,
lxu_cache_locations=lxu_cache_locations,
lxu_cache_locations=self.lxu_cache_locations,
output_dtype=self.output_dtype,
vbe_metadata=vbe_metadata,
is_experimental=self.is_experimental,
Expand Down Expand Up @@ -1163,6 +1174,8 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
self.stochastic_rounding,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
self.lock_cache_line,
self.lxu_cache_locking_counter,
)
elif self.cache_algorithm == CacheAlgorithm.LFU:
torch.ops.fbgemm.lfu_cache_populate(
Expand All @@ -1182,15 +1195,19 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
assert (
len(self.lxu_cache_locations_list) < self.max_prefetch_depth
), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"
self.lxu_cache_locations_list.append(
torch.ops.fbgemm.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)

lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)

self.lxu_cache_locations_list.append(lxu_cache_locations)
if self.prefetch_pipeline:
self.linear_cache_indices_list.append(linear_cache_indices)

if self.gather_uvm_cache_stats:
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
# We may wanna do this accumulation atomically, but as it's only for monitoring,
Expand Down Expand Up @@ -1521,6 +1538,9 @@ def _apply_cache_state(
self.lxu_cache_locations_empty = torch.empty(
0, device=self.current_device, dtype=torch.int32
).fill_(-1)
self.lxu_cache_locations = self.lxu_cache_locations_empty
self.prefetch_stream: Optional[torch.cuda.Stream] = None
self.linear_cache_indices_list = []

self._init_uvm_cache_stats()

Expand Down Expand Up @@ -1561,6 +1581,7 @@ def _apply_cache_state(
torch.tensor([0, 0], dtype=torch.int64),
persistent=False,
)
self._init_uvm_cache_counter(cache_sets, persistent=False)
return

assert cache_load_factor > 0
Expand Down Expand Up @@ -1648,13 +1669,137 @@ def _apply_cache_state(
"cache_miss_counter",
torch.tensor([0, 0], device=self.current_device, dtype=torch.int64),
)
self._init_uvm_cache_counter(cache_sets, persistent=True)
if self.prefetch_pipeline:
# using the placeholder_autograd_tensor to make sure
# the hook is executed after the backward pass
# not using register_module_full_backward_hook
# due to https://github.com/pytorch/pytorch/issues/100528
self.placeholder_autograd_tensor.register_hook(
self._sync_stream_post_backward
)
self.register_full_backward_pre_hook(
self._update_cache_counter_and_locations
)

if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU):
raise ValueError(
f"cache_algorithm must be {CacheAlgorithm.LRU} "
f"or {CacheAlgorithm.LFU}"
)

def prefetch_tensors_record_stream(self, stream: torch.cuda.Stream) -> None:
# Record the tensors created by prefetch stream and consumed by forward/backward
# to the forward stream. In PyTorch, each backward CUDA op runs on the same
# stream that was used for its corresponding forward op.
if self.prefetch_stream is None:
return
for t in self.lxu_cache_locations_list:
# pyre-fixme[6]: For 1st param expected `_C.Stream` but got `streams.Stream`
t.record_stream(stream)
for t in self.linear_cache_indices_list:
# pyre-fixme[6]: For 1st param expected `_C.Stream` but got `streams.Stream`
t.record_stream(stream)

def _sync_stream_post_backward(
self,
grad: Tensor,
) -> None:
"""
backward hook function when prefetch_pipeline is enabled.
With the pipeline, prefetch(batch_{i+2}) may overlap with backward(batch_{i}).
There is race condition that backward(batch_i) writes to UVM memory and
at the same time prefetch(batch_{i+2}) loads UVM memory to cache. This stream sync forces
backward(batch_i) to finish before prefetch(batch_{i+2}).
"""
if self.prefetch_stream is not None:
self.prefetch_stream.wait_stream(torch.cuda.current_stream())

def _update_cache_counter_and_locations(
self,
module: nn.Module,
grad_input: Union[Tuple[Tensor, ...], Tensor],
) -> None:
"""
Backward prehook function when prefetch_pipeline is enabled.
This function does 3 things:
1. backward stream waits for prefetch stream to finish.
Otherwise the prefetch(batch_{i+1}) might overlap with backward(batch_i).
If an idx is not in cache in batch_i, but it is being inserted in batch_{i+1},
there is race condition that backward(batch_i) writes to UVM memory and
at the same time prefetch(batch_{i+1}) loads UVM memory to cache.
2. decrement the lxu_cache_locking_counter to indicate the current batch is finished.
The lxu_cache_locking_counter is updated in both prefetch and TBE backward.
As there is no overlap between prefetch and backward, we can decrement either before or
after backward. It's better to decrement before lxu_cache_locations gets updated.
3. update lxu_cache_locations to address the cache inconsistency issue.
In the case that the same index is not inserted into cache in batch_i,
but it is inserted in batch_{i+1}, the cache can be invalid in
the sense that the cached weight for this index does not have the
backward update of batch_i.
Example of the issue is as follows:
idx is in batch_i, batch_{i+1}
prefetch(batch_i)
- failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
forward(batch_i)
prefetch(batch_{i+1})
- insert idx into cache, cache is loaded from host memory
backward(batch_i)
- cache_locations_batch_i of idx is -1, the host memory is updated
forward(batch_{i+1})
- OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.
The fix to this cache inconsistency is to update the cache_locations_batch_i before backward of batch_i,
so that the cache gets updated correctly by the backward pass of TBE.
"""

if self.prefetch_stream is not None:
# need to wait for the prefetch of next batch,
# so that cache states are valid
torch.cuda.current_stream().wait_stream(self.prefetch_stream)

torch.ops.fbgemm.lxu_cache_locking_counter_decrement(
self.lxu_cache_locking_counter,
self.lxu_cache_locations,
)

linear_cache_indices = self.linear_cache_indices_list.pop(0)
lxu_cache_locations_new = torch.ops.fbgemm.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
False, # not collecting cache stats
self.local_uvm_cache_stats,
)
# self.lxu_cache_locations is updated inplace
torch.ops.fbgemm.lxu_cache_locations_update(
self.lxu_cache_locations,
lxu_cache_locations_new,
)

def _init_uvm_cache_counter(self, cache_sets: int, persistent: bool) -> None:
if self.prefetch_pipeline and persistent:
self.register_buffer(
"lxu_cache_locking_counter",
torch.zeros(
cache_sets,
DEFAULT_ASSOC,
device=self.current_device,
dtype=torch.int32,
),
)
else:
self.register_buffer(
"lxu_cache_locking_counter",
torch.zeros([0, 0], dtype=torch.int32, device=self.current_device),
persistent=persistent,
)

def _init_uvm_cache_stats(self) -> None:
if not self.gather_uvm_cache_stats:
# If uvm_cache_stats is not enabled, register stub entries via buffer to state_dict for TorchScript to JIT properly.
Expand Down Expand Up @@ -1696,6 +1841,9 @@ def _init_uvm_cache_stats(self) -> None:
)
self.reset_uvm_cache_stats()

def set_prefetch_stream(self, prefetch_stream: torch.cuda.Stream) -> None:
self.prefetch_stream = prefetch_stream

def reset_cache_states(self) -> None:
if not self.lxu_cache_weights.numel():
return
Expand Down

0 comments on commit 406ef0f

Please sign in to comment.