Skip to content

Commit

Permalink
- Fix casting for NTLB.
Browse files Browse the repository at this point in the history
- Factorize word embedding routing code and add to other routing algorithms.

PiperOrigin-RevId: 385845773
  • Loading branch information
Mesh TensorFlow Team committed Jul 20, 2021
1 parent 38e58c4 commit 1668afb
Showing 1 changed file with 66 additions and 25 deletions.
91 changes: 66 additions & 25 deletions mesh_tensorflow/transformer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def transformer_moe_layer_v1(
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "switch":
dispatch_tensor, combine_tensor, loss = _switch_gating(
inputs=inputs,
Expand All @@ -432,7 +433,8 @@ def transformer_moe_layer_v1(
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "switch_max":
dispatch_tensor, combine_tensor, loss = _switch_max_gating(
inputs=inputs,
Expand All @@ -443,7 +445,8 @@ def transformer_moe_layer_v1(
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches)
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "expert_selection":
dispatch_tensor, combine_tensor, loss = _expert_selection_gating(
inputs=inputs,
Expand All @@ -456,7 +459,8 @@ def transformer_moe_layer_v1(
variable_dtype=variable_dtype,
importance=nonpadding,
name="expert_selection_gating",
num_microbatches=num_microbatches)
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

Expand Down Expand Up @@ -866,20 +870,30 @@ def _ntlb_gating(inputs,
variable_dtype,
importance=None,
name="ntlb_gating",
num_microbatches=None):
num_microbatches=None,
token_embeddings=None):
"""Compute Switch gating with no-token-left behind (NTLB) behavior."""
# SELECT EXPERT
if train:
policy = hparams.moe_switch_policy_train
else:
policy = hparams.moe_switch_policy_eval

# The internals of this function run in float32.
# bfloat16 seems to reduce quality.
gate_inputs = mtf.to_float(inputs)

# Input perturbations
if train and policy == "input_jitter":
inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_switch_jitter)
gate_inputs = mtf.layers.multiplicative_jitter(
gate_inputs, hparams.moe_switch_jitter)

if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)

gate_logits = mtf.layers.dense(
inputs,
gate_inputs,
experts_dim,
use_bias=False,
expert_dims=outer_expert_dims,
Expand Down Expand Up @@ -1004,7 +1018,7 @@ def _ntlb_gating(inputs,
def _switch_max_gating(
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
hparams, train, variable_dtype, importance=None, name="switch_max_gating",
num_microbatches=None):
num_microbatches=None, token_embeddings=None):
"""Compute Switch gating."""
# TODO(barretzoph,liamfedus): Refactor switch_max, switch and ntlb to limit
# code resuse.
Expand All @@ -1027,6 +1041,10 @@ def _switch_max_gating(
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
hparams.moe_switch_jitter)

if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)

gate_logits = mtf.layers.dense(
gate_inputs,
experts_dim,
Expand Down Expand Up @@ -1123,7 +1141,7 @@ def _expert_selection_gating(
inputs, outer_expert_dims, experts_dim, group_size_dim,
expert_capacity_dim, hparams, train, variable_dtype, importance=None,
name="expert_selection_gating", num_microbatches=None,
normalize_by_num_experts_routed=True):
normalize_by_num_experts_routed=True, token_embeddings=None):
"""Compute gating where each expert chooses what tokens it wants."""
# Select the randomization policy.
if train:
Expand All @@ -1143,6 +1161,10 @@ def _expert_selection_gating(
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
hparams.moe_switch_jitter)

if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)

# Compute expert logits for each token.
# gate_logits shape: [outer_batch, batch, group, expert_unsplit]
gate_logits = mtf.layers.dense(
Expand Down Expand Up @@ -1252,8 +1274,6 @@ def _switch_gating(
# The internals of this function run in float32.
# bfloat16 seems to reduce quality.
gate_inputs = mtf.to_float(inputs)
if hparams.moe_word_embed_mode is not None:
token_embeddings = mtf.to_float(token_embeddings)

# Input perturbations
if policy == "input_dropout":
Expand All @@ -1265,19 +1285,9 @@ def _switch_gating(
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
hparams.moe_switch_jitter)

if hparams.moe_word_embed_mode == "concat":
gate_inputs = mtf.concat(
[gate_inputs, token_embeddings], gate_inputs.shape.dims[-1].name)
elif hparams.moe_word_embed_mode == "concat_stop_grad":
token_embeddings = mtf.stop_gradient(token_embeddings)
gate_inputs = mtf.concat(
[gate_inputs, token_embeddings], gate_inputs.shape.dims[-1].name)
elif hparams.moe_word_embed_mode == "add":
gate_inputs += token_embeddings
elif hparams.moe_word_embed_mode == "add_stop_grad":
gate_inputs += mtf.stop_gradient(token_embeddings)
elif hparams.moe_word_embed_mode == "embed_only":
gate_inputs = token_embeddings
if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)

gate_logits = mtf.layers.dense(
gate_inputs,
Expand Down Expand Up @@ -1397,7 +1407,7 @@ def _switch_gating(
def _top_2_gating(
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
hparams, train, variable_dtype, importance=None, name="top_2_gating",
num_microbatches=None):
num_microbatches=None, token_embeddings=None):
"""Compute gating for mixture-of-experts in TensorFlow.
Note: until the algorithm and inferface solidify, we pass in a hyperparameters
Expand Down Expand Up @@ -1451,6 +1461,9 @@ def _top_2_gating(
importance: an optional tensor with shape [<batch_dims>, group_size_dim]
name: an optional string
num_microbatches: number of microbatches.
token_embeddings: an optional tensor with shape
[<batch_dims>, group_size_dim, input_dim] that is the input
word embeddings.
Returns:
dispatch_tensor: a Tensor with shape
Expand All @@ -1468,6 +1481,10 @@ def _top_2_gating(
# bfloat16 seems to reduce quality.
gate_inputs = mtf.to_float(inputs)

if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)

raw_gates = mtf.layers.dense(
gate_inputs, experts_dim, use_bias=False,
expert_dims=outer_expert_dims,
Expand Down Expand Up @@ -1598,6 +1615,30 @@ def _top_2_gating(
return dispatch_tensor, combine_tensor, loss


def _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, moe_word_embed_mode):
"""Add token_embeddings to gate_inputs based on moe_word_embed_mode."""

token_embeddings = mtf.to_float(token_embeddings)
if moe_word_embed_mode == "concat":
gate_inputs = mtf.concat(
[gate_inputs, token_embeddings], gate_inputs.shape.dims[-1].name)
elif moe_word_embed_mode == "concat_stop_grad":
token_embeddings = mtf.stop_gradient(token_embeddings)
gate_inputs = mtf.concat(
[gate_inputs, token_embeddings], gate_inputs.shape.dims[-1].name)
elif moe_word_embed_mode == "add":
gate_inputs += token_embeddings
elif moe_word_embed_mode == "add_stop_grad":
gate_inputs += mtf.stop_gradient(token_embeddings)
elif moe_word_embed_mode == "embed_only":
gate_inputs = token_embeddings
else:
raise ValueError("Unimplemented moe word embed mode: {}".format(
moe_word_embed_mode))
return gate_inputs


def set_default_moe_hparams(hparams):
"""Add necessary hyperparameters for mixture-of-experts."""
hparams.moe_num_experts = 16
Expand Down

0 comments on commit 1668afb

Please sign in to comment.