From 51841e3d5878701744f255d580cc7fe06036c8c2 Mon Sep 17 00:00:00 2001 From: Gantaphon Chalumporn Date: Tue, 25 Nov 2025 14:03:57 -0800 Subject: [PATCH] Add support rowwise_adagrad_wtith_counter on CPU (#5146) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2145 Initial support has been added in D81998586. Differential Revision: D87104079 --- fbgemm_gpu/codegen/genscript/optimizers.py | 129 +++++++++++++++++++-- 1 file changed, 119 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index c61e6843f9..f54d07d544 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -600,23 +600,132 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: exp_reg_correction = SHFL_SYNC(exp_reg_correction, 0); """ split_weight_update_cpu = """ + auto offset_idx = momentum1_offsets_data[feature_begin] + idx; + + // Counter update logic with halflife decay + at::acc_type freq = 1.0; + at::acc_type tail_id_threshold_val = tail_id_threshold; + if (max_counter != 0.0) { + if (is_tail_id_thresh_ratio == 1) { + tail_id_threshold_val = std::floor(tail_id_threshold * max_counter); + } + + if (counter_halflife > 0) { + // Decay based on counter_halflife + const auto iter_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx]; + const auto counter_log_rho = std::log(2.0) / counter_halflife; + row_counter_host[offset_idx] = 1.0 + std::exp(-iter_delta * counter_log_rho) * row_counter_host[offset_idx]; + } else if (counter_halflife == 0) { + // Count only 1 (appear or not) + row_counter_host[offset_idx] = 1.0; + } else { + // Count raw appearance without decaying + row_counter_host[offset_idx] += 1.0; + } + } + freq = counter_halflife / row_counter_host[offset_idx]; + + // Compute gradient statistics at::acc_type g_local_sum_square = 0.0; + at::acc_type w_local_sum_square = 0.0; + for (int64_t d = 0; d < D; ++d) { - g_local_sum_square += grad_buffer[d] * grad_buffer[d]; + auto grad = grad_buffer[d]; + // For L2 regularization (weight_decay_mode=1), add weight_decay to gradient before other computation + if (weight_decay_mode == 1) { + grad += weight_decay * host_weights_data[embedding_begin + d]; + } + g_local_sum_square += grad * grad; + + // COW-clip (regularization_mode=4) requires weight norm + if (regularization_mode == 4) { + const auto weight = host_weights_data[embedding_begin + d]; + w_local_sum_square += weight * weight; + } } - auto g_avg_square = g_local_sum_square / D; - auto offset_idx = momentum1_offsets_data[feature_begin] + idx; + + const auto g_sum_square = g_local_sum_square; + const auto g_avg_square = g_sum_square / D; + const auto w_sum_square = w_local_sum_square; + + // Update momentum at::acc_type new_sum_square_grads = momentum1_host[offset_idx] + g_avg_square; momentum1_host[offset_idx] = new_sum_square_grads; - at::acc_type multiplier; - multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - const auto iter_delta = iter * 1.0 - prev_iter_host[offset_idx]; + const auto multiplier = learning_rate / (std::sqrt(new_sum_square_grads) + eps); + const auto adjustment_enabled = adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter); + + // Compute adjusted multiplier and regularization correction + at::acc_type adjusted_multiplier = 0.0; + at::acc_type exp_reg_correction = 0.0; + + if (regularization_mode == 3) { + // Counter-based regularization (regularization_mode=3) + adjusted_multiplier = multiplier; + if (learning_rate_mode >= 0 && adjustment_enabled) { + if (row_counter_host[offset_idx] > tail_id_threshold_val) { + if (learning_rate_mode == 0) { + adjusted_multiplier = multiplier * std::max(std::min(std::pow(max_counter / (row_counter_host[offset_idx] + 1.0), adjustment_ub), 10.0), 1.0); + } else if (learning_rate_mode == 1) { + adjusted_multiplier = multiplier * std::min(std::max(std::pow((row_counter_host[offset_idx] + 1.0) / max_counter, adjustment_ub), 0.1), 1.0); + } else if (learning_rate_mode == 2) { + adjusted_multiplier = learning_rate / (std::sqrt(adjustment_ub * row_counter_host[offset_idx]) + eps); + } + } + } + } else if (regularization_mode == 4) { + // COW-clip (regularization_mode=4) + const auto clip_thresh = row_counter_host[offset_idx] * std::max(weight_norm_coefficient * std::sqrt(w_sum_square), lower_bound); + adjusted_multiplier = std::min(1.0f, static_cast(clip_thresh / std::sqrt(g_sum_square))) * multiplier; + } else { + // Default: no special regularization + adjusted_multiplier = multiplier; + } + + // Compute regularization correction + exp_reg_correction = 1.0; + if (regularization_mode == 3) { + // Counter-based regularization (regularization_mode=3) + if (adjustment_enabled) { + if (weight_decay_mode == 3) { + // AdagradW (weight_decay_mode=3) + if (counter_halflife == -1) { + adjusted_multiplier = multiplier * std::sqrt(row_counter_host[offset_idx] * 1.0); + } else if (counter_halflife == -2) { + adjusted_multiplier = std::min(static_cast(learning_rate * std::pow(row_counter_host[offset_idx] * 1.0, 1.0)), adjustment_ub) / (std::sqrt(new_sum_square_grads) + eps); + } + exp_reg_correction = 1.0 - weight_decay * learning_rate; + const auto lazy_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx]; + const auto lazy_multiplier = std::pow(exp_reg_correction, std::min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0); + adjusted_multiplier *= lazy_multiplier; + exp_reg_correction *= lazy_multiplier; + } else if (weight_decay_mode == 2) { + // Decoupled weight decay (weight_decay_mode=2) + exp_reg_correction = 1.0 - freq * weight_decay * learning_rate; + } else if (weight_decay_mode == 1) { + // L2 regularization (coupled wd) + exp_reg_correction = 1.0 - freq * weight_decay * multiplier; + } + } + } else if (regularization_mode == 4) { + // COW-clip (regularization_mode=4) + if (weight_decay_mode == 2) { + // Decoupled weight decay (weight_decay_mode=2) + exp_reg_correction = 1.0 - weight_decay * learning_rate; + } else if (weight_decay_mode == 1) { + // L2 regularization (coupled wd) + exp_reg_correction = 1.0 - weight_decay * adjusted_multiplier; + } + } else { + // Default regularization + exp_reg_correction = 1.0; + } + + // Update prev_iter prev_iter_host[offset_idx] = iter * 1.0; - const auto exp_reg = 1.0 / (weight_decay * multiplier + 1.0); - const auto exp_reg_correction = powf(exp_reg, iter_delta); + + // Apply weight updates for (int64_t d = 0; d < D; ++d) { - const auto weight = host_weights_data[embedding_begin + d]; - host_weights_data[embedding_begin + d] = exp_reg_correction * weight - exp_reg * multiplier * grad_buffer[d]; + host_weights_data[embedding_begin + d] = exp_reg_correction * host_weights_data[embedding_begin + d] - adjusted_multiplier * grad_buffer[d]; } """