Skip to content

Commit

Permalink
Implement multi-pass prefetch for memory efficiency
Browse files Browse the repository at this point in the history
Summary:
## Context

Memory snapshot shows significant memory usage during prefetch kernels (specifically, `linearize_cache_index` and `lru_cache_populate`), which is estimated to be 6x of input size

And unfortunately, due to they using dedicated stream, the memory cannot be reused by any other stream without performance penalty.

So we need to lower down the peak prefetch memory usage as much as possible.

## MultiPass Prefetch (MPP)
Multipass prefetch is basically a technique to sacrifice a bit of more running time for less peak memory during prefetch: We observed that intermediate memory usage for all functions during prefetch is `O(N)`, so we reduce the total prefetched index (`N`) for each pass to reduce the peak temporary usage. The following passes will recycle the memory used in the first pass so they won't further increase the memory footprint.

**Benefit**

With this being turned on, the peak memory usage will be dropped from `6 * input_size` to `(6 / M) * input_size`, where `M` is the total # of passes being configured.

**Overhead**
Overall, the bigger `M` we configured, the slower we'll be. But the overall overhead is acceptable.

- **Efficiency regression**: Prefetch is taking longer because the process of cache lookup is being repeated for every duplicate index. In the past, they're deduped before being looked up, but now they might be look up multiple times if duplicate index are across different passes.
   - The regression is overall insignificant, as the major cost is the data movement between DDR and HBM. We'll always copy the data only once, even if they're duplicated across different passes.
   - The regression is likely hidden from the actual training performance, since prefetch happen in a separate stream. As long as it's not long enough to block sparse backward it's invisible.
- **Spamming CUDA Launch Queue**: CUDA is allowing max # of 1024 pending kernels. CPU will go blocking if more are submitted. If a kernel is really small, we'll easily spam launch queue and greatly hurt QPS. We mitigate this via limit the minimal # of elements for a pass.

## What's in the patch?
1. Add multipass prefetch config to the interface of TBE. By default it's None for full backward compatibility
2. Modify the `lru_find_uncached` to make it idempotent -- if we tried to lock the same id multiple times in one single timestep (but multiple passes), we'll increase lock counter by only one.

Differential Revision: D56908989
  • Loading branch information
levythu authored and facebook-github-bot committed May 7, 2024
1 parent e83e81a commit 133e6dc
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 64 deletions.
226 changes: 163 additions & 63 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,20 @@ class UVMCacheStatsIndex(enum.IntEnum):
num_conflict_misses = 5


@dataclass
class MultiPassPrefetchConfig:
# Number of passes to split indices tensor into. Actual number of passes may
# be less if indices tensor is too small to split.
num_passes: int = 12

# The minimal number of element in indices tensor to be able to splitted into
# two passes. This is useful to prevent too many prefetch kernels spamming
# the CUDA launch queue.
# The default 6M indices means 6M * 8 * 6 = approx. 300MB of memory overhead
# per pass.
min_splitable_pass_size: int = 6 * 1024 * 1024


def construct_split_state(
embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]],
rowwise: bool,
Expand Down Expand Up @@ -390,6 +404,7 @@ def __init__( # noqa C901
# Embedding table names that are contained in this TBE.
table_names: Optional[List[str]] = None,
optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None,
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
self.uuid = str(uuid.uuid4())
Expand All @@ -403,12 +418,35 @@ def __init__( # noqa C901
self.prefetch_pipeline: bool = prefetch_pipeline
self.lock_cache_line: bool = self.prefetch_pipeline
self.use_uniq_cache_locations_bwd: bool = self.prefetch_pipeline
self.multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = (
multipass_prefetch_config
)

if record_cache_metrics is not None:
self.record_cache_metrics = record_cache_metrics
else:
self.record_cache_metrics = RecordCacheMetrics(False, False)

if multipass_prefetch_config:
assert (
prefetch_pipeline
), "Multipass prefetch makes no sense in non-prefetch mode."
assert (
cache_algorithm == CacheAlgorithm.LRU
), "Multipass prefetch is only supported in LRU cache."
assert (
multipass_prefetch_config.num_passes > 0
), f"num_passes must be positive, get {multipass_prefetch_config.num_passes}"
assert (
multipass_prefetch_config.min_splitable_pass_size > 0
), f"min_splitable_pass_size must be positive, get {multipass_prefetch_config.min_splitable_pass_size}"
assert (
not self.record_cache_metrics.record_cache_miss_counter
and not self.record_cache_metrics.record_tablewise_cache_miss
), self.log(
"Unique cache miss counters are not accurate in multipass prefetch and therefore not supported"
)

self.embedding_specs = embedding_specs
(rows, dims, locations, compute_devices) = zip(*embedding_specs)
T_ = len(self.embedding_specs)
Expand Down Expand Up @@ -926,6 +964,48 @@ def _register_nonpersistent_buffers(self, prefix: str) -> None:
persistent=False,
)

@staticmethod
def get_prefetch_passes(
multipass_prefetch_config: Optional[MultiPassPrefetchConfig],
input_tensor: Tensor,
output_tensor: Tensor,
) -> List[Tuple[Tensor, Tensor, int]]:
"""
Given input (the indices to forward), return the segmentation for each pass
in the format of (input[start_idx:end_idx], output[start_idx:end_idx], start_idx).
Caller should guarantee input and output are having the size on dimension 0
The returned segments are guaranteed to completely and non-overlappingly cover the input tensor.
In non-multipass-prefetch mode, it returns the input/output tensor itself.
"""
if multipass_prefetch_config is None:
return [(input_tensor, output_tensor, 0)]
mpp_config: MultiPassPrefetchConfig = multipass_prefetch_config

N = input_tensor.size(0)
if N <= mpp_config.num_passes or mpp_config.num_passes == 1:
# One row per pass, just don't split
return [(input_tensor, output_tensor, 0)]

pass_size: int = max(
(N + mpp_config.num_passes - 1) // mpp_config.num_passes,
mpp_config.min_splitable_pass_size,
)
ret: List[Tuple[Tensor, Tensor, int]] = []
for i in range(0, N, pass_size):
# start_idx must be less than end_idx
start_idx = i
end_idx = min(i + pass_size, N)
ret.append(
(
input_tensor[start_idx:end_idx],
output_tensor[start_idx:end_idx],
start_idx,
)
)
return ret

def get_states(self, prefix: str) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
if not hasattr(self, f"{prefix}_physical_placements"):
raise DoesNotHavePrefix()
Expand Down Expand Up @@ -1195,7 +1275,12 @@ def forward( # noqa: C901
self._report_tbe_mem_usage()

if len(self.timesteps_prefetched) == 0:
self._prefetch(indices, offsets, vbe_metadata)
# In forward, we don't enable multi-pass prefetch as we want the process
# to be as fast as possible and memory usage doesn't matter (will be recycled
# by dense fwd/bwd)
self._prefetch(
indices, offsets, vbe_metadata, multipass_prefetch_config=None
)

if len(self.timesteps_prefetched) > 0:
self.timesteps_prefetched.pop(0)
Expand Down Expand Up @@ -1510,7 +1595,12 @@ def prefetch(
offsets,
batch_size_per_feature_per_rank,
)
self._prefetch(indices, offsets, vbe_metadata)
self._prefetch(
indices,
offsets,
vbe_metadata,
multipass_prefetch_config=self.multipass_prefetch_config,
)
if forward_stream is not None:
self._prefetch_tensors_record_stream(forward_stream)

Expand All @@ -1519,6 +1609,7 @@ def _prefetch(
indices: Tensor,
offsets: Tensor,
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
) -> None:
if not is_torchdynamo_compiling():
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
Expand All @@ -1535,81 +1626,90 @@ def _prefetch(
self.local_uvm_cache_stats.zero_()
self._report_io_size_count("prefetch_input", indices)

linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
self.cache_hash_size_cumsum,
indices,
offsets,
vbe_metadata.B_offsets if vbe_metadata is not None else None,
vbe_metadata.max_B if vbe_metadata is not None else -1,
)

if (
self.record_cache_metrics.record_cache_miss_counter
or self.record_cache_metrics.record_tablewise_cache_miss
final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
for (
partial_indices,
partial_lxu_cache_locations,
base_offset,
) in self.get_prefetch_passes(
multipass_prefetch_config, indices, final_lxu_cache_locations
):
lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
self.cache_hash_size_cumsum,
partial_indices,
offsets,
vbe_metadata.B_offsets if vbe_metadata is not None else None,
vbe_metadata.max_B if vbe_metadata is not None else -1,
base_offset,
)
if self.record_cache_metrics.record_cache_miss_counter:
self._update_cache_miss_counter(
lxu_cache_locations, linear_cache_indices

if (
self.record_cache_metrics.record_cache_miss_counter
or self.record_cache_metrics.record_tablewise_cache_miss
):
lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)
if self.record_cache_metrics.record_cache_miss_counter:
self._update_cache_miss_counter(
lxu_cache_locations, linear_cache_indices
)
if self.record_cache_metrics.record_tablewise_cache_miss:
self._update_tablewise_cache_miss(
lxu_cache_locations, linear_cache_indices, offsets
)

if self.cache_algorithm == CacheAlgorithm.LRU:
torch.ops.fbgemm.lru_cache_populate(
self.weights_uvm,
self.cache_hash_size_cumsum,
self.total_cache_hash_size,
self.cache_index_table_map,
self.weights_offsets,
self.D_offsets,
linear_cache_indices,
self.lxu_cache_state,
self.lxu_cache_weights,
self.timestep,
self.lxu_state,
self.stochastic_rounding,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
self.lock_cache_line,
self.lxu_cache_locking_counter,
)
if self.record_cache_metrics.record_tablewise_cache_miss:
self._update_tablewise_cache_miss(
lxu_cache_locations, linear_cache_indices, offsets
elif self.cache_algorithm == CacheAlgorithm.LFU:
torch.ops.fbgemm.lfu_cache_populate(
self.weights_uvm,
self.cache_hash_size_cumsum,
self.total_cache_hash_size,
self.cache_index_table_map,
self.weights_offsets,
self.D_offsets,
linear_cache_indices,
self.lxu_cache_state,
self.lxu_cache_weights,
self.lxu_state,
self.stochastic_rounding,
)

if self.cache_algorithm == CacheAlgorithm.LRU:
torch.ops.fbgemm.lru_cache_populate(
self.weights_uvm,
self.cache_hash_size_cumsum,
self.total_cache_hash_size,
self.cache_index_table_map,
self.weights_offsets,
self.D_offsets,
torch.ops.fbgemm.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
self.lxu_cache_weights,
self.timestep,
self.lxu_state,
self.stochastic_rounding,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
self.lock_cache_line,
self.lxu_cache_locking_counter,
)
elif self.cache_algorithm == CacheAlgorithm.LFU:
torch.ops.fbgemm.lfu_cache_populate(
self.weights_uvm,
self.cache_hash_size_cumsum,
self.total_cache_hash_size,
self.cache_index_table_map,
self.weights_offsets,
self.D_offsets,
linear_cache_indices,
self.lxu_cache_state,
self.lxu_cache_weights,
self.lxu_state,
self.stochastic_rounding,
lxu_cache_locations_output=partial_lxu_cache_locations,
)

assert (
len(self.lxu_cache_locations_list) < self.max_prefetch_depth
), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()"

lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)

self.lxu_cache_locations_list.append(lxu_cache_locations)
self.lxu_cache_locations_list.append(final_lxu_cache_locations)

if self.gather_uvm_cache_stats:
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
Expand Down
5 changes: 4 additions & 1 deletion fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel(
const bool found = ::__ldg((&lxu_cache_state[cache_set][0]) + slot) == idx;
if (found) {
// mark it as recently accessed so we don't evict.
const bool already_locked = lru_state[cache_set][slot] == time_stamp;
lru_state[cache_set][slot] = time_stamp;
if (lock_cache_line) {
// Don't lock the line one more time if we have locked it in the same
// batch (timestamp)
if (lock_cache_line && !already_locked) {
lxu_cache_locking_counter[cache_set][slot] += 1;
}
}
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/test/tbe/cache/cache_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
ComputeDevice,
MultiPassPrefetchConfig,
SplitTableBatchedEmbeddingBagsCodegen,
)

Expand Down Expand Up @@ -96,6 +97,7 @@ def generate_cache_tbes(
stochastic_rounding: bool = False,
gather_uvm_cache_stats: bool = False,
reporter_config: Optional[TestingStatsReporterConfig] = None,
multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None,
) -> Tuple[
SplitTableBatchedEmbeddingBagsCodegen,
SplitTableBatchedEmbeddingBagsCodegen,
Expand Down Expand Up @@ -153,6 +155,7 @@ def generate_cache_tbes(
cache_precision=weights_cache_precision,
gather_uvm_cache_stats=gather_uvm_cache_stats,
stats_reporter_config=reporter_config,
multipass_prefetch_config=multipass_prefetch_config,
)

if use_int_weight:
Expand Down

0 comments on commit 133e6dc

Please sign in to comment.