From 2456b1b65229b7671bd86c4752fb616e0bb4c5d8 Mon Sep 17 00:00:00 2001 From: Mesh TensorFlow Team Date: Sat, 14 Aug 2021 10:27:06 -0700 Subject: [PATCH] Expert Attention Fixes: - Allow moe.py to work with a tensor of "memory_length" dimension - Fix Experts Attention bug in moe.py where it would break during decoding if the input dimension was different than the output dimension. - Fix bug in ExpertsEncDecAttention where it was only doing Self-Attention on the decoder side. - Factorize expert_computation code to easily allow for using different query and memory antecedents PiperOrigin-RevId: 390808726 --- mesh_tensorflow/transformer/attention.py | 96 +++++++++++-------- mesh_tensorflow/transformer/moe.py | 32 +++++-- mesh_tensorflow/transformer/transformer.py | 11 +++ .../transformer/transformer_layers.py | 8 +- 4 files changed, 99 insertions(+), 48 deletions(-) diff --git a/mesh_tensorflow/transformer/attention.py b/mesh_tensorflow/transformer/attention.py index b58020a1..5533241b 100644 --- a/mesh_tensorflow/transformer/attention.py +++ b/mesh_tensorflow/transformer/attention.py @@ -631,10 +631,11 @@ def __init__(self, combine_dims=True, ensemble_dim=None, keep_query_heads_dims=False, - fold_scaling_into_initializer=True, + fold_scaling_into_initializer=False, context=None, experts_hparams=None, - expert_computation="qkv"): + expert_computation="qkv", + is_encdec=False): super(ExpertsAttentionParams, self).__init__( mesh=mesh, query_input_dim=query_input_dim, @@ -655,6 +656,7 @@ def __init__(self, self.context = context self.expert_computation = expert_computation + self.is_encdec = is_encdec # Unless we want to compute both q and kv, we can use the normal MoE # settings. @@ -696,9 +698,6 @@ def __init__(self, # we want to partition both "experts_hidden" and "heads". moe_output_dims = mtf.Dimension("d_model", self.q_shape[-1].size) - tf.logging.info("ExpertsAttention moe_hidden_size: {}".format( - experts_hparams.hidden_size)) - tf.logging.info("moe_output_dims: {}".format(moe_output_dims)) self.moe_layer = mtf.transformer.moe.MoE1D( moe_gating=experts_hparams.moe_gating, num_experts=experts_hparams.num_experts, @@ -719,55 +718,70 @@ def __init__(self, activation=experts_hparams.activation, z_loss=experts_hparams.z_loss) + def _replace_d_model_dim(self, t): + """Used to replace the `d_model` dim with `heads`.""" + new_last_dim = mtf.Dimension(self.q_shape[-1].name, t.shape[-1].size) + return mtf.reshape(t, new_shape=mtf.Shape(t.shape[:-1] + [new_last_dim])) + + def _compute_q_with_experts(self, antecedent): + q = self.moe_layer.call(self.context, antecedent) + q = self._replace_d_model_dim(q) + return q + + def _compute_kv_with_experts(self, antecedent): + kv = self.moe_layer.call( + self.context, antecedent, use_enc_nonpadding=self.is_encdec) + kv = self._replace_d_model_dim(kv) + return kv + def _compute_merge_qkv(self, antecedent): """Computes qkv all in one call using MoE layer.""" - def _replace_d_model_dim(t): - """Used to replace the `d_model` dim with `heads`.""" - new_last_dim = mtf.Dimension(self.q_shape[-1].name, t.shape[-1].size) - return mtf.reshape( - t, new_shape=mtf.Shape(t.shape[:-1] + [new_last_dim])) + # This mode assumes query and memory antecedent are the same. + qkv = self.moe_layer.call(self.context, antecedent) + q, kv = qkv + q = self._replace_d_model_dim(q) + kv = self._replace_d_model_dim(kv) + self._q = q + self._kv = kv + + def compute_q(self, query_antecedent): if self.expert_computation == "qkv": - # NOTE: This assumes querty and memory antecedent are the same - qk = self.moe_layer.call(self.context, antecedent) - # Split qk here since they went through experts-layers - q, k = qk - q = _replace_d_model_dim(q) - k = _replace_d_model_dim(k) + self._compute_merge_qkv(query_antecedent) + q = self._q elif self.expert_computation == "q": - q = self.moe_layer.call(self.context, antecedent) - q = _replace_d_model_dim(q) - # Compute key/value normally - k = mtf.layers.us_einsum( - [antecedent, self.wkv], reduced_dims=[self.memory_input_dim]) + q = self._compute_q_with_experts(query_antecedent) + # If computing "kv" with experts, then compute q normally. elif self.expert_computation == "kv": - k = self.moe_layer.call(self.context, antecedent) - k = _replace_d_model_dim(k) - # Compute query normally q = mtf.layers.us_einsum( - [antecedent, self.wq], reduced_dims=[self.query_input_dim]) - else: - raise ValueError("Invalid expert computation mode: {}".format( - self.expert_computation)) - - # Scale query + [query_antecedent, self.wq], reduced_dims=[self.query_input_dim]) q *= self.key_dim.size ** -0.5 - self._q = mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims) - self._k = mtf.replace_dimensions(k, k.shape.dims[-1], self.k_dims) - - def compute_q(self, query_antecedent): - self._compute_merge_qkv(query_antecedent) - return self._q + return mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims) def compute_k(self, memory_antecedent): - del memory_antecedent - return self._k + raise NotImplementedError("ExpertsAttention uses shared_kv = True.") def compute_kv(self, memory_antecedent): - del memory_antecedent - return self._k + if self.expert_computation == "qkv": + # We have already computing "kv" with "q", so just return its value. + kv = self._kv + # Check if the "length" dimension should be "memory_length" since both + # q and kv were computed using the same antecedent. This is why we must + # always have the same query and memory antecedent for the qkv mode. + if self.context.length_dim in kv.shape.dims: + memory_length = mtf.Dimension( + "memory_length", self.context.length_dim.size) + kv = mtf.replace_dimensions( + kv, self.context.length_dim, memory_length) + # If computing "q" with experts, then compute "kv" normally. + elif self.expert_computation == "q": + kv = mtf.layers.us_einsum( + [memory_antecedent, self.wkv], reduced_dims=[self.memory_input_dim]) + elif self.expert_computation == "kv": + kv = self._compute_kv_with_experts(memory_antecedent) + kv = mtf.replace_dimensions(kv, kv.shape.dims[-1], self.k_dims) + return kv def compute_v(self, memory_antecedent): - del memory_antecedent raise NotImplementedError("ExpertsAttention uses shared_kv = True.") diff --git a/mesh_tensorflow/transformer/moe.py b/mesh_tensorflow/transformer/moe.py index 557c3ea3..702a2fd0 100644 --- a/mesh_tensorflow/transformer/moe.py +++ b/mesh_tensorflow/transformer/moe.py @@ -98,13 +98,30 @@ def __init__(self, moe_top_n_num_experts_per_token=top_n_num_experts_per_token) self._activation = activation - def call(self, context, x, losses=None): + def call(self, context, x, losses=None, use_enc_nonpadding=False): """Call the layer.""" if context.model.ensemble_dim: raise NotImplementedError("MoE not yet implemented with ensembles") has_length_dim = context.length_dim in x.shape.dims - if not has_length_dim: + has_memory_length_dim = "memory_length" in x.shape.dimension_names + # Used for EncDec attention if we have the MoE layer produce the kv. + if use_enc_nonpadding: + nonpadding = context.nonpadding_encoder + else: + nonpadding = context.nonpadding + # If a memory_length dimension exists, then we make sure the + # length dimension of the nonpadding tensor matches it. + if (has_memory_length_dim and isinstance(nonpadding, mtf.Tensor) + and "length" in nonpadding.shape.dimension_names): + old_length_dim = nonpadding.shape.get_dim_by_name("length") + new_length_dim = mtf.Dimension("memory_length", old_length_dim.size) + nonpadding = mtf.replace_dimensions( + nonpadding, old_length_dim, new_length_dim) + # Insert a length dimension if one does not exist. + # Typically no length dims will occur on the decoder during autoregressive + # decoding. + if not has_length_dim and not has_memory_length_dim: x_shape = x.shape shape_with_length = mtf.Shape( x_shape.dims[:-1] + [mtf.Dimension("length", 1)] @@ -124,18 +141,21 @@ def call(self, context, x, losses=None): context.variable_dtype, layout=context.model.layout, mesh_shape=context.model.mesh_shape, - nonpadding=context.nonpadding, + nonpadding=nonpadding, activation=self._activation, num_microbatches=context.num_microbatches, token_embeddings=context.input_embeddings) if context.losses is not None: context.losses.append(loss) - if not has_length_dim: + if not has_length_dim and not has_memory_length_dim: + # Shapes will differ if the input and output dimension of the layer do not + # match. + new_y_shape = mtf.Shape(x_shape.dims[:-1] + [output_dim]) if self._hparams.moe_use_experts_attention: - y_reshape = [mtf.reshape(y_out, x_shape) for y_out in y] + y_reshape = [mtf.reshape(y_out, new_y_shape) for y_out in y] y = y_reshape else: - y = mtf.reshape(y, x_shape) + y = mtf.reshape(y, new_y_shape) return y diff --git a/mesh_tensorflow/transformer/transformer.py b/mesh_tensorflow/transformer/transformer.py index f8c20d9a..204e8cdb 100644 --- a/mesh_tensorflow/transformer/transformer.py +++ b/mesh_tensorflow/transformer/transformer.py @@ -304,6 +304,17 @@ def nonpadding(self): return mtf.cast( mtf.not_equal(self.sequence_id, 0), self.activation_dtype) + @property + def nonpadding_encoder(self): + """Tensor with zeros in padding positions and ones elsewhere for encoder.""" + if self.encoder_sequence_id is None: + return None + if self.encoder_sequence_id == 1: + return 1 + else: + return mtf.cast( + mtf.not_equal(self.encoder_sequence_id, 0), self.activation_dtype) + def get_position(self): if self.position_is_default: return mtf.range(self.mesh, self.length_dim, tf.int32) diff --git a/mesh_tensorflow/transformer/transformer_layers.py b/mesh_tensorflow/transformer/transformer_layers.py index 159acd35..541e3136 100644 --- a/mesh_tensorflow/transformer/transformer_layers.py +++ b/mesh_tensorflow/transformer/transformer_layers.py @@ -410,6 +410,7 @@ def __init__(self, **kwargs): super(ExpertsSelfAttention, self).__init__(**kwargs) self.expert_computation = expert_computation + self.is_encdec = False # Overrided in ExpertsEncDecAttention self._hparams = mtf.transformer.moe.HParams( moe_gating=moe_gating, num_experts=num_experts, @@ -465,7 +466,8 @@ def make_params(self, context): fold_scaling_into_initializer=self.fold_scaling_into_initializer, context=context, experts_hparams=self._hparams, - expert_computation=self.expert_computation) + expert_computation=self.expert_computation, + is_encdec=self.is_encdec) @gin.configurable @@ -475,6 +477,10 @@ class ExpertsEncDecAttention(ExpertsSelfAttention): def __init__(self, relative_attention_type=None, **kwargs): super(ExpertsEncDecAttention, self).__init__( relative_attention_type=relative_attention_type, **kwargs) + self.is_encdec = True + if self.expert_computation == "qkv": + raise ValueError("ExpertsEncDecAttention must use expert_computation of " + "q or kv.") def _get_memory_antecedent(self, context): return context.encoder_output