In [1]:
# Do this to allow for local imports.
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
# Import from required modules.
from tommy2tommy.models.transformer import TransformerLM

import tensorflow as tf

In [3]:
# Set up the configuration hyperparameters.
config = {
    # Model/data hyperparameters.
    'vocab_size': 32,
    'length': 10,
    'num_layers': 2,
    'd_model': 32,
    'd_filter': 128,
    'num_heads': 8,
    'dropout_rate': 0.1,
    'ffn_activation': 'gelu',
    'layer_norm_epsilon': 1.0e-6,
    
    # Optimizer hyperparameters.
    'adam_learning_rate': 0.001,
    'adam_beta_1': 0.9,
    'adam_beta_2': 0.999,
    'adam_epsilon': 1.0e-7,
    
    # Training hyperparameters.
    'batch_size': 32,
    'num_epochs': 10,
    'training_steps': 1000,
}

In [4]:
# Prepare inputs, create the synthetic datasets.
def generate_input(vocab_size, length):
    assert length % 2 == 0
    half_len = (length - 2)//2
    while True:
        half_input = tf.random.uniform(shape=(half_len,), minval=1, maxval=vocab_size, dtype=tf.int32)
        full_input = tf.concat([[0], half_input, [0], half_input], axis=0)
        yield (full_input, full_input)

# Need to specify the output shapes.
training_dataset = tf.data.Dataset.from_generator(
    lambda: generate_input(config['vocab_size'], config['length']),
    output_types=(tf.int32, (tf.int32)),
    output_shapes=((config['length'],), (config['length'],)))

# Batch the training data, must drop the remainder in order for the input sizes to be consistent.
training_dataset = training_dataset.batch(config['batch_size'], drop_remainder=True)

In [5]:
# Set up the loss function, should only calculate loss on the copied half of outputs.
def loss_function(real, pred):
    real = real[:, config['length']//2:]
    pred = pred[:, config['length']//2:, :]
    loss = tf.keras.losses.sparse_categorical_crossentropy(real, pred, from_logits=True)
    return tf.reduce_mean(loss)

# Same as above for accuracy.
def accuracy(real, pred):
    real = real[:, config['length']//2:]
    pred = pred[:, config['length']//2:, :]
    return tf.keras.metrics.sparse_categorical_accuracy(real, pred)

In [6]:
# Use Adam optimizer. Works best with learning rate warmup, but this task is easy enough it's not necessary.
optimizer = tf.keras.optimizers.Adam(
    config['adam_learning_rate'],
    beta_1=config['adam_beta_1'],
    beta_2=config['adam_beta_2'],
    epsilon=config['adam_epsilon'])

In [7]:
# Build the language model and compile.
model = TransformerLM(config, padding_id=-1)  # No padding in our synthetic data.
model.compile(optimizer=optimizer, loss=loss_function, metrics=[accuracy])

In [8]:
# Train the model. Doesn't really make sense to validate since the input is randomly generated.
model.fit(training_dataset,
          epochs=config['num_epochs'],
          steps_per_epoch=config['training_steps'])

Epoch 1/10
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause

Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x203f157c340>

In [9]:
# Example inference. Note the extra padding token from the rightward shift in the language model.
# Note also that the model only learns the second half, due to our choice of loss function.
example = tf.constant([[0, 0, 1, 2, 3, 4, 0, 0, 0, 0]])
print(tf.argmax(model.predict(x=example), axis=2).numpy())

[[0 0 0 0 0 0 1 2 3 4]]


In [10]:
# The correct way to do inference is with a decoder search algorithm such as greedy search or beam search.
from tommy2tommy.utils.search import greedy_search

In [11]:
example = tf.constant([[0, 30, 1, 2, 11, 0]])
print(greedy_search(model, prefix=example, length=config['length']).numpy())

[[ 0 30  1  2 11  0 30  1  2 11]]
