diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 980b274e58..7141d78e10 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -186,7 +186,7 @@ def rowwise_adagrad() -> Dict[str, Any]: g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; """ ) - split_precomputation += """ + split_precomputation += """ // Define the rowwise adagrad optimizer state struct view struct [[maybe_unused]] OptimizerState { at::acc_type momentum; @@ -197,17 +197,17 @@ def rowwise_adagrad() -> Dict[str, Any]: at::acc_type multiplier = 0.0; at::acc_type correction = 0.0; - if (threadIdx.x == 0) { + if (threadIdx.x == 0) { auto new_sum_square_grads = g_avg_square; - - // Update the optimizer state. Use optimizer state offloading only if + + // Update the optimizer state. Use optimizer state offloading only if // SSD and if enabled by the user if (enable_optimizer_offloading) { // Fetch the pointer to the optimizer state along the cache row auto* optimizer = weight_row_template.template optimizer_state_ptr(); new_sum_square_grads += optimizer->momentum; optimizer->momentum = new_sum_square_grads; - + } else { new_sum_square_grads += momentum1[idx]; momentum1[idx] = new_sum_square_grads; @@ -570,14 +570,17 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3) if (adjustment_enabled) { if (weight_decay_mode == 3) { // AdagradW (weight_decay_mode=3) - if (counter_halflife < 0) { + if (counter_halflife == -1) { adjusted_multiplier = multiplier * sqrtf(row_counter[idx] * 1.0); - exp_reg_correction = 1.0 - weight_decay * learning_rate; - const auto lazy_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx]; - const auto lazy_multiplier = powf(exp_reg_correction, min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0); - adjusted_multiplier *= lazy_multiplier; - exp_reg_correction *= lazy_multiplier; } + else if (counter_halflife == -2) { + adjusted_multiplier = min(learning_rate * powf(row_counter[idx] * 1.0, 1.0), adjustment_ub) / (sqrtf(new_sum_square_grads) + eps); + } + exp_reg_correction = 1.0 - weight_decay * learning_rate; + const auto lazy_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx]; + const auto lazy_multiplier = powf(exp_reg_correction, min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0); + adjusted_multiplier *= lazy_multiplier; + exp_reg_correction *= lazy_multiplier; } else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2) exp_reg_correction = 1.0 - freq * weight_decay * learning_rate; } else if (weight_decay_mode == 1) { // L2 regularization (coupled wd) @@ -1040,8 +1043,8 @@ def adam() -> Dict[str, Any]: DEVICE_INLINE momentum2_ph_t* momentum2_ptr(const int32_t D) { // Cast to uintptr_t for pointer arithmetic auto addr = reinterpret_cast(momentum1_ptr() + D); - - // Cast back to momentum2_ph_t* and return + + // Cast back to momentum2_ph_t* and return return reinterpret_cast(addr); } }; @@ -1179,16 +1182,16 @@ def partial_rowwise_adam() -> Dict[str, Any]: struct OptimizerState { // momentum2 is a single value placed at the beginning of the struct momentum2_ph_t momentum2; - + // momentum1 is an array of values placed after momentum2, aligned to 4-byte boundary // to support mixed state precision (e.g. FP32 momentum1 and FP16 momentum2) alignas(4) momentum1_ph_t momentum1[1]; - + // momentum2_ptr returns a pointer to the beginning of the struct DEVICE_INLINE momentum2_ph_t* momentum2_ptr() { return &momentum2; } - + // momentum1_ptr returns a pointer to the beginning of the momentum1 array DEVICE_INLINE momentum1_ph_t* momentum1_ptr() { return momentum1; @@ -1231,11 +1234,11 @@ def partial_rowwise_adam() -> Dict[str, Any]: // Create a Vec4T for momentum1 values - either directly from momentum1_start // or from a temporary aligned buffer if optimizer offloading is enabled Vec4T m_t; - + if (enable_optimizer_offloading) { - // When offloading is enabled, we need to ensure proper alignment, so + // When offloading is enabled, we need to ensure proper alignment, so // first copy to a temporary aligned array before loading to Vec4T - m_t = vec4_load_unaligned(momentum1_start + d); + m_t = vec4_load_unaligned(momentum1_start + d); m_t.mul_(beta1); m_t.fma_(grad, 1.0 - beta1); vec4_store_unaligned(m_t, momentum1_start + d); @@ -1247,7 +1250,7 @@ def partial_rowwise_adam() -> Dict[str, Any]: m_t.fma_(grad, 1.0 - beta1); m_t.store(&momentum1_start[d]); } - + // Update weights using the momentum values weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x); weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y); diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index 771267c0b8..14477b3883 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -108,6 +108,7 @@ def execute_backward_optimizers_( # noqa C901 optimizer_state_dtypes: Optional[dict[str, SparseType]] = None, use_rowwise_bias_correction: bool = False, counter_weight_decay_mode: Optional[CounterWeightDecayMode] = None, + counter_halflife: int = -1, ) -> None: # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! @@ -297,7 +298,7 @@ def execute_backward_optimizers_( # noqa C901 else: counter_based_regularization = CounterBasedRegularizationDefinition( counter_weight_decay_mode=CounterWeightDecayMode.ADAGRADW, - counter_halflife=-1, + counter_halflife=counter_halflife, adjustment_iter=-1, adjustment_ub=0.1, learning_rate_mode=LearningRateMode.EQUAL, @@ -893,11 +894,22 @@ def _get_wts_from_counter_adagrad_using_counter( adjustment_iter > 0 and iter_ > adjustment_iter ): if counter_weight_decay_mode == CounterWeightDecayMode.ADAGRADW: - adjusted_multiplier = torch.where( - row_counter > 0, - multiplier * torch.sqrt(row_counter), - torch.Tensor([0.0]), - ) + if counter_halflife == -1: + adjusted_multiplier = torch.where( + row_counter > 0, + multiplier * torch.sqrt(row_counter), + torch.Tensor([0.0]), + ) + elif counter_halflife == -2: + adjusted_multiplier = torch.where( + row_counter > 0, + torch.minimum( + torch.tensor([learning_rate]) * row_counter, + torch.tensor([adjustment_ub]), + ) + / denom, + torch.tensor([0.0]), + ) exp_reg_correction = torch.where( row_counter > 0, 1.0 - weight_decay * learning_rate, @@ -1177,6 +1189,7 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901 CounterWeightDecayMode.ADAGRADW, ] ), + counter_halflife=st.sampled_from([-1, -2]), ) @settings( verbosity=VERBOSITY, @@ -1201,6 +1214,7 @@ def test_backward_optimizers_adagrad( # noqa C901 use_cpu: bool, weight_decay_mode: WeightDecayMode, counter_weight_decay_mode: CounterWeightDecayMode, + counter_halflife: int, ) -> None: if ( pooling_mode == PoolingMode.NONE @@ -1222,6 +1236,7 @@ def test_backward_optimizers_adagrad( # noqa C901 use_cpu, weight_decay_mode, counter_weight_decay_mode=counter_weight_decay_mode, + counter_halflife=counter_halflife, ) @given(