Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
"""
)
split_precomputation += """
split_precomputation += """
// Define the rowwise adagrad optimizer state struct view
struct [[maybe_unused]] OptimizerState {
at::acc_type<cache_t, true> momentum;
Expand All @@ -197,17 +197,17 @@ def rowwise_adagrad() -> Dict[str, Any]:

at::acc_type<cache_t, true> multiplier = 0.0;
at::acc_type<cache_t, true> correction = 0.0;
if (threadIdx.x == 0) {
if (threadIdx.x == 0) {
auto new_sum_square_grads = g_avg_square;
// Update the optimizer state. Use optimizer state offloading only if

// Update the optimizer state. Use optimizer state offloading only if
// SSD and if enabled by the user
if (enable_optimizer_offloading) {
// Fetch the pointer to the optimizer state along the cache row
auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
new_sum_square_grads += optimizer->momentum;
optimizer->momentum = new_sum_square_grads;

} else {
new_sum_square_grads += momentum1[idx];
momentum1[idx] = new_sum_square_grads;
Expand Down Expand Up @@ -570,14 +570,17 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3)
if (adjustment_enabled) {
if (weight_decay_mode == 3) { // AdagradW (weight_decay_mode=3)
if (counter_halflife < 0) {
if (counter_halflife == -1) {
adjusted_multiplier = multiplier * sqrtf(row_counter[idx] * 1.0);
exp_reg_correction = 1.0 - weight_decay * learning_rate;
const auto lazy_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
const auto lazy_multiplier = powf(exp_reg_correction, min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
adjusted_multiplier *= lazy_multiplier;
exp_reg_correction *= lazy_multiplier;
}
else if (counter_halflife == -2) {
adjusted_multiplier = min(learning_rate * powf(row_counter[idx] * 1.0, 1.0), adjustment_ub) / (sqrtf(new_sum_square_grads) + eps);
}
exp_reg_correction = 1.0 - weight_decay * learning_rate;
const auto lazy_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
const auto lazy_multiplier = powf(exp_reg_correction, min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
adjusted_multiplier *= lazy_multiplier;
exp_reg_correction *= lazy_multiplier;
} else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
} else if (weight_decay_mode == 1) { // L2 regularization (coupled wd)
Expand Down Expand Up @@ -1040,8 +1043,8 @@ def adam() -> Dict[str, Any]:
DEVICE_INLINE momentum2_ph_t* momentum2_ptr(const int32_t D) {
// Cast to uintptr_t for pointer arithmetic
auto addr = reinterpret_cast<uintptr_t>(momentum1_ptr() + D);
// Cast back to momentum2_ph_t* and return

// Cast back to momentum2_ph_t* and return
return reinterpret_cast<momentum2_ph_t *>(addr);
}
};
Expand Down Expand Up @@ -1179,16 +1182,16 @@ def partial_rowwise_adam() -> Dict[str, Any]:
struct OptimizerState {
// momentum2 is a single value placed at the beginning of the struct
momentum2_ph_t momentum2;

// momentum1 is an array of values placed after momentum2, aligned to 4-byte boundary
// to support mixed state precision (e.g. FP32 momentum1 and FP16 momentum2)
alignas(4) momentum1_ph_t momentum1[1];

// momentum2_ptr returns a pointer to the beginning of the struct
DEVICE_INLINE momentum2_ph_t* momentum2_ptr() {
return &momentum2;
}

// momentum1_ptr returns a pointer to the beginning of the momentum1 array
DEVICE_INLINE momentum1_ph_t* momentum1_ptr() {
return momentum1;
Expand Down Expand Up @@ -1231,11 +1234,11 @@ def partial_rowwise_adam() -> Dict[str, Any]:
// Create a Vec4T for momentum1 values - either directly from momentum1_start
// or from a temporary aligned buffer if optimizer offloading is enabled
Vec4T<momentum1_ph_t> m_t;

if (enable_optimizer_offloading) {
// When offloading is enabled, we need to ensure proper alignment, so
// When offloading is enabled, we need to ensure proper alignment, so
// first copy to a temporary aligned array before loading to Vec4T
m_t = vec4_load_unaligned(momentum1_start + d);
m_t = vec4_load_unaligned(momentum1_start + d);
m_t.mul_(beta1);
m_t.fma_(grad, 1.0 - beta1);
vec4_store_unaligned(m_t, momentum1_start + d);
Expand All @@ -1247,7 +1250,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
m_t.fma_(grad, 1.0 - beta1);
m_t.store(&momentum1_start[d]);
}

// Update weights using the momentum values
weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x);
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y);
Expand Down
27 changes: 21 additions & 6 deletions fbgemm_gpu/test/tbe/training/backward_optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def execute_backward_optimizers_( # noqa C901
optimizer_state_dtypes: Optional[dict[str, SparseType]] = None,
use_rowwise_bias_correction: bool = False,
counter_weight_decay_mode: Optional[CounterWeightDecayMode] = None,
counter_halflife: int = -1,
) -> None:
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!

Expand Down Expand Up @@ -297,7 +298,7 @@ def execute_backward_optimizers_( # noqa C901
else:
counter_based_regularization = CounterBasedRegularizationDefinition(
counter_weight_decay_mode=CounterWeightDecayMode.ADAGRADW,
counter_halflife=-1,
counter_halflife=counter_halflife,
adjustment_iter=-1,
adjustment_ub=0.1,
learning_rate_mode=LearningRateMode.EQUAL,
Expand Down Expand Up @@ -893,11 +894,22 @@ def _get_wts_from_counter_adagrad_using_counter(
adjustment_iter > 0 and iter_ > adjustment_iter
):
if counter_weight_decay_mode == CounterWeightDecayMode.ADAGRADW:
adjusted_multiplier = torch.where(
row_counter > 0,
multiplier * torch.sqrt(row_counter),
torch.Tensor([0.0]),
)
if counter_halflife == -1:
adjusted_multiplier = torch.where(
row_counter > 0,
multiplier * torch.sqrt(row_counter),
torch.Tensor([0.0]),
)
elif counter_halflife == -2:
adjusted_multiplier = torch.where(
row_counter > 0,
torch.minimum(
torch.tensor([learning_rate]) * row_counter,
torch.tensor([adjustment_ub]),
)
/ denom,
torch.tensor([0.0]),
)
exp_reg_correction = torch.where(
row_counter > 0,
1.0 - weight_decay * learning_rate,
Expand Down Expand Up @@ -1177,6 +1189,7 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901
CounterWeightDecayMode.ADAGRADW,
]
),
counter_halflife=st.sampled_from([-1, -2]),
)
@settings(
verbosity=VERBOSITY,
Expand All @@ -1201,6 +1214,7 @@ def test_backward_optimizers_adagrad( # noqa C901
use_cpu: bool,
weight_decay_mode: WeightDecayMode,
counter_weight_decay_mode: CounterWeightDecayMode,
counter_halflife: int,
) -> None:
if (
pooling_mode == PoolingMode.NONE
Expand All @@ -1222,6 +1236,7 @@ def test_backward_optimizers_adagrad( # noqa C901
use_cpu,
weight_decay_mode,
counter_weight_decay_mode=counter_weight_decay_mode,
counter_halflife=counter_halflife,
)

@given(
Expand Down
Loading