Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Make TransformerLM train reasonably well in trax. Adding loss and met…
Browse files Browse the repository at this point in the history
…ric masking and dropout refactor in Transformer.

PiperOrigin-RevId: 239692595
  • Loading branch information
Lukasz Kaiser authored and Copybara-Service committed Mar 21, 2019
1 parent eedd6d7 commit 0d840ee
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 39 deletions.
2 changes: 1 addition & 1 deletion tensor2tensor/trax/configs/resnet50_imagenet_8gb.gin
Expand Up @@ -42,5 +42,5 @@ train.eval_steps = 20
train.inputs = @trax.inputs.inputs
train.model = @trax.models.Resnet50
train.optimizer = @trax.optimizers.momentum
train.train_steps = 500000
train.train_steps = 1000000
train.lr_schedule = @learning_rate.EvalAdjustingSchedule
18 changes: 11 additions & 7 deletions tensor2tensor/trax/configs/transformer_lm1b_8gb.gin
Expand Up @@ -5,28 +5,32 @@ import tensor2tensor.trax.trax

# Parameters for batch_fun:
# ==============================================================================
batch_fun.batch_size = 32
batch_fun.eval_batch_size = 32
batch_fun.batch_size = 128
batch_fun.eval_batch_size = 128

# Parameters for inputs:
# ==============================================================================
inputs.data_dir = None
inputs.dataset_name = 't2t_languagemodel_lm1b32k'

# Parameters for mask:
# ==============================================================================
mask.mask_id = 0

# Parameters for MultifactorSchedule:
# ==============================================================================
MultifactorSchedule.constant = 0.05
MultifactorSchedule.constant = 0.1
MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay'
MultifactorSchedule.warmup_steps = 8000

# Parameters for preprocess_fun:
# ==============================================================================
preprocess_fun.max_target_length = 256
preprocess_fun.max_target_length = 512

# Parameters for train:
# ==============================================================================
train.eval_frequency = 1000
train.eval_steps = 1
train.eval_steps = 5
train.inputs = @trax.inputs.inputs
train.model = @trax.models.TransformerLM
train.run_debug_step = False
Expand All @@ -38,10 +42,10 @@ train_and_eval_batches.input_name = 'targets'

# Parameters for TransformerLM:
# ==============================================================================
TransformerLM.dropout = 0.1
TransformerLM.dropout = 0.2
TransformerLM.feature_depth = 512
TransformerLM.feedforward_depth = 2048
TransformerLM.max_len = 256
TransformerLM.max_len = 512
TransformerLM.mode = 'train'
TransformerLM.num_heads = 8
TransformerLM.num_layers = 6
Expand Down
6 changes: 4 additions & 2 deletions tensor2tensor/trax/inputs.py
Expand Up @@ -191,10 +191,12 @@ def batch_fun(dataset, training, shapes, target_names,
if variable_target_shapes:
bucket_boundaries = [bucket_length // 4, bucket_length // 2,
bucket_length, bucket_length * 2,
bucket_length * 4, bucket_length * 8]
bucket_length * 4, bucket_length * 8,
bucket_length * 16]
bucket_batch_sizes = [cur_batch_size * 4, cur_batch_size * 2,
cur_batch_size, cur_batch_size // 2,
cur_batch_size // 4, cur_batch_size // 8, 1]
cur_batch_size // 4, cur_batch_size // 8,
max(1, cur_batch_size // 16), 1]
buckets = (bucket_boundaries, bucket_batch_sizes)

if buckets:
Expand Down
53 changes: 28 additions & 25 deletions tensor2tensor/trax/models/transformer.py
Expand Up @@ -29,7 +29,7 @@ def TransformerEncoder(mode='train', # pylint: disable=invalid-name
feature_depth=512,
feedforward_depth=2048,
num_heads=8,
dropout=0.9):
dropout=0.1):
"""Transformer Encoder Stack.
Args:
Expand All @@ -38,20 +38,22 @@ def TransformerEncoder(mode='train', # pylint: disable=invalid-name
feature_depth: int: depth of embedding
feedforward_depth: int: depth of feed-forward layer
num_heads: int: number of attention heads
dropout: float: dropout rate - Stax follows TF's KEEP probability convention
dropout: float: dropout rate (how much to drop out; note that stax follows
Tensorflow's keep_rate convention, so we use 1 - dropout in calls below)
Returns:
A staxlayer for implementing a raw Transformer encoder stack. No embedding
or positional signals are added by this layer.
"""
keep_rate = 1.0 - dropout
# Multi-headed Attention and Feed-forward layers
multi_attention = stax.MultiHeadedAttention(
feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)
feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode)

feed_forward = stax.serial(
stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
stax.Relu,
stax.Dropout(dropout, mode=mode),
stax.Dropout(keep_rate, mode=mode),
stax.Dense(feature_depth, W_init=stax.xavier_uniform())
)

Expand All @@ -74,11 +76,11 @@ def encoder(embedded_source, source_mask):
stax.Identity, # value
source_mask), # attention mask
multi_attention,
stax.Dropout(dropout, mode=mode)),
stax.Dropout(keep_rate, mode=mode)),
# feed-forward
stax.residual(stax.LayerNorm(feature_depth),
feed_forward,
stax.Dropout(dropout, mode=mode))
stax.Dropout(keep_rate, mode=mode))
)
return stax.serial(
embedded_source,
Expand All @@ -95,8 +97,8 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
feature_depth=512,
feedforward_depth=2048,
num_heads=8,
dropout=0.9,
max_len=256):
dropout=0.1,
max_len=512):
"""Transformer language model (only uses the decoder part of Transformer).
Args:
Expand All @@ -106,20 +108,21 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
feature_depth: int: depth of embedding
feedforward_depth: int: depth of feed-forward layer
num_heads: int: number of attention heads
dropout: float: dropout rate - Stax follows TF's KEEP probability convention
dropout: float: dropout rate (how much to drop out)
max_len: int: maximum symbol length for positional encoding
Returns:
init and apply.
"""
keep_rate = 1.0 - dropout
# Multi-headed Attention and Feed-forward layers
multi_attention = stax.MultiHeadedAttention(
feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)
feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode)

feed_forward = stax.serial(
stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
stax.Relu,
stax.Dropout(dropout, mode=mode),
stax.Dropout(keep_rate, mode=mode),
stax.Dense(feature_depth, W_init=stax.xavier_uniform())
)

Expand All @@ -132,18 +135,18 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
stax.Identity, # value
stax.CausalMask(axis=-2)), # attention mask
multi_attention,
stax.Dropout(dropout, mode=mode)),
stax.Dropout(keep_rate, mode=mode)),
# feed-forward
stax.residual(stax.LayerNorm(feature_depth),
feed_forward,
stax.Dropout(dropout, mode=mode))
stax.Dropout(keep_rate, mode=mode))
)

return stax.serial(
stax.ShiftRight(),
stax.Embedding(feature_depth, vocab_size),
stax.PositionalEncoding(feature_depth, max_len=max_len),
stax.Dropout(dropout, mode=mode),
stax.Dropout(keep_rate, mode=mode),
stax.repeat(decoder_layer, num_layers),
stax.LayerNorm(feature_depth),
stax.Dense(vocab_size, W_init=stax.xavier_uniform()),
Expand All @@ -158,7 +161,7 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
feature_depth=512,
feedforward_depth=2048,
num_heads=8,
dropout=0.9,
dropout=0.1,
shared_embedding=True,
max_len=200,
return_evals=False):
Expand All @@ -172,7 +175,7 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
feature_depth: int: depth of embedding
feedforward_depth: int: depth of feed-forward layer
num_heads: int: number of attention heads
dropout: float: dropout rate - Stax follows TF's KEEP probability convention
dropout: float: dropout rate (how much to drop out)
shared_embedding: bool: specify whether source/target embeddings are tied.
max_len: int: maximum symbol length for positional encoding
return_evals: bool: whether to generate decode-time evaluation functions
Expand All @@ -182,11 +185,11 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
the 'evals' functions that itself returns a namedtuple containing evaluation
functions for the trained encoder, decoder, and generator substax.
"""

keep_rate = 1.0 - dropout
# Input embedding and positional encoding
inject_position = stax.serial(
stax.PositionalEncoding(feature_depth, max_len=max_len),
stax.Dropout(dropout, mode=mode)
stax.Dropout(keep_rate, mode=mode)
)
if shared_embedding:
assert source_vocab_size == target_vocab_size
Expand All @@ -202,12 +205,12 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name

# Multi-headed Attention and Feed-forward layers
multi_attention = stax.MultiHeadedAttention(
feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)
feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode)

feed_forward = stax.serial(
stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
stax.Relu,
stax.Dropout(dropout, mode=mode),
stax.Dropout(keep_rate, mode=mode),
stax.Dense(feature_depth, W_init=stax.xavier_uniform())
)

Expand All @@ -231,11 +234,11 @@ def encoder(source, source_mask):
stax.Identity, # value
source_mask), # attention mask
multi_attention,
stax.Dropout(dropout, mode=mode)),
stax.Dropout(keep_rate, mode=mode)),
# feed-forward
stax.residual(stax.LayerNorm(feature_depth),
feed_forward,
stax.Dropout(dropout, mode=mode))
stax.Dropout(keep_rate, mode=mode))
)
return stax.serial(
source,
Expand Down Expand Up @@ -266,19 +269,19 @@ def decoder(memory, target, target_mask, memory_mask):
stax.Identity, # value
target_mask), # attention mask
multi_attention,
stax.Dropout(dropout, mode=mode)),
stax.Dropout(keep_rate, mode=mode)),
# target attends to encoded source
stax.residual(stax.LayerNorm(feature_depth),
stax.multiplex(stax.Identity, # query
memory, # key
memory, # value
memory_mask), # attention mask
multi_attention,
stax.Dropout(dropout, mode=mode)),
stax.Dropout(keep_rate, mode=mode)),
# feed-forward
stax.residual(stax.LayerNorm(feature_depth),
feed_forward,
stax.Dropout(dropout, mode=mode))
stax.Dropout(keep_rate, mode=mode))
)
return stax.serial(
target,
Expand Down
20 changes: 16 additions & 4 deletions tensor2tensor/trax/trax.py
Expand Up @@ -46,26 +46,38 @@
from tensorflow.io import gfile


@gin.configurable(blacklist=["inputs", "targets"])
def masked_mean(inputs, targets, mask_id=None):
"""Mean of the inputs but counting only those where targets != mask_id."""
x = inputs.astype(np.float32)
if mask_id is None:
return np.mean(x)
unmask = 1.0 - np.equal(targets, mask_id).astype(np.float32)
return np.sum(x * unmask) / np.sum(unmask)


def accuracy(batch, model_predictions):
"""Calculate accuracy."""
_, targets = batch
predicted_class = np.argmax(model_predictions, axis=-1)
return np.mean(predicted_class == targets)
correct = np.equal(predicted_class, targets)
return masked_mean(correct, targets)


def neg_log_perplexity(batch, model_predictions):
"""Calculate negative log perplexity."""
_, targets = batch
hot_targets = stax.one_hot(targets, model_predictions.shape[-1])
return np.mean(np.sum(model_predictions * hot_targets, axis=-1))
xent = np.sum(model_predictions * hot_targets, axis=-1)
return masked_mean(xent, targets)


def loss(params, batch, model_predict):
"""Calculate loss."""
inputs, targets = batch
preds = model_predict(params, inputs)
return - np.mean(np.sum(preds * stax.one_hot(targets, preds.shape[-1]),
axis=-1))
xent = np.sum(preds * stax.one_hot(targets, preds.shape[-1]), axis=-1)
return - masked_mean(xent, targets)


def log(s, stdout=True):
Expand Down

0 comments on commit 0d840ee

Please sign in to comment.