From 4b3a7422a7af862a5fc9a0241ea09cb7725586b6 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Mon, 13 May 2024 08:01:47 -0700 Subject: [PATCH] add max norm support to PARTIAL_ROWWISE_ADAM (#2567) Summary: This adds max norm to partial rowwise adam optimizer. Differential Revision: D57018951 --- fbgemm_gpu/codegen/genscript/optimizers.py | 45 ++++++++++++++++++- .../tbe/training/backward_optimizers_test.py | 24 +++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 179b80f745..4c6f419bf7 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -1004,6 +1004,48 @@ def adam() -> Dict[str, Any]: def partial_rowwise_adam() -> Dict[str, Any]: + split_post_update = """ + if (max_norm > 0.0) { + CUDA_KERNEL_ASSERT(!(std::is_same::value && !cache_weights)); // not supported for uint8 yet + + // compute weight norm + at::acc_type weight_sum_square = 0.0; + for (int32_t vec = 0; + vec < max_vecs && (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH < D; + ++vec) { + const int32_t d = (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH; + Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); + weight_sum_square + += weight_new.acc.x * weight_new.acc.x + + weight_new.acc.y * weight_new.acc.y + + weight_new.acc.z * weight_new.acc.z + + weight_new.acc.w * weight_new.acc.w; + } + const at::acc_type weight_norm = + sqrtf(GROUP_REDUCE_ALL_SUM(weight_sum_square, at::acc_type)); + + // scale by max_norm if weight_norm exceeds max_norm + at::acc_type multiplier; + if (threadIdx.x == 0) { + multiplier = weight_norm > max_norm ? max_norm / weight_norm : 1.0f; + } + multiplier = SHFL_SYNC(multiplier, 0); + if (weight_norm > max_norm) { + for (int32_t vec = 0; + vec < max_vecs && (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH < D; + ++vec) { + const int32_t d = (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH; + Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); + + weight_new.acc.x *= multiplier; + weight_new.acc.y *= multiplier; + weight_new.acc.z *= multiplier; + weight_new.acc.w *= multiplier; + weight_row_template.store(weight_new, d, qparams_new); // qparams_new not used if embedding is not int8 + } + } + } + """ split_precomputation = """ at::acc_type g_local_sum_square = 0.0; """ @@ -1065,11 +1107,12 @@ def partial_rowwise_adam() -> Dict[str, Any]: OptimItem(ArgType.FLOAT, "beta2"), OptimItem(ArgType.FLOAT, "weight_decay"), OptimItem(ArgType.INT, "iter"), + OptimItem(ArgType.FLOAT, "max_norm", 0.0), ] ), "split_precomputation": split_precomputation, "split_weight_update": split_weight_update, - "split_post_update": "", + "split_post_update": split_post_update, "split_weight_update_cpu": split_weight_update_cpu, "has_cpu_support": False, "has_gpu_support": True, diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index e04f0b612f..282f14f5c3 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -92,6 +92,7 @@ def execute_backward_optimizers_( # noqa C901 weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, uvm_non_rowwise_momentum: bool = False, optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None, + max_norm: float = 0.0, ) -> None: # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! assume(not use_cpu or T * B * L * D <= 2048) @@ -125,7 +126,11 @@ def execute_backward_optimizers_( # noqa C901 ] ) ) - + # max_norm is only applicable to PARTIAL_ROWWISE_ADAM GPU version + assume( + max_norm == 0.0 + or (not use_cpu and optimizer == OptimType.PARTIAL_ROWWISE_ADAM) + ) assume(pooling_mode == PoolingMode.SUM or not weighted) # No bag ops only work on GPUs, no mixed, no weighted assume(not use_cpu or pooling_mode != PoolingMode.NONE) @@ -288,6 +293,7 @@ def execute_backward_optimizers_( # noqa C901 optimizer_kwargs["beta2"] = beta2 optimizer_kwargs["weight_decay"] = weight_decay optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes + optimizer_kwargs["max_norm"] = max_norm if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB): optimizer_kwargs["eps"] = eps @@ -514,6 +520,19 @@ def execute_backward_optimizers_( # noqa C901 ) - lr * weight_decay * bs[t].weight.cpu() ) + + if rowwise and max_norm > 0: + grads = bs[t].weight.grad.cpu().to_dense() + non_zero_grads = grads.abs().sum(dim=1, keepdim=True) > 0 + weights_norm = ( + weights_ref.norm(dim=1, keepdim=True) * non_zero_grads + ) + weights_ref = torch.where( + weights_norm > max_norm, + weights_ref * max_norm / weights_norm, + weights_ref, + ) + torch.testing.assert_close( weights_new.index_select(dim=0, index=xs[t].view(-1)).cpu(), weights_ref.index_select(dim=0, index=xs[t].view(-1).cpu()), @@ -765,6 +784,7 @@ def _get_wts_from_counter_adagrad_using_cowclip( ), use_cpu=use_cpu_strategy(), uvm_non_rowwise_momentum=st.booleans(), + max_norm=st.floats(min_value=0.01, max_value=1.0), ) @settings( verbosity=VERBOSITY, @@ -787,6 +807,7 @@ def test_backward_optimizers_adam( # noqa C901 pooling_mode: PoolingMode, use_cpu: bool, uvm_non_rowwise_momentum: bool, + max_norm: float, ) -> None: self.execute_backward_optimizers_( T, @@ -802,6 +823,7 @@ def test_backward_optimizers_adam( # noqa C901 pooling_mode, use_cpu, uvm_non_rowwise_momentum=uvm_non_rowwise_momentum, + max_norm=max_norm, ) @given(