Skip to content

Commit

Permalink
Add more extensive top-2 logging.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 392699230
  • Loading branch information
Mesh TensorFlow Team committed Aug 24, 2021
1 parent 1f381ce commit 21c4ef3
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions mesh_tensorflow/transformer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,55 @@ def _top_2_gating(
position_in_expert_2 = mtf.reduce_sum(
position_in_expert_2, reduced_dim=experts_dim)

if train:
# Gate entropy.
if importance is not None:
raw_gates *= mtf.to_float(mtf.greater(importance, 0.0))
entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
reduced_dim=experts_dim)
batch_entropy = mtf.reduce_mean(entropy)
mtf.scalar_summary(name + "/entropy", batch_entropy)

# Mean top-1 and top-2 normalized gate probabilities.
if importance is not None:
gate_2 *= mtf.to_float(mtf.greater(importance, 0.0))
mtf.scalar_summary("top1_gate_normalized", mtf.reduce_mean(gate_1))
mtf.scalar_summary("top2_gate_normalized", mtf.reduce_mean(gate_2))
top1_routed = mtf.reduce_sum(mask_1_flat)
top2_routed = mtf.reduce_sum(mask_2_flat)
importance = mtf.cast(importance, dtype=top1_routed.dtype)

# What fraction of the top-1 and top-2 tokens are being routed to any
# expert.
mtf.scalar_summary("top1_fraction_routed",
top1_routed / mtf.reduce_sum(importance))
mtf.scalar_summary("top2_fraction_routed",
top2_routed / mtf.reduce_sum(importance))
# One or zero if that token got routed anywhere.
total_routed = mtf.reduce_sum(mtf.minimum(
mask_1_flat + mask_2_flat, mtf.ones_like(top1_routed)))
mtf.scalar_summary("all_fraction_routed",
total_routed / mtf.reduce_sum(importance))
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

# Log what fraction of tokens are going to each expert.
def _log_per_expert_fraction(mask, name):
# mask: [batch, group, experts]
tokens_per_expert = mtf.reduce_sum(mask, output_shape=[experts_dim])
total_routed = mtf.reduce_sum(tokens_per_expert)
expert_fraction = mtf.to_float(tokens_per_expert / total_routed)
split_fractions = mtf.split(
expert_fraction,
split_dim=experts_dim,
num_or_size_splits=experts_dim.size)
for fraction in split_fractions:
mtf.scalar_summary(name + "_experts/" + fraction.name.replace(":", "/"),
mtf.reduce_mean(fraction))

_log_per_expert_fraction(mask_1, "top1")
_log_per_expert_fraction(mask_2, "top2")
_log_per_expert_fraction(mask_1 + mask_2, "all")

# [batch, group, experts, expert_capacity]
combine_tensor = (
gate_1 * mask_1_flat
Expand Down

0 comments on commit 21c4ef3

Please sign in to comment.