Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,19 @@ class EvictionPolicy(NamedTuple):
None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
)
training_id_eviction_trigger_count: Optional[list[int]] = (
None # training_id_eviction_trigger_count for each table
None # Number of training IDs that, when exceeded, will trigger eviction for each table.
)
training_id_keep_count: Optional[list[int]] = (
None # training_id_keep_count for each table
None # Target number of training IDs to retain in each table after eviction.
)
l2_weight_thresholds: Optional[list[float]] = (
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
)
threshold_calculation_bucket_stride: Optional[float] = (
0.2 # threshold_calculation_bucket_stride if eviction strategy is feature score
0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
)
threshold_calculation_bucket_num: Optional[int] = (
1000000 # 1M, threshold_calculation_bucket_num if eviction strategy is feature score
1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction.
)
interval_for_insufficient_eviction_s: int = (
# wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
Expand All @@ -114,10 +114,16 @@ class EvictionPolicy(NamedTuple):
24 * 3600 # 1 day, interval for feature statistics decay
)
meta_header_lens: Optional[list[int]] = None # metaheader length for each table
eviction_free_mem_threshold_gb: Optional[int] = (
None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
)
eviction_free_mem_check_interval_batch: Optional[int] = (
None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
)

def validate(self) -> None:
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4], (
"eviction_trigger_mode must be 0, 1, 2, 3 or 4 "
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
"eviction_trigger_mode must be 0, 1, 2, 3, 4, 5"
f"actual {self.eviction_trigger_mode}"
)
if self.eviction_trigger_mode == 0:
Expand All @@ -143,6 +149,13 @@ def validate(self) -> None:
assert (
self.training_id_eviction_trigger_count is not None
), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4"
elif self.eviction_trigger_mode == 5:
assert (
self.eviction_free_mem_threshold_gb is not None
), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5"
assert (
self.eviction_free_mem_check_interval_batch is not None
), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5"

if self.eviction_strategy == 0:
assert self.ttls_in_mins is not None, (
Expand Down Expand Up @@ -240,6 +253,19 @@ def validate(self) -> None:
), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"


class KVZCHEvictionTBEConfig(NamedTuple):
# Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem
kvzch_eviction_trigger_mode: Optional[int] = None
# Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
eviction_free_mem_threshold_gb: Optional[int] = None
# Number of batches between checks for free memory threshold when using free_mem trigger mode.
eviction_free_mem_check_interval_batch: Optional[int] = None
# The width of each feature score bucket used for threshold calculation in feature score-based eviction.
threshold_calculation_bucket_stride: Optional[float] = None
# Total number of feature score buckets used for threshold calculation in feature score-based eviction.
threshold_calculation_bucket_num: Optional[int] = None


class BackendType(enum.IntEnum):
SSD = 0
DRAM = 1
Expand Down
144 changes: 128 additions & 16 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import time
from functools import cached_property
from math import floor, log2
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, ClassVar, Optional, Union
import torch # usort:skip
import weakref

# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
Expand All @@ -34,6 +35,7 @@
BoundsCheckMode,
CacheAlgorithm,
EmbeddingLocation,
EvictionPolicy,
get_bounds_check_version_for_platform,
KVZCHParams,
PoolingMode,
Expand All @@ -54,6 +56,8 @@
from torch import distributed as dist, nn, Tensor # usort:skip
from dataclasses import dataclass

import psutil

from torch.autograd.profiler import record_function

from ..cache import get_unique_indices_v2
Expand Down Expand Up @@ -100,6 +104,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
_local_instance_index: int = -1
res_params: RESParams
table_names: list[str]
_all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet()
_first_instance_ref: ClassVar[weakref.ref] = None
_eviction_triggered: ClassVar[bool] = False

def __init__(
self,
Expand Down Expand Up @@ -179,6 +186,7 @@ def __init__(
table_names: Optional[list[str]] = None,
use_rowwise_bias_correction: bool = False, # For Adam use
optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006
pg: Optional[dist.ProcessGroup] = None,
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

Expand Down Expand Up @@ -567,6 +575,10 @@ def __init__(
# loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend
self.load_state_dict: bool = False

SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self)
if SSDTableBatchedEmbeddingBags._first_instance_ref is None:
SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self)

# create tbe unique id using rank index | local tbe idx
if tbe_unique_id == -1:
SSDTableBatchedEmbeddingBags._local_instance_index += 1
Expand All @@ -584,6 +596,7 @@ def __init__(
self.tbe_unique_id = tbe_unique_id
self.l2_cache_size = l2_cache_size
logging.info(f"tbe_unique_id: {tbe_unique_id}")
self.enable_free_mem_trigger_eviction: bool = False
if self.backend_type == BackendType.SSD:
logging.info(
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
Expand Down Expand Up @@ -688,25 +701,31 @@ def __init__(
if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
else self.l2_cache_size
)
kv_zch_params = self.kv_zch_params
eviction_policy = self.kv_zch_params.eviction_policy
if eviction_policy.eviction_trigger_mode == 5:
# If trigger mode is free_mem(5), populate config
self.set_free_mem_eviction_trigger_config(eviction_policy)

# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
self.kv_zch_params.eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
self.kv_zch_params.eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
self.kv_zch_params.eviction_policy.training_id_keep_count, # training_id_keep_count for each table
self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
eviction_policy.training_id_keep_count, # training_id_keep_count for each table
eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
table_dims.tolist() if table_dims is not None else None,
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s,
self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s,
self.kv_zch_params.eviction_policy.interval_for_feature_statistics_decay_s,
eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
eviction_policy.interval_for_insufficient_eviction_s,
eviction_policy.interval_for_sufficient_eviction_s,
eviction_policy.interval_for_feature_statistics_decay_s,
)
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
self.cache_row_dim,
Expand Down Expand Up @@ -1065,6 +1084,8 @@ def __init__(

self.bounds_check_version: int = get_bounds_check_version_for_platform()

self._pg = pg

@cached_property
def cache_row_dim(self) -> int:
"""
Expand Down Expand Up @@ -2042,6 +2063,9 @@ def _prefetch( # noqa C901
if dist.get_rank() == 0:
self._report_kv_backend_stats()

# May trigger eviction if free mem trigger mode enabled before get cuda
self.may_trigger_eviction()

# Fetch data from SSD
if linear_cache_indices.numel() > 0:
self.record_function_via_dummy_profile(
Expand Down Expand Up @@ -4650,3 +4674,91 @@ def direct_write_embedding(
)

# Return control to the main stream without waiting for the backend operation to complete

def get_free_cpu_memory_gb(self) -> float:
mem = psutil.virtual_memory()
return mem.available / (1024**3)

@classmethod
def trigger_evict_in_all_tbes(cls) -> None:
for tbe in cls._all_tbe_instances:
tbe.ssd_db.trigger_feature_evict()

@classmethod
def tbe_has_ongoing_eviction(cls) -> bool:
for tbe in cls._all_tbe_instances:
if tbe.ssd_db.is_evicting():
return True
return False

def set_free_mem_eviction_trigger_config(
self, eviction_policy: EvictionPolicy
) -> None:
self.enable_free_mem_trigger_eviction = True
self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode
assert (
eviction_policy.eviction_free_mem_check_interval_batch is not None
), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode"
self.eviction_free_mem_check_interval_batch: int = (
eviction_policy.eviction_free_mem_check_interval_batch
)
assert (
eviction_policy.eviction_free_mem_threshold_gb is not None
), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode"
self.eviction_free_mem_threshold_gb: int = (
eviction_policy.eviction_free_mem_threshold_gb
)
logging.info(
f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}"
)

def may_trigger_eviction(self) -> None:
def is_first_tbe() -> bool:
first = SSDTableBatchedEmbeddingBags._first_instance_ref
return first is not None and first() is self

# We assume that the eviction time is less than free mem check interval time
# So every time we reach this check, all evictions in all tbes should be finished.
# We only need to check the first tbe because all tbes share the same free mem,
# once the first tbe detect need to trigger eviction, it will call trigger func
# in all tbes from _all_tbe_instances
if (
self.enable_free_mem_trigger_eviction
and self.step % self.eviction_free_mem_check_interval_batch == 0
and self.training
and is_first_tbe()
):
if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction():
SSDTableBatchedEmbeddingBags._eviction_triggered = False

free_cpu_mem_gb = self.get_free_cpu_memory_gb()
local_evict_trigger = int(
free_cpu_mem_gb < self.eviction_free_mem_threshold_gb
)
tensor_flag = torch.tensor(
local_evict_trigger,
device=self.current_device,
dtype=torch.int,
)
world_size = dist.get_world_size(self._pg)
if world_size > 1:
dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg)
global_evict_trigger = tensor_flag.item()
else:
global_evict_trigger = local_evict_trigger
if (
global_evict_trigger >= 1
and SSDTableBatchedEmbeddingBags._eviction_triggered
):
logging.info(
f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true"
)
if (
global_evict_trigger >= 1
and not SSDTableBatchedEmbeddingBags._eviction_triggered
):
SSDTableBatchedEmbeddingBags._eviction_triggered = True
SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes()
logging.info(
f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction"
)
23 changes: 13 additions & 10 deletions fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -1177,17 +1177,8 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {

void compact() override {}

void trigger_feature_evict(
std::optional<uint32_t> inplace_update_ts = std::nullopt) {
void trigger_feature_evict() {
if (feature_evict_) {
if (inplace_update_ts.has_value() &&
feature_evict_config_.value()->trigger_strategy_ ==
EvictTriggerStrategy::BY_TIMESTAMP_THRESHOLD) {
auto* tt_evict = dynamic_cast<TimeThresholdBasedEvict<weight_type>*>(
feature_evict_.get());
CHECK(tt_evict != nullptr);
tt_evict->set_eviction_timestamp_threshold(inplace_update_ts.value());
}
feature_evict_->trigger_evict();
}
}
Expand Down Expand Up @@ -1223,6 +1214,11 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
}
break;
}
case EvictTriggerMode::FREE_MEM: {
// For free mem eviction, all conditions checked in frontend, no check
// option in backend
return;
}
default:
break;
}
Expand Down Expand Up @@ -1271,6 +1267,13 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
}
}

bool is_evicting() override {
if (feature_evict_) {
return feature_evict_->is_evicting();
}
return false;
}

// for inference only, this logs the total hit/miss count
// this should be called at the end of full/delta snapshot chunk by chunk
// update
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
impl_->set_backend_return_whole_row(backend_return_whole_row);
}

void trigger_feature_evict() {
impl_->trigger_feature_evict();
}

bool is_evicting() {
return impl_->is_evicting();
}

void set_feature_score_metadata_cuda(
at::Tensor indices,
at::Tensor count,
Expand Down
Loading
Loading