diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index bd43100cb0..01832dfbc1 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -86,19 +86,19 @@ class EvictionPolicy(NamedTuple): None # feature_score_counter_decay_rates for each table if eviction strategy is feature score ) training_id_eviction_trigger_count: Optional[list[int]] = ( - None # training_id_eviction_trigger_count for each table + None # Number of training IDs that, when exceeded, will trigger eviction for each table. ) training_id_keep_count: Optional[list[int]] = ( - None # training_id_keep_count for each table + None # Target number of training IDs to retain in each table after eviction. ) l2_weight_thresholds: Optional[list[float]] = ( None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm ) threshold_calculation_bucket_stride: Optional[float] = ( - 0.2 # threshold_calculation_bucket_stride if eviction strategy is feature score + 0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction. ) threshold_calculation_bucket_num: Optional[int] = ( - 1000000 # 1M, threshold_calculation_bucket_num if eviction strategy is feature score + 1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction. ) interval_for_insufficient_eviction_s: int = ( # wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient @@ -114,10 +114,16 @@ class EvictionPolicy(NamedTuple): 24 * 3600 # 1 day, interval for feature statistics decay ) meta_header_lens: Optional[list[int]] = None # metaheader length for each table + eviction_free_mem_threshold_gb: Optional[int] = ( + None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode. + ) + eviction_free_mem_check_interval_batch: Optional[int] = ( + None # Number of batches between checks for free memory threshold when using free_mem trigger mode. + ) def validate(self) -> None: - assert self.eviction_trigger_mode in [0, 1, 2, 3, 4], ( - "eviction_trigger_mode must be 0, 1, 2, 3 or 4 " + assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], ( + "eviction_trigger_mode must be 0, 1, 2, 3, 4, 5" f"actual {self.eviction_trigger_mode}" ) if self.eviction_trigger_mode == 0: @@ -143,6 +149,13 @@ def validate(self) -> None: assert ( self.training_id_eviction_trigger_count is not None ), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4" + elif self.eviction_trigger_mode == 5: + assert ( + self.eviction_free_mem_threshold_gb is not None + ), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5" + assert ( + self.eviction_free_mem_check_interval_batch is not None + ), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5" if self.eviction_strategy == 0: assert self.ttls_in_mins is not None, ( diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 32fb3991f7..59ea7f3b70 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -18,8 +18,9 @@ import time from functools import cached_property from math import floor, log2 -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, ClassVar, Optional, Union import torch # usort:skip +import weakref # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers @@ -34,6 +35,7 @@ BoundsCheckMode, CacheAlgorithm, EmbeddingLocation, + EvictionPolicy, get_bounds_check_version_for_platform, KVZCHParams, PoolingMode, @@ -54,6 +56,8 @@ from torch import distributed as dist, nn, Tensor # usort:skip from dataclasses import dataclass +import psutil + from torch.autograd.profiler import record_function from ..cache import get_unique_indices_v2 @@ -100,6 +104,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module): _local_instance_index: int = -1 res_params: RESParams table_names: list[str] + _all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet() + _first_instance_ref: ClassVar[weakref.ref] = None + _eviction_triggered: ClassVar[bool] = False def __init__( self, @@ -179,6 +186,7 @@ def __init__( table_names: Optional[list[str]] = None, use_rowwise_bias_correction: bool = False, # For Adam use optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006 + pg: Optional[dist.ProcessGroup] = None, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -567,6 +575,10 @@ def __init__( # loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend self.load_state_dict: bool = False + SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self) + if SSDTableBatchedEmbeddingBags._first_instance_ref is None: + SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self) + # create tbe unique id using rank index | local tbe idx if tbe_unique_id == -1: SSDTableBatchedEmbeddingBags._local_instance_index += 1 @@ -584,6 +596,7 @@ def __init__( self.tbe_unique_id = tbe_unique_id self.l2_cache_size = l2_cache_size logging.info(f"tbe_unique_id: {tbe_unique_id}") + self.enable_free_mem_trigger_eviction: bool = False if self.backend_type == BackendType.SSD: logging.info( f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, " @@ -688,25 +701,31 @@ def __init__( if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb else self.l2_cache_size ) + kv_zch_params = self.kv_zch_params + eviction_policy = self.kv_zch_params.eviction_policy + if eviction_policy.eviction_trigger_mode == 5: + # If trigger mode is free_mem(5), populate config + self.set_free_mem_eviction_trigger_config(eviction_policy) + # Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters. eviction_config = torch.classes.fbgemm.FeatureEvictConfig( - self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count - self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score - self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration + eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count + eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score + eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util - self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp - self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter - self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter - self.kv_zch_params.eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score - self.kv_zch_params.eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table - self.kv_zch_params.eviction_policy.training_id_keep_count, # training_id_keep_count for each table - self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm + eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp + eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter + eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter + eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score + eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table + eviction_policy.training_id_keep_count, # training_id_keep_count for each table + eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm table_dims.tolist() if table_dims is not None else None, - self.kv_zch_params.eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score - self.kv_zch_params.eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score - self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s, - self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s, - self.kv_zch_params.eviction_policy.interval_for_feature_statistics_decay_s, + eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score + eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score + eviction_policy.interval_for_insufficient_eviction_s, + eviction_policy.interval_for_sufficient_eviction_s, + eviction_policy.interval_for_feature_statistics_decay_s, ) self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper( self.cache_row_dim, @@ -1065,6 +1084,8 @@ def __init__( self.bounds_check_version: int = get_bounds_check_version_for_platform() + self._pg = pg + @cached_property def cache_row_dim(self) -> int: """ @@ -2042,6 +2063,9 @@ def _prefetch( # noqa C901 if dist.get_rank() == 0: self._report_kv_backend_stats() + # May trigger eviction if free mem trigger mode enabled before get cuda + self.may_trigger_eviction() + # Fetch data from SSD if linear_cache_indices.numel() > 0: self.record_function_via_dummy_profile( @@ -4650,3 +4674,91 @@ def direct_write_embedding( ) # Return control to the main stream without waiting for the backend operation to complete + + def get_free_cpu_memory_gb(self) -> float: + mem = psutil.virtual_memory() + return mem.available / (1024**3) + + @classmethod + def trigger_evict_in_all_tbes(cls) -> None: + for tbe in cls._all_tbe_instances: + tbe.ssd_db.trigger_feature_evict() + + @classmethod + def tbe_has_ongoing_eviction(cls) -> bool: + for tbe in cls._all_tbe_instances: + if tbe.ssd_db.is_evicting(): + return True + return False + + def set_free_mem_eviction_trigger_config( + self, eviction_policy: EvictionPolicy + ) -> None: + self.enable_free_mem_trigger_eviction = True + self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode + assert ( + eviction_policy.eviction_free_mem_check_interval_batch is not None + ), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode" + self.eviction_free_mem_check_interval_batch: int = ( + eviction_policy.eviction_free_mem_check_interval_batch + ) + assert ( + eviction_policy.eviction_free_mem_threshold_gb is not None + ), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode" + self.eviction_free_mem_threshold_gb: int = ( + eviction_policy.eviction_free_mem_threshold_gb + ) + logging.info( + f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}" + ) + + def may_trigger_eviction(self) -> None: + def is_first_tbe() -> bool: + first = SSDTableBatchedEmbeddingBags._first_instance_ref + return first is not None and first() is self + + # We assume that the eviction time is less than free mem check interval time + # So every time we reach this check, all evictions in all tbes should be finished. + # We only need to check the first tbe because all tbes share the same free mem, + # once the first tbe detect need to trigger eviction, it will call trigger func + # in all tbes from _all_tbe_instances + if ( + self.enable_free_mem_trigger_eviction + and self.step % self.eviction_free_mem_check_interval_batch == 0 + and self.training + and is_first_tbe() + ): + if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction(): + SSDTableBatchedEmbeddingBags._eviction_triggered = False + + free_cpu_mem_gb = self.get_free_cpu_memory_gb() + local_evict_trigger = int( + free_cpu_mem_gb < self.eviction_free_mem_threshold_gb + ) + tensor_flag = torch.tensor( + local_evict_trigger, + device=self.current_device, + dtype=torch.int, + ) + world_size = dist.get_world_size(self._pg) + if world_size > 1: + dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg) + global_evict_trigger = tensor_flag.item() + else: + global_evict_trigger = local_evict_trigger + if ( + global_evict_trigger >= 1 + and SSDTableBatchedEmbeddingBags._eviction_triggered + ): + logging.warning( + f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true" + ) + if ( + global_evict_trigger >= 1 + and not SSDTableBatchedEmbeddingBags._eviction_triggered + ): + SSDTableBatchedEmbeddingBags._eviction_triggered = True + SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes() + logging.info( + f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction" + ) diff --git a/fbgemm_gpu/requirements.txt b/fbgemm_gpu/requirements.txt index c1f0bb92ff..dcd13bfcd9 100644 --- a/fbgemm_gpu/requirements.txt +++ b/fbgemm_gpu/requirements.txt @@ -29,3 +29,4 @@ setuptools_git_versioning tabulate patchelf fairscale +psutil diff --git a/fbgemm_gpu/requirements_genai.txt b/fbgemm_gpu/requirements_genai.txt index 59741362a5..722de8de37 100644 --- a/fbgemm_gpu/requirements_genai.txt +++ b/fbgemm_gpu/requirements_genai.txt @@ -30,3 +30,4 @@ setuptools_git_versioning tabulate patchelf fairscale +psutil diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index 3f2848d4a3..4d1d2895a6 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -1212,6 +1212,11 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { } break; } + case EvictTriggerMode::FREE_MEM: { + // For free mem eviction, all conditions checked in frontend, no check + // option in backend + return; + } default: break; } diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h index e0443ee640..5637224754 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h @@ -34,7 +34,8 @@ enum class EvictTriggerMode { ITERATION, // Trigger based on iteration steps MEM_UTIL, // Trigger based on memory usage MANUAL, // Manually triggered by upstream - ID_COUNT // Trigger based on id count + ID_COUNT, // Trigger based on id count + FREE_MEM, // Trigger based on free memory }; inline std::string to_string(EvictTriggerMode mode) { switch (mode) { @@ -48,6 +49,8 @@ inline std::string to_string(EvictTriggerMode mode) { return "MANUAL"; case EvictTriggerMode::ID_COUNT: return "ID_COUNT"; + case EvictTriggerMode::FREE_MEM: + return "FREE_MEM"; } } @@ -184,6 +187,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder { eviction_trigger_stats_log += "]"; break; } + case EvictTriggerMode::FREE_MEM: { + break; + } default: throw std::runtime_error("Unknown evict trigger mode"); } @@ -202,7 +208,6 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder { case EvictTriggerStrategy::BY_FEATURE_SCORE: { CHECK(feature_score_counter_decay_rates_.has_value()); - CHECK(training_id_eviction_trigger_count_.has_value()); CHECK(training_id_keep_count_.has_value()); CHECK(threshold_calculation_bucket_stride_.has_value()); CHECK(threshold_calculation_bucket_num_.has_value()); @@ -210,8 +215,6 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder { LOG(INFO) << "eviction config, trigger mode:" << to_string(trigger_mode_) << eviction_trigger_stats_log << ", strategy: " << to_string(trigger_strategy_) - << ", training_id_eviction_trigger_count: " - << training_id_eviction_trigger_count_.value() << ", training_id_keep_count:" << training_id_keep_count_.value() << ", ttls_in_mins: " << ttls_in_mins_.value()