Skip to content

Commit

Permalink
Fix failure in windowed_sampling_test.jax in OSS.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 453202511
  • Loading branch information
jburnim authored and tensorflower-gardener committed Jun 6, 2022
1 parent f8107c1 commit 8e72c11
Showing 1 changed file with 3 additions and 2 deletions.
Expand Up @@ -269,8 +269,9 @@ def step_broadcast(step_size):
shard_axis_names = pinned_model.experimental_shard_axis_names
if any(tf.nest.flatten(shard_axis_names)):
shard_axis_names = nest.flatten_up_to(
initial_transformed_position, pinned_model._model_flatten( # pylint: disable=protected-access
shard_axis_names))
initial_transformed_position,
list(pinned_model._model_flatten(shard_axis_names))) # pylint: disable=protected-access

else:
# No active shard axis names
shard_axis_names = None
Expand Down

0 comments on commit 8e72c11

Please sign in to comment.