diff --git a/mesh_tensorflow/transformer/moe.py b/mesh_tensorflow/transformer/moe.py index 8ad53663..82fe9455 100644 --- a/mesh_tensorflow/transformer/moe.py +++ b/mesh_tensorflow/transformer/moe.py @@ -409,7 +409,8 @@ def transformer_moe_layer_v1( train=train, variable_dtype=variable_dtype, importance=nonpadding, - num_microbatches=num_microbatches) + num_microbatches=num_microbatches, + token_embeddings=token_embeddings) elif hparams.moe_gating == "switch": dispatch_tensor, combine_tensor, loss = _switch_gating( inputs=inputs, @@ -432,7 +433,8 @@ def transformer_moe_layer_v1( train=train, variable_dtype=variable_dtype, importance=nonpadding, - num_microbatches=num_microbatches) + num_microbatches=num_microbatches, + token_embeddings=token_embeddings) elif hparams.moe_gating == "switch_max": dispatch_tensor, combine_tensor, loss = _switch_max_gating( inputs=inputs, @@ -443,7 +445,8 @@ def transformer_moe_layer_v1( train=train, variable_dtype=variable_dtype, importance=nonpadding, - num_microbatches=num_microbatches) + num_microbatches=num_microbatches, + token_embeddings=token_embeddings) elif hparams.moe_gating == "expert_selection": dispatch_tensor, combine_tensor, loss = _expert_selection_gating( inputs=inputs, @@ -456,7 +459,8 @@ def transformer_moe_layer_v1( variable_dtype=variable_dtype, importance=nonpadding, name="expert_selection_gating", - num_microbatches=num_microbatches) + num_microbatches=num_microbatches, + token_embeddings=token_embeddings) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) @@ -866,7 +870,8 @@ def _ntlb_gating(inputs, variable_dtype, importance=None, name="ntlb_gating", - num_microbatches=None): + num_microbatches=None, + token_embeddings=None): """Compute Switch gating with no-token-left behind (NTLB) behavior.""" # SELECT EXPERT if train: @@ -874,12 +879,21 @@ def _ntlb_gating(inputs, else: policy = hparams.moe_switch_policy_eval + # The internals of this function run in float32. + # bfloat16 seems to reduce quality. + gate_inputs = mtf.to_float(inputs) + # Input perturbations if train and policy == "input_jitter": - inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_switch_jitter) + gate_inputs = mtf.layers.multiplicative_jitter( + gate_inputs, hparams.moe_switch_jitter) + + if hparams.moe_word_embed_mode is not None: + gate_inputs = _add_token_emb_to_gate_inputs( + gate_inputs, token_embeddings, hparams.moe_word_embed_mode) gate_logits = mtf.layers.dense( - inputs, + gate_inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, @@ -1004,7 +1018,7 @@ def _ntlb_gating(inputs, def _switch_max_gating( inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="switch_max_gating", - num_microbatches=None): + num_microbatches=None, token_embeddings=None): """Compute Switch gating.""" # TODO(barretzoph,liamfedus): Refactor switch_max, switch and ntlb to limit # code resuse. @@ -1027,6 +1041,10 @@ def _switch_max_gating( gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs, hparams.moe_switch_jitter) + if hparams.moe_word_embed_mode is not None: + gate_inputs = _add_token_emb_to_gate_inputs( + gate_inputs, token_embeddings, hparams.moe_word_embed_mode) + gate_logits = mtf.layers.dense( gate_inputs, experts_dim, @@ -1123,7 +1141,7 @@ def _expert_selection_gating( inputs, outer_expert_dims, experts_dim, group_size_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="expert_selection_gating", num_microbatches=None, - normalize_by_num_experts_routed=True): + normalize_by_num_experts_routed=True, token_embeddings=None): """Compute gating where each expert chooses what tokens it wants.""" # Select the randomization policy. if train: @@ -1143,6 +1161,10 @@ def _expert_selection_gating( gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs, hparams.moe_switch_jitter) + if hparams.moe_word_embed_mode is not None: + gate_inputs = _add_token_emb_to_gate_inputs( + gate_inputs, token_embeddings, hparams.moe_word_embed_mode) + # Compute expert logits for each token. # gate_logits shape: [outer_batch, batch, group, expert_unsplit] gate_logits = mtf.layers.dense( @@ -1252,8 +1274,6 @@ def _switch_gating( # The internals of this function run in float32. # bfloat16 seems to reduce quality. gate_inputs = mtf.to_float(inputs) - if hparams.moe_word_embed_mode is not None: - token_embeddings = mtf.to_float(token_embeddings) # Input perturbations if policy == "input_dropout": @@ -1265,19 +1285,9 @@ def _switch_gating( gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs, hparams.moe_switch_jitter) - if hparams.moe_word_embed_mode == "concat": - gate_inputs = mtf.concat( - [gate_inputs, token_embeddings], gate_inputs.shape.dims[-1].name) - elif hparams.moe_word_embed_mode == "concat_stop_grad": - token_embeddings = mtf.stop_gradient(token_embeddings) - gate_inputs = mtf.concat( - [gate_inputs, token_embeddings], gate_inputs.shape.dims[-1].name) - elif hparams.moe_word_embed_mode == "add": - gate_inputs += token_embeddings - elif hparams.moe_word_embed_mode == "add_stop_grad": - gate_inputs += mtf.stop_gradient(token_embeddings) - elif hparams.moe_word_embed_mode == "embed_only": - gate_inputs = token_embeddings + if hparams.moe_word_embed_mode is not None: + gate_inputs = _add_token_emb_to_gate_inputs( + gate_inputs, token_embeddings, hparams.moe_word_embed_mode) gate_logits = mtf.layers.dense( gate_inputs, @@ -1397,7 +1407,7 @@ def _switch_gating( def _top_2_gating( inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="top_2_gating", - num_microbatches=None): + num_microbatches=None, token_embeddings=None): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters @@ -1451,6 +1461,9 @@ def _top_2_gating( importance: an optional tensor with shape [, group_size_dim] name: an optional string num_microbatches: number of microbatches. + token_embeddings: an optional tensor with shape + [, group_size_dim, input_dim] that is the input + word embeddings. Returns: dispatch_tensor: a Tensor with shape @@ -1468,6 +1481,10 @@ def _top_2_gating( # bfloat16 seems to reduce quality. gate_inputs = mtf.to_float(inputs) + if hparams.moe_word_embed_mode is not None: + gate_inputs = _add_token_emb_to_gate_inputs( + gate_inputs, token_embeddings, hparams.moe_word_embed_mode) + raw_gates = mtf.layers.dense( gate_inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, @@ -1598,6 +1615,30 @@ def _top_2_gating( return dispatch_tensor, combine_tensor, loss +def _add_token_emb_to_gate_inputs( + gate_inputs, token_embeddings, moe_word_embed_mode): + """Add token_embeddings to gate_inputs based on moe_word_embed_mode.""" + + token_embeddings = mtf.to_float(token_embeddings) + if moe_word_embed_mode == "concat": + gate_inputs = mtf.concat( + [gate_inputs, token_embeddings], gate_inputs.shape.dims[-1].name) + elif moe_word_embed_mode == "concat_stop_grad": + token_embeddings = mtf.stop_gradient(token_embeddings) + gate_inputs = mtf.concat( + [gate_inputs, token_embeddings], gate_inputs.shape.dims[-1].name) + elif moe_word_embed_mode == "add": + gate_inputs += token_embeddings + elif moe_word_embed_mode == "add_stop_grad": + gate_inputs += mtf.stop_gradient(token_embeddings) + elif moe_word_embed_mode == "embed_only": + gate_inputs = token_embeddings + else: + raise ValueError("Unimplemented moe word embed mode: {}".format( + moe_word_embed_mode)) + return gate_inputs + + def set_default_moe_hparams(hparams): """Add necessary hyperparameters for mixture-of-experts.""" hparams.moe_num_experts = 16