From a3649c930af5cdf6b129aac3ef5d2f882e887be1 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 15 Aug 2025 10:27:05 -0700 Subject: [PATCH] Add unit tests for Full Adam, pt 1 Summary: - Add unit tests for Full Adam Reviewed By: emlin Differential Revision: D79975184 --- .../tbe/ssd/ssd_split_tbe_training_test.py | 442 ++++++++++++++++-- 1 file changed, 401 insertions(+), 41 deletions(-) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 5282e8ef49..d54e05d8f5 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -707,6 +707,7 @@ def generate_ssd_tbes( enable_raw_embedding_streaming=enable_raw_embedding_streaming, backend_type=backend_type, res_params=res_params, + optimizer_state_dtypes=optimizer_state_dtypes, ).cuda() if bulk_init_chunk_size > 0 and lazy_bulk_init_enabled: @@ -1580,12 +1581,35 @@ def test_ssd_emb_state_dict( ) @given( + **default_st, + **{ + "m1_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), + "m2_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), + }, + num_buckets=st.integers(min_value=10, max_value=15), + enable_optimizer_offloading=st.booleans(), + backend_type=st.sampled_from([BackendType.SSD, BackendType.DRAM]), bulk_init_chunk_size=st.sampled_from([0, 204800]), lazy_bulk_init_enabled=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_emb_state_dict_partial_rowwise_adam( - self, bulk_init_chunk_size: int, lazy_bulk_init_enabled: bool + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + m1_dtype: SparseType, + m2_dtype: SparseType, + bulk_init_chunk_size: int, + lazy_bulk_init_enabled: bool, + # pyre-ignore[2] + **kwargs, ) -> None: # Constants lr = 0.5 @@ -1595,15 +1619,6 @@ def test_ssd_emb_state_dict_partial_rowwise_adam( beta2 = 0.99 weight_decay = 0.01 - T = 4 - B = 10 - D = 128 - L = 10 - log_E = 4 - weights_precision = SparseType.FP32 - output_dtype = SparseType.FP32 - pooling_mode = PoolingMode.SUM - # Generate embedding modules and inputs ( emb, @@ -1614,7 +1629,7 @@ def test_ssd_emb_state_dict_partial_rowwise_adam( B, log_E, L, - False, # weighted + weighted, lr=lr, eps=eps, weight_decay=weight_decay, @@ -1629,6 +1644,10 @@ def test_ssd_emb_state_dict_partial_rowwise_adam( share_table=True, bulk_init_chunk_size=bulk_init_chunk_size, lazy_bulk_init_enabled=lazy_bulk_init_enabled, + optimizer_state_dtypes={ + "momentum1": m1_dtype, + "momentum2": m2_dtype, + }, ) Es = [emb.embedding_specs[t][0] for t in range(T)] @@ -1725,6 +1744,166 @@ def test_ssd_emb_state_dict_partial_rowwise_adam( rtol=tolerance, ) + @given( + **default_st, + **{ + "m1_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), + "m2_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), + }, + num_buckets=st.integers(min_value=10, max_value=15), + enable_optimizer_offloading=st.booleans(), + backend_type=st.sampled_from([BackendType.SSD, BackendType.DRAM]), + bulk_init_chunk_size=st.sampled_from([0, 204800]), + lazy_bulk_init_enabled=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_ssd_emb_state_dict_adam( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + m1_dtype: SparseType, + m2_dtype: SparseType, + bulk_init_chunk_size: int, + lazy_bulk_init_enabled: bool, + # pyre-ignore[2] + **kwargs, + ) -> None: + # Constants + lr = 0.5 + eps = 0.2 + ssd_shards = 2 + beta1 = 0.9 + beta2 = 0.99 + weight_decay = 0.01 + + # Generate embedding modules and inputs + ( + emb, + emb_ref, + ) = self.generate_ssd_tbes( + T, + D, + B, + log_E, + L, + weighted, + lr=lr, + eps=eps, + weight_decay=weight_decay, + beta1=beta1, + beta2=beta2, + ssd_shards=ssd_shards, + optimizer=OptimType.ADAM, + cache_set_scale=0.2, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + share_table=True, + bulk_init_chunk_size=bulk_init_chunk_size, + lazy_bulk_init_enabled=lazy_bulk_init_enabled, + optimizer_state_dtypes={ + "momentum1": m1_dtype, + "momentum2": m2_dtype, + }, + ) + + Es = [emb.embedding_specs[t][0] for t in range(T)] + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=True, + ) + + # Execute forward + output_ref_list, output = self.execute_ssd_forward_( + emb, + emb_ref, + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + B, + L, + False, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + + # Execute backward + self.execute_ssd_backward_( + output_ref_list, + output, + B, + D, + pooling_mode, + batch_size_per_feature_per_rank, + ) + + emb.flush() + + tolerance = 1.0e-2 + + split_optimizer_states = self.split_optimizer_states_(emb) + + # Compare emb state dict with expected values from nn.EmbeddingBag + emb_state_dict, _, _, _ = emb.split_embedding_weights(no_snapshot=False) + for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): + (m1, m2) = split_optimizer_states[t] + # Some optimizers have non-float momentum values + # pyre-ignore[16] + ref_grad = emb_ref[f].weight.grad.cpu().to_dense() + ref_weights = emb_ref[f].weight.cpu() + + # Compare momentum2 values: (1 - beta2) * dL^2 + m2_ref = ref_grad.pow(2) * (1.0 - beta2) + self.assert_close_(m2, m2_ref) + + # Compare momentum1 values: (1 - beta1) * dL + m1_ref = ref_grad * (1.0 - beta1) + self.assert_close_(m1, m1_ref) + + # Bias corrections + iter_ = emb.iter.item() + v_hat_t = m2_ref / (1 - beta2**iter_) + m_hat_t = m1_ref / (1 - beta1**iter_) + + # Weight update + ref_weights_updated = ( + torch.addcdiv( + ref_weights, + value=-lr, + tensor1=m_hat_t, + tensor2=v_hat_t.sqrt_().add_(eps), + ) + - lr * weight_decay * ref_weights + ) + + # Compare weights + torch.testing.assert_close( + # pyre-fixme [16] + emb_state_dict[t].full_tensor().float(), + ref_weights_updated, + atol=tolerance, + rtol=tolerance, + ) + def execute_ssd_cache_pipeline_( # noqa C901 self, T: int, @@ -2405,26 +2584,14 @@ def test_kv_emb_state_dict( self.assertTrue(len(metadata_list[table_index].size()) == 2) @given( + **default_st, **{ - "T": st.integers(min_value=1, max_value=10), - "D": st.integers(min_value=2, max_value=128), - "B": st.integers(min_value=1, max_value=128), - "log_E": st.integers(min_value=3, max_value=5), - "L": st.integers(min_value=0, max_value=20), - "weighted": st.just(False), - "cache_set_scale": st.just(0.0), - "pooling_mode": st.just(PoolingMode.NONE), - "weights_precision": st.sampled_from([SparseType.FP16, SparseType.FP32]), - "output_dtype": st.sampled_from([SparseType.FP16, SparseType.FP32]), "m1_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), "m2_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), - "share_table": st.just(False), - "trigger_bounds_check": st.just(False), - "mixed_B": st.just(False), }, - num_buckets=st.just(10), - enable_optimizer_offloading=st.just(True), - backend_type=st.sampled_from([BackendType.DRAM]), + num_buckets=st.integers(min_value=10, max_value=15), + enable_optimizer_offloading=st.booleans(), + backend_type=st.sampled_from([BackendType.SSD, BackendType.DRAM]), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_kv_emb_state_dict_partial_rowwise_adam( @@ -2624,24 +2791,217 @@ def test_kv_emb_state_dict_partial_rowwise_adam( ) @given( + **default_st, **{ - "T": st.integers(min_value=1, max_value=10), - "D": st.integers(min_value=2, max_value=128), - "B": st.integers(min_value=1, max_value=128), - "log_E": st.integers(min_value=3, max_value=5), - "L": st.integers(min_value=0, max_value=20), - "weighted": st.just(False), - "cache_set_scale": st.just(0.0), - "pooling_mode": st.just(PoolingMode.NONE), - "weights_precision": st.sampled_from([SparseType.FP16, SparseType.FP32]), - "output_dtype": st.sampled_from([SparseType.FP16, SparseType.FP32]), "m1_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), "m2_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), - "share_table": st.just(False), - "trigger_bounds_check": st.just(False), - "mixed_B": st.just(False), }, - num_buckets=st.just(10), + num_buckets=st.integers(min_value=10, max_value=15), + enable_optimizer_offloading=st.booleans(), + backend_type=st.sampled_from([BackendType.SSD, BackendType.DRAM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_kv_emb_state_dict_adam( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + m1_dtype: SparseType, + m2_dtype: SparseType, + share_table: bool, + trigger_bounds_check: bool, + mixed_B: bool, + num_buckets: int, + enable_optimizer_offloading: bool, + backend_type: BackendType, + ) -> None: + assume(not weighted or pooling_mode == PoolingMode.SUM) + # VBE is currently not supported for PARTIAL_ROWWISE_ADAM optimizer + assume(not mixed_B) + # Don't stimulate boundary check cases + trigger_bounds_check = False + + # Constants + lr = 0.5 + eps = 0.2 + ssd_shards = 2 + beta1 = 0.9 + beta2 = 0.99 + weight_decay = 0.01 + + # Generate embedding modules and inputs + ( + emb, + emb_ref, + Es, + _, + bucket_offsets, + bucket_sizes, + ) = self.generate_kvzch_tbes( + T, + D, + B, + log_E, + L, + weighted, + lr=lr, + eps=eps, + weight_decay=weight_decay, + beta1=beta1, + beta2=beta2, + ssd_shards=ssd_shards, + optimizer=OptimType.ADAM, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + share_table=share_table, + num_buckets=num_buckets, + enable_optimizer_offloading=enable_optimizer_offloading, + backend_type=backend_type, + optimizer_state_dtypes={ + "momentum1": m1_dtype, + "momentum2": m2_dtype, + }, + ) + + # Generate inputs + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + # Execute forward + output_ref_list, output = self.execute_ssd_forward_( + emb, + emb_ref, + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + B, + L, + weighted, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + + # Execute backward + self.execute_ssd_backward_( + output_ref_list, + output, + B, + D, + pooling_mode, + batch_size_per_feature_per_rank, + ) + + emb.flush() + + split_optimizer_states = [] + + # Compare emb state dict with expected values from nn.EmbeddingBag + emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list, _ = ( + emb.split_embedding_weights(no_snapshot=False, should_flush=True) + ) + + for s in emb.split_optimizer_states( + bucket_asc_ids_list, no_snapshot=False, should_flush=True + ): + split_optimizer_states.append(s) + + # Compare optimizer states + for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): + (m1, m2) = split_optimizer_states[t] + # Some optimizers have non-float momentum values + # pyre-ignore[16] + ref_grad = emb_ref[f].weight.grad.cpu().to_dense() + ref_weights = emb_ref[f].weight.cpu() + + # Compare momentum2 values: (1 - beta2) * dL^2 + m2_ref = ref_grad.pow(2) * (1.0 - beta2) + # Get only the subset of rows based on bucket_asc_ids_list[t] + # pyre-ignore [16] + m2_ref = m2_ref[bucket_asc_ids_list[t].view(-1)] + self.assert_close_(m2, m2_ref) + + # Compare momentum1 values: (1 - beta1) * dL + m1_ref = ref_grad * (1.0 - beta1) + # Get only the subset of rows based on bucket_asc_ids_list[t] + m1_ref = m1_ref[bucket_asc_ids_list[t].view(-1)] + # Print which rows are different between m1 and m1_ref + print_different_rows( + m1, m1_ref, atol=1.0e-2, rtol=1.0e-2, name1="m1", name2="m1_ref" + ) + self.assert_close_(m1, m1_ref) + + #################################################################### + # Compare weight values + #################################################################### + + # Re-index the weights according to bucket ids + ref_weights = ref_weights[bucket_asc_ids_list[t].view(-1)] + + # Bias corrections + iter_ = emb.iter.item() + v_hat_t = m2_ref / (1 - beta2**iter_) + m_hat_t = m1_ref / (1 - beta1**iter_) + + # Manually update the ref weights + ref_weights_updated = ( + torch.addcdiv( + ref_weights, + value=-lr, + tensor1=m_hat_t, + tensor2=v_hat_t.sqrt_().add_(eps), + ) + - lr * weight_decay * ref_weights + ) + + # Fetch the updated weights from SSDTableBatchedEmbeddingBags + emb_w = ( + emb_state_dict_list[t] + .narrow(0, 0, bucket_asc_ids_list[t].size(0)) + .float() + ) + + # Compare weights + tolerance = 1.0e-2 + torch.testing.assert_close( + emb_w, + ref_weights_updated, + atol=tolerance, + rtol=tolerance, + ) + + @given( + **default_st, + **{ + "m1_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), + "m2_dtype": st.sampled_from([SparseType.BF16, SparseType.FP32]), + }, + num_buckets=st.integers(min_value=10, max_value=15), enable_optimizer_offloading=st.just(True), backend_type=st.sampled_from([BackendType.DRAM]), )