Skip to content

Commit

Permalink
Raise error when weight decay is set in TPUEmbeddingV2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631579045
  • Loading branch information
pineapplejuice233 authored and tensorflower-gardener committed May 7, 2024
1 parent 6e278a4 commit 92261ba
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tensorflow/python/tpu/tpu_embedding_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,19 @@ def __init__(
.format(self._strategy)
)

# TODO(pineapplejuice233): Remove this once weight decay is supported.
for table in self._table_config:
if (
table.optimizer.weight_decay_factor is not None
or table.optimizer.multiply_weight_decay_factor_by_learning_rate
is not None
):
raise NotImplementedError(
"weight_decay_factor and"
" multiply_weight_decay_factor_by_learning_rate are not supported"
f" yet. But found in table {table.name} setting."
)

self._num_sc_per_chip = (
self._strategy.extended.tpu_hardware_feature.num_embedding_devices_per_chip
)
Expand Down
21 changes: 21 additions & 0 deletions tensorflow/python/tpu/tpu_embedding_v3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,27 @@ def step(data):
):
self.assertAllEqual(per_feature_result, per_feature_result_cpu)

def test_raise_error_when_weight_decay_is_set(self):
feature_config = tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='watched', output_shape=[16]
)

resolver = tpu_cluster_resolver.TPUClusterResolver(tpu='')
remote.connect_to_cluster(resolver)
tpu_cluster_resolver.initialize_tpu_system(resolver)
strategy = tpu_strategy.TPUStrategy(resolver)

with self.assertRaises(NotImplementedError):
with strategy.scope():
tpu_embedding_v3.TPUEmbeddingV2(
feature_config=feature_config,
optimizer=tpu_embedding_v2_utils.SGD(
learning_rate=1.0,
weight_decay_factor=0.1,
multiply_weight_decay_factor_by_learning_rate=True,
),
)


if __name__ == '__main__':
v2_compat.enable_v2_behavior()
Expand Down

0 comments on commit 92261ba

Please sign in to comment.