Skip to content

Commit

Permalink
Expert Attention Fixes:
Browse files Browse the repository at this point in the history
- Allow moe.py to work with a tensor of "memory_length" dimension
- Fix Experts Attention bug in moe.py where it would break during decoding if the input dimension was different than the output dimension.
- Fix bug in ExpertsEncDecAttention where it was only doing Self-Attention on the decoder side.
- Factorize expert_computation code to easily allow for using different query and memory antecedents

PiperOrigin-RevId: 390808726
  • Loading branch information
Mesh TensorFlow Team committed Sep 7, 2021
1 parent 21c4ef3 commit 2456b1b
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 48 deletions.
96 changes: 55 additions & 41 deletions mesh_tensorflow/transformer/attention.py
Expand Up @@ -631,10 +631,11 @@ def __init__(self,
combine_dims=True,
ensemble_dim=None,
keep_query_heads_dims=False,
fold_scaling_into_initializer=True,
fold_scaling_into_initializer=False,
context=None,
experts_hparams=None,
expert_computation="qkv"):
expert_computation="qkv",
is_encdec=False):
super(ExpertsAttentionParams, self).__init__(
mesh=mesh,
query_input_dim=query_input_dim,
Expand All @@ -655,6 +656,7 @@ def __init__(self,

self.context = context
self.expert_computation = expert_computation
self.is_encdec = is_encdec

# Unless we want to compute both q and kv, we can use the normal MoE
# settings.
Expand Down Expand Up @@ -696,9 +698,6 @@ def __init__(self,
# we want to partition both "experts_hidden" and "heads".
moe_output_dims = mtf.Dimension("d_model", self.q_shape[-1].size)

tf.logging.info("ExpertsAttention moe_hidden_size: {}".format(
experts_hparams.hidden_size))
tf.logging.info("moe_output_dims: {}".format(moe_output_dims))
self.moe_layer = mtf.transformer.moe.MoE1D(
moe_gating=experts_hparams.moe_gating,
num_experts=experts_hparams.num_experts,
Expand All @@ -719,55 +718,70 @@ def __init__(self,
activation=experts_hparams.activation,
z_loss=experts_hparams.z_loss)

def _replace_d_model_dim(self, t):
"""Used to replace the `d_model` dim with `heads`."""
new_last_dim = mtf.Dimension(self.q_shape[-1].name, t.shape[-1].size)
return mtf.reshape(t, new_shape=mtf.Shape(t.shape[:-1] + [new_last_dim]))

def _compute_q_with_experts(self, antecedent):
q = self.moe_layer.call(self.context, antecedent)
q = self._replace_d_model_dim(q)
return q

def _compute_kv_with_experts(self, antecedent):
kv = self.moe_layer.call(
self.context, antecedent, use_enc_nonpadding=self.is_encdec)
kv = self._replace_d_model_dim(kv)
return kv

def _compute_merge_qkv(self, antecedent):
"""Computes qkv all in one call using MoE layer."""
def _replace_d_model_dim(t):
"""Used to replace the `d_model` dim with `heads`."""
new_last_dim = mtf.Dimension(self.q_shape[-1].name, t.shape[-1].size)
return mtf.reshape(
t, new_shape=mtf.Shape(t.shape[:-1] + [new_last_dim]))
# This mode assumes query and memory antecedent are the same.
qkv = self.moe_layer.call(self.context, antecedent)
q, kv = qkv
q = self._replace_d_model_dim(q)
kv = self._replace_d_model_dim(kv)
self._q = q
self._kv = kv

def compute_q(self, query_antecedent):
if self.expert_computation == "qkv":
# NOTE: This assumes querty and memory antecedent are the same
qk = self.moe_layer.call(self.context, antecedent)
# Split qk here since they went through experts-layers
q, k = qk
q = _replace_d_model_dim(q)
k = _replace_d_model_dim(k)
self._compute_merge_qkv(query_antecedent)
q = self._q
elif self.expert_computation == "q":
q = self.moe_layer.call(self.context, antecedent)
q = _replace_d_model_dim(q)
# Compute key/value normally
k = mtf.layers.us_einsum(
[antecedent, self.wkv], reduced_dims=[self.memory_input_dim])
q = self._compute_q_with_experts(query_antecedent)
# If computing "kv" with experts, then compute q normally.
elif self.expert_computation == "kv":
k = self.moe_layer.call(self.context, antecedent)
k = _replace_d_model_dim(k)
# Compute query normally
q = mtf.layers.us_einsum(
[antecedent, self.wq], reduced_dims=[self.query_input_dim])
else:
raise ValueError("Invalid expert computation mode: {}".format(
self.expert_computation))

# Scale query
[query_antecedent, self.wq], reduced_dims=[self.query_input_dim])
q *= self.key_dim.size ** -0.5
self._q = mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims)
self._k = mtf.replace_dimensions(k, k.shape.dims[-1], self.k_dims)

def compute_q(self, query_antecedent):
self._compute_merge_qkv(query_antecedent)
return self._q
return mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims)

def compute_k(self, memory_antecedent):
del memory_antecedent
return self._k
raise NotImplementedError("ExpertsAttention uses shared_kv = True.")

def compute_kv(self, memory_antecedent):
del memory_antecedent
return self._k
if self.expert_computation == "qkv":
# We have already computing "kv" with "q", so just return its value.
kv = self._kv
# Check if the "length" dimension should be "memory_length" since both
# q and kv were computed using the same antecedent. This is why we must
# always have the same query and memory antecedent for the qkv mode.
if self.context.length_dim in kv.shape.dims:
memory_length = mtf.Dimension(
"memory_length", self.context.length_dim.size)
kv = mtf.replace_dimensions(
kv, self.context.length_dim, memory_length)
# If computing "q" with experts, then compute "kv" normally.
elif self.expert_computation == "q":
kv = mtf.layers.us_einsum(
[memory_antecedent, self.wkv], reduced_dims=[self.memory_input_dim])
elif self.expert_computation == "kv":
kv = self._compute_kv_with_experts(memory_antecedent)
kv = mtf.replace_dimensions(kv, kv.shape.dims[-1], self.k_dims)
return kv

def compute_v(self, memory_antecedent):
del memory_antecedent
raise NotImplementedError("ExpertsAttention uses shared_kv = True.")


Expand Down
32 changes: 26 additions & 6 deletions mesh_tensorflow/transformer/moe.py
Expand Up @@ -98,13 +98,30 @@ def __init__(self,
moe_top_n_num_experts_per_token=top_n_num_experts_per_token)
self._activation = activation

def call(self, context, x, losses=None):
def call(self, context, x, losses=None, use_enc_nonpadding=False):
"""Call the layer."""
if context.model.ensemble_dim:
raise NotImplementedError("MoE not yet implemented with ensembles")

has_length_dim = context.length_dim in x.shape.dims
if not has_length_dim:
has_memory_length_dim = "memory_length" in x.shape.dimension_names
# Used for EncDec attention if we have the MoE layer produce the kv.
if use_enc_nonpadding:
nonpadding = context.nonpadding_encoder
else:
nonpadding = context.nonpadding
# If a memory_length dimension exists, then we make sure the
# length dimension of the nonpadding tensor matches it.
if (has_memory_length_dim and isinstance(nonpadding, mtf.Tensor)
and "length" in nonpadding.shape.dimension_names):
old_length_dim = nonpadding.shape.get_dim_by_name("length")
new_length_dim = mtf.Dimension("memory_length", old_length_dim.size)
nonpadding = mtf.replace_dimensions(
nonpadding, old_length_dim, new_length_dim)
# Insert a length dimension if one does not exist.
# Typically no length dims will occur on the decoder during autoregressive
# decoding.
if not has_length_dim and not has_memory_length_dim:
x_shape = x.shape
shape_with_length = mtf.Shape(
x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
Expand All @@ -124,18 +141,21 @@ def call(self, context, x, losses=None):
context.variable_dtype,
layout=context.model.layout,
mesh_shape=context.model.mesh_shape,
nonpadding=context.nonpadding,
nonpadding=nonpadding,
activation=self._activation,
num_microbatches=context.num_microbatches,
token_embeddings=context.input_embeddings)
if context.losses is not None:
context.losses.append(loss)
if not has_length_dim:
if not has_length_dim and not has_memory_length_dim:
# Shapes will differ if the input and output dimension of the layer do not
# match.
new_y_shape = mtf.Shape(x_shape.dims[:-1] + [output_dim])
if self._hparams.moe_use_experts_attention:
y_reshape = [mtf.reshape(y_out, x_shape) for y_out in y]
y_reshape = [mtf.reshape(y_out, new_y_shape) for y_out in y]
y = y_reshape
else:
y = mtf.reshape(y, x_shape)
y = mtf.reshape(y, new_y_shape)
return y


Expand Down
11 changes: 11 additions & 0 deletions mesh_tensorflow/transformer/transformer.py
Expand Up @@ -304,6 +304,17 @@ def nonpadding(self):
return mtf.cast(
mtf.not_equal(self.sequence_id, 0), self.activation_dtype)

@property
def nonpadding_encoder(self):
"""Tensor with zeros in padding positions and ones elsewhere for encoder."""
if self.encoder_sequence_id is None:
return None
if self.encoder_sequence_id == 1:
return 1
else:
return mtf.cast(
mtf.not_equal(self.encoder_sequence_id, 0), self.activation_dtype)

def get_position(self):
if self.position_is_default:
return mtf.range(self.mesh, self.length_dim, tf.int32)
Expand Down
8 changes: 7 additions & 1 deletion mesh_tensorflow/transformer/transformer_layers.py
Expand Up @@ -410,6 +410,7 @@ def __init__(self,
**kwargs):
super(ExpertsSelfAttention, self).__init__(**kwargs)
self.expert_computation = expert_computation
self.is_encdec = False # Overrided in ExpertsEncDecAttention
self._hparams = mtf.transformer.moe.HParams(
moe_gating=moe_gating,
num_experts=num_experts,
Expand Down Expand Up @@ -465,7 +466,8 @@ def make_params(self, context):
fold_scaling_into_initializer=self.fold_scaling_into_initializer,
context=context,
experts_hparams=self._hparams,
expert_computation=self.expert_computation)
expert_computation=self.expert_computation,
is_encdec=self.is_encdec)


@gin.configurable
Expand All @@ -475,6 +477,10 @@ class ExpertsEncDecAttention(ExpertsSelfAttention):
def __init__(self, relative_attention_type=None, **kwargs):
super(ExpertsEncDecAttention, self).__init__(
relative_attention_type=relative_attention_type, **kwargs)
self.is_encdec = True
if self.expert_computation == "qkv":
raise ValueError("ExpertsEncDecAttention must use expert_computation of "
"q or kv.")

def _get_memory_antecedent(self, context):
return context.encoder_output
Expand Down

0 comments on commit 2456b1b

Please sign in to comment.