Skip to content

Commit

Permalink
Fix the bug that disable_table_stacking is not respected in TPUEmbedd…
Browse files Browse the repository at this point in the history
…ingV2.

PiperOrigin-RevId: 631447524
  • Loading branch information
pineapplejuice233 authored and tensorflower-gardener committed May 7, 2024
1 parent f619248 commit 5f1eb4d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 44 deletions.
90 changes: 47 additions & 43 deletions tensorflow/python/tpu/tpu_embedding_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,17 +320,16 @@ def _stack_tables_with_same_table_dim_and_optimizer(
if sparse_core_embedding_config:
disable_table_stacking = sparse_core_embedding_config.disable_table_stacking

if disable_table_stacking:
logging.warn("Table stacking is disabled.")

s = TableStacking()

# Round the table sizes to be divisible by the number of SCs.
num_shards = num_partitions * num_sc_per_partition * 8

s.table_to_padding_columns = {}
s.table_to_padding_rows = {}

table_name_to_table = {}
for table in table_config:
table_name_to_table[table.name] = table
extra_rows = (
num_shards - (table.vocabulary_size % num_shards)
) % num_shards
Expand Down Expand Up @@ -364,49 +363,54 @@ def _stack_tables_with_same_table_dim_and_optimizer(
table.vocabulary_size += extra_rows
table.dim += extra_cols

table_names = []
table_widths = []
table_heights = []
table_num_samples = []
table_groups = []

table_data_to_group = {}
table_to_num_samples = {table.name: 0 for table in table_config}
table_name_to_table = {}
for _, feature in flat_features:
table_to_num_samples[feature.table.name] += functools.reduce(
operator.mul, feature.output_shape
)
if disable_table_stacking:
logging.warn("Table stacking is disabled.")
table_stacks = [[table] for table in table_config]
else:
table_names = []
table_widths = []
table_heights = []
table_num_samples = []
table_groups = []

table_data_to_group = {}
table_to_num_samples = {table.name: 0 for table in table_config}
for _, feature in flat_features:
table_to_num_samples[feature.table.name] += functools.reduce(
operator.mul, feature.output_shape
)

for table in table_config:
table_name_to_table[table.name] = table
key = (
table.dim,
table.optimizer,
repr(table.quantization_config) if table.quantization_config else None,
for table in table_config:
key = (
table.dim,
table.optimizer,
repr(table.quantization_config)
if table.quantization_config
else None,
)
if key not in table_data_to_group:
table_data_to_group[key] = len(table_data_to_group)
table_groups.append(table_data_to_group[key])
table_names.append(table.name)
table_widths.append(table.dim)
table_heights.append(table.vocabulary_size)
table_num_samples.append(table_to_num_samples[table.name])

table_stacks_by_name = _pywrap_tpu_embedding.stack_tables(
table_heights,
table_widths,
table_num_samples,
table_groups,
table_names,
num_partitions,
)
if key not in table_data_to_group:
table_data_to_group[key] = len(table_data_to_group)
table_groups.append(table_data_to_group[key])
table_names.append(table.name)
table_widths.append(table.dim)
table_heights.append(table.vocabulary_size)
table_num_samples.append(table_to_num_samples[table.name])

table_stacks_by_name = _pywrap_tpu_embedding.stack_tables(
table_heights,
table_widths,
table_num_samples,
table_groups,
table_names,
num_partitions,
)

table_stacks = [
[table_name_to_table[table_name] for table_name in stack_by_name]
for stack_by_name in table_stacks_by_name
]
table_stacks = [
[table_name_to_table[table_name] for table_name in stack_by_name]
for stack_by_name in table_stacks_by_name
]

s.table_name_to_table = table_name_to_table
# Store the mapping between stacked table names to the actual tableConfigs.
s.stacked_table_to_tables = {}
# Store the mapping between table to name of the stacked table which
Expand Down
17 changes: 16 additions & 1 deletion tensorflow/python/tpu/tpu_embedding_v3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,17 @@ def test_two_features_two_tables_stacked_lookup_with_csr_input(self):
dist_iter = iter(dist)
data = next(dist_iter)

sparse_core_embedding_config = tpu_embedding_v3.SparseCoreEmbeddingConfig(
disable_table_stacking=False,
max_ids_per_chip_per_sample=64,
allow_id_dropping=False,
)

with strategy.scope():
mid_level_api = tpu_embedding_v3.TPUEmbeddingV2(
feature_config=feature_config,
optimizer=tpu_embedding_v2_utils.SGD(learning_rate=1.0),
sparse_core_embedding_config=sparse_core_embedding_config,
)
self.assertLen(mid_level_api.embedding_tables, 1)

Expand Down Expand Up @@ -581,10 +588,18 @@ def test_two_feature_two_tables_stacked_backwards_pass_with_csr_input(self):
dist_iter = iter(dist)
data = next(dist_iter)

sparse_core_embedding_config = tpu_embedding_v3.SparseCoreEmbeddingConfig(
disable_table_stacking=False,
max_ids_per_chip_per_sample=64,
allow_id_dropping=False,
)

with strategy.scope():
optimizer = tpu_embedding_v2_utils.SGD(learning_rate=1.0)
mid_level_api = tpu_embedding_v3.TPUEmbeddingV2(
feature_config=feature_config, optimizer=optimizer
feature_config=feature_config,
optimizer=optimizer,
sparse_core_embedding_config=sparse_core_embedding_config,
)
mid_level_api.build()
random1 = np.random.uniform(size=(16, self.embedding_dim)).astype(
Expand Down

0 comments on commit 5f1eb4d

Please sign in to comment.