Skip to content

Commit

Permalink
add max norm support to PARTIAL_ROWWISE_ADAM (#2567)
Browse files Browse the repository at this point in the history
Summary:

This adds max norm to partial rowwise adam optimizer.

Differential Revision: D57018951
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed May 13, 2024
1 parent ef263d5 commit 4b3a742
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
45 changes: 44 additions & 1 deletion fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,48 @@ def adam() -> Dict[str, Any]:


def partial_rowwise_adam() -> Dict[str, Any]:
split_post_update = """
if (max_norm > 0.0) {
CUDA_KERNEL_ASSERT(!(std::is_same<emb_t, uint8_t>::value && !cache_weights)); // not supported for uint8 yet
// compute weight norm
at::acc_type<cache_t, true> weight_sum_square = 0.0;
for (int32_t vec = 0;
vec < max_vecs && (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH < D;
++vec) {
const int32_t d = (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH;
Vec4TAcc<cache_t> weight_new = weight_row_template.load(d, qparams_template);
weight_sum_square
+= weight_new.acc.x * weight_new.acc.x
+ weight_new.acc.y * weight_new.acc.y
+ weight_new.acc.z * weight_new.acc.z
+ weight_new.acc.w * weight_new.acc.w;
}
const at::acc_type<cache_t, true> weight_norm =
sqrtf(GROUP_REDUCE_ALL_SUM(weight_sum_square, at::acc_type<cache_t, true>));
// scale by max_norm if weight_norm exceeds max_norm
at::acc_type<cache_t, true> multiplier;
if (threadIdx.x == 0) {
multiplier = weight_norm > max_norm ? max_norm / weight_norm : 1.0f;
}
multiplier = SHFL_SYNC(multiplier, 0);
if (weight_norm > max_norm) {
for (int32_t vec = 0;
vec < max_vecs && (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH < D;
++vec) {
const int32_t d = (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH;
Vec4TAcc<cache_t> weight_new = weight_row_template.load(d, qparams_template);
weight_new.acc.x *= multiplier;
weight_new.acc.y *= multiplier;
weight_new.acc.z *= multiplier;
weight_new.acc.w *= multiplier;
weight_row_template.store(weight_new, d, qparams_new); // qparams_new not used if embedding is not int8
}
}
}
"""
split_precomputation = """
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
"""
Expand Down Expand Up @@ -1065,11 +1107,12 @@ def partial_rowwise_adam() -> Dict[str, Any]:
OptimItem(ArgType.FLOAT, "beta2"),
OptimItem(ArgType.FLOAT, "weight_decay"),
OptimItem(ArgType.INT, "iter"),
OptimItem(ArgType.FLOAT, "max_norm", 0.0),
]
),
"split_precomputation": split_precomputation,
"split_weight_update": split_weight_update,
"split_post_update": "",
"split_post_update": split_post_update,
"split_weight_update_cpu": split_weight_update_cpu,
"has_cpu_support": False,
"has_gpu_support": True,
Expand Down
24 changes: 23 additions & 1 deletion fbgemm_gpu/test/tbe/training/backward_optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def execute_backward_optimizers_( # noqa C901
weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE,
uvm_non_rowwise_momentum: bool = False,
optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None,
max_norm: float = 0.0,
) -> None:
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
assume(not use_cpu or T * B * L * D <= 2048)
Expand Down Expand Up @@ -125,7 +126,11 @@ def execute_backward_optimizers_( # noqa C901
]
)
)

# max_norm is only applicable to PARTIAL_ROWWISE_ADAM GPU version
assume(
max_norm == 0.0
or (not use_cpu and optimizer == OptimType.PARTIAL_ROWWISE_ADAM)
)
assume(pooling_mode == PoolingMode.SUM or not weighted)
# No bag ops only work on GPUs, no mixed, no weighted
assume(not use_cpu or pooling_mode != PoolingMode.NONE)
Expand Down Expand Up @@ -288,6 +293,7 @@ def execute_backward_optimizers_( # noqa C901
optimizer_kwargs["beta2"] = beta2
optimizer_kwargs["weight_decay"] = weight_decay
optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes
optimizer_kwargs["max_norm"] = max_norm

if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB):
optimizer_kwargs["eps"] = eps
Expand Down Expand Up @@ -514,6 +520,19 @@ def execute_backward_optimizers_( # noqa C901
)
- lr * weight_decay * bs[t].weight.cpu()
)

if rowwise and max_norm > 0:
grads = bs[t].weight.grad.cpu().to_dense()
non_zero_grads = grads.abs().sum(dim=1, keepdim=True) > 0
weights_norm = (
weights_ref.norm(dim=1, keepdim=True) * non_zero_grads
)
weights_ref = torch.where(
weights_norm > max_norm,
weights_ref * max_norm / weights_norm,
weights_ref,
)

torch.testing.assert_close(
weights_new.index_select(dim=0, index=xs[t].view(-1)).cpu(),
weights_ref.index_select(dim=0, index=xs[t].view(-1).cpu()),
Expand Down Expand Up @@ -765,6 +784,7 @@ def _get_wts_from_counter_adagrad_using_cowclip(
),
use_cpu=use_cpu_strategy(),
uvm_non_rowwise_momentum=st.booleans(),
max_norm=st.floats(min_value=0.01, max_value=1.0),
)
@settings(
verbosity=VERBOSITY,
Expand All @@ -787,6 +807,7 @@ def test_backward_optimizers_adam( # noqa C901
pooling_mode: PoolingMode,
use_cpu: bool,
uvm_non_rowwise_momentum: bool,
max_norm: float,
) -> None:
self.execute_backward_optimizers_(
T,
Expand All @@ -802,6 +823,7 @@ def test_backward_optimizers_adam( # noqa C901
pooling_mode,
use_cpu,
uvm_non_rowwise_momentum=uvm_non_rowwise_momentum,
max_norm=max_norm,
)

@given(
Expand Down

0 comments on commit 4b3a742

Please sign in to comment.