<a href="https://colab.research.google.com/github/sahaja-kanuri/Transformer_XL/blob/main/Transformer_XL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Initializing

In [1]:
#GPU Info:

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Tue Jun  4 09:33:23 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   49C    P8              13W /  72W |      1MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 56.9 gigabytes of available RAM

You are using a high-RAM runtime!


In [3]:
import tensorflow as tf
import numpy as np
import keras

from keras.models import Model
from keras.layers import Dense, Dropout, LayerNormalization, Embedding
from keras.layers import Softmax, Activation, Add
from keras.optimizers import Adam
from keras.losses import SparseCategoricalCrossentropy

In [4]:
!pip install einops
from einops import rearrange #, repeat, pack, unpack, einsum



In [5]:
# Install pre-requisites:
!pip install datasets
from datasets import load_dataset

!pip install sacremoses
from transformers import TransfoXLTokenizer
import os
os.environ["TRUST_REMOTE_CODE"] = "True"

# Load the Dataset:
wikitext103 = load_dataset('wikitext', 'wikitext-103-v1')
print('The wikitext103 dataset: {wikitext103}')

train_wikitext103 = wikitext103['train']['text']
validation_wikitext103 = wikitext103['validation']['text']
test_wikitext103 = wikitext103['test']['text']
print("\n No. of training examples: ", len(train_wikitext103))
print("\n No. of validation examples: ", len(validation_wikitext103))
print("\n No. of test examples: ", len(test_wikitext103))

Collecting sacremoses
  Using cached sacremoses-0.1.1-py3-none-any.whl (897 kB)
Installing collected packages: sacremoses
Successfully installed sacremoses-0.1.1


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/722k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/156M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/156M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/655k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

The wikitext103 dataset: {wikitext103}

 No. of training examples:  1801350

 No. of validation examples:  3760

 No. of test examples:  4358


In [6]:
# Load the pre-trained tokenizer:
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl/transfo-xl-wt103',
                # padding_side='right', truncation_side='right',
                pad_token='pad_token')
vocab_size = len(tokenizer)
print(f'Vocab size is: {vocab_size}')

vocab.pkl:   0%|          | 0.00/9.14M [00:00<?, ?B/s]

vocab.bin:   0%|          | 0.00/9.14M [00:00<?, ?B/s]



config.json:   0%|          | 0.00/856 [00:00<?, ?B/s]

`TransfoXL` was deprecated due to security issues linked to `pickle.load` in `TransfoXLTokenizer`. See more details on this model's documentation page: `https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/transfo-xl.md`.


Vocab size is: 267736


In [7]:
# Identifying how much of the dataset is empty
word_counts = list(map(len, wikitext103['train']['text']))

print("Percentiles of character count in Wikitext103:")
length_quantiles = [0, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 1.0]
for percentile, length in zip(length_quantiles, np.quantile(word_counts, length_quantiles)):
  print(f" * {percentile:.1%}: {int(length):,} characters")

Percentiles of character count in Wikitext103:
 * 0.0%: 0 characters
 * 10.0%: 0 characters
 * 25.0%: 0 characters
 * 50.0%: 33 characters
 * 75.0%: 556 characters
 * 90.0%: 921 characters
 * 99.0%: 1,601 characters
 * 100.0%: 7,064 characters


In [8]:
wikitext_nonempty = wikitext103.filter(lambda x: len(x['text']) >= 10)
wikitext_nonempty

Filter:   0%|          | 0/4358 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3760 [00:00<?, ? examples/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 2881
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1161735
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 2461
    })
})

In [9]:
train_wikitext103 = wikitext_nonempty['train']['text']
validation_wikitext103 = wikitext_nonempty['validation']['text']
test_wikitext103 = wikitext_nonempty['test']['text']
print("\n No. of training examples after filtering: ", len(train_wikitext103))
print("\n No. of validation examples after filtering: ", len(validation_wikitext103))
print("\n No. of test examples after filtering: ", len(test_wikitext103))


 No. of training examples after filtering:  1161735

 No. of validation examples after filtering:  2461

 No. of test examples after filtering:  2881


In [10]:
# Initializing Transformer XL model dimensions:
# Note that 'vocab_size' has been pre-determined above using the TransfoXLTokenizer

max_sequence_length = 32  # maximum sequence length
train_batch_size = 16 # no. of sequences to process at once
# validation_batch_size = 32
num_heads = 4 # no. of attention heads
head_dim = 16 # size of each attention head
embed_dim = 64 # num_heads * head_dim # embedding dimention or d_model
max_iters = 50 # no. of iterations of gradient descent
learning_rate = 1e-3 # Adam optimization's learning rate
num_layers = 4 # no. of layers of Transformer decoder/block
num_buckets = 6 # relating to the logarithmic distribution in relative positional encodings

In [11]:
# Tokenizes the dataset:
tokenized_text_train = tokenizer(train_wikitext103,
                                 truncation=True, max_length=max_sequence_length+1,
                                 padding='max_length',
                                 return_attention_mask=True,
                                 return_overflowing_tokens=True,
                                 return_tensors='tf')
tokenized_text_validation = tokenizer(validation_wikitext103,
                                 truncation=True, max_length=max_sequence_length+1,
                                 padding='max_length',
                                 return_attention_mask=True,
                                 return_overflowing_tokens=True,
                                 return_tensors='tf')
tokenized_text_test = tokenizer(test_wikitext103,
                                 truncation=True, max_length=max_sequence_length+1,
                                 padding='max_length',
                                 return_attention_mask=True,
                                 return_overflowing_tokens=True,
                                 return_tensors='tf')

# Model

## Getting Batches

In [12]:
def get_batch(split, max_sequence_length, train_batch_size):
    '''
    Generates a batch of training/validation/test examples by selecting
    'train_batch_size' number of sequences at random of length
    'max_sequence_length'

    Arguments:
    split -- 'train' or 'validation' or 'test'
    max_sequence_length -- maximum length of the sequence
    train_batch_size -- number of sequences in a batch

    Outputs:
    x -- (train_batch_size, max_sequence_length)
    y -- (train_batch_size, max_sequence_length)
    '''

    # Tokenized dataset:
    if split == 'train':
        tokenized_data = tokenized_text_train
    elif split == 'validation':
        tokenized_data = tokenized_text_validation
    elif split == 'test':
        tokenized_data = tokenized_text_test

    # The output tokenized text consists of 0s at the positions where the text
    # has been padded to 'max_sequence_length+1':
    input_attention_mask = tokenized_data['attention_mask']
    tokenized_text = tokenized_data['input_ids']
    #shape: (len(data), max_sequence_length+1)
    padded_tokenized_text = tf.cast(tf.multiply(tokenized_text,
                                    input_attention_mask), dtype=tf.float32)
    #shape: (len(data), max_sequence_length+1)

    # Creates an array of random numbers of shape (train_batch_size,)
    # between 0 (incl.) len(tokenized_data) (excl.)
    ix = tf.random.uniform(shape=(train_batch_size,), minval=0,
                           maxval=len(tokenized_data), dtype=tf.int32)

    # Pick only the rows of the 1st argument given in the 2nd argument, ix
    # and off-sets the columns by 1 in y compared to x:
    x = tf.gather(padded_tokenized_text[:, :max_sequence_length], ix) #shape: (train_batch_size, max_sequence_length)
    y = tf.gather(padded_tokenized_text[:, 1:max_sequence_length+1], ix) #shape: (train_batch_size, max_sequence_length)

    return x, y #shape: (train_batch_size, max_sequence_length)

## Relative Positional Encoding layer class

**Relative Positional Embedding**

- Positional embeddings are added to the QK embeddings during attention
- Relative position embeddings identify, for each input example, how far away all the other tokens are from a specific token of interest
- Instead of giving each token a relative position index of n that is n positions away from our token of interest, T5 relative position "buckets" some tokens into the same index
- First we create this set of indices. then the indices are matched to an embedding layer of weight values. These values are then added to the QK embeddings during attention. The positional embeddings are trained with the network.


One sentence that is 'max_sequence_length' tokens long which gets split into 'max_sequence_length' number of examples ^

In [13]:
#1.Construct a relative position matrix
#2.For offsets larger than what we want, start to spread offset values logarithmically into a finite amount of buckets. (Past a certian max value (128) we'll just map everything to one value)
#3.Initialize embedding weights that we will assign offset values to
#4.Now the relative position matrix is mapped to these weights
#5.This matrix gets added to our attention when we perform self-attention. Our self-attention now incorporates as a piece of information the relative positions between tokens

# Relative Positional Encoding:

class relative_positional_encoding_xl(tf.keras.layers.Layer):
    '''
    Constructs a relative position matrix *for one head* to be added to the Attention score

    Arguments:
        rp_scale -- a normalising scale factor
        num_buckets -- default value
        num_heads -- default value = 1
        rp_max_distance -- default value
        sequence_length -- query length/ input sequence length,
                           default value = max_sequence_length

    Returns:
        relative_position_values * self.scale -- shape: (1, num_heads, sequence_length, 2*sequence_length)
    '''
    def __init__(self, rp_scale, num_buckets=num_buckets, num_heads=1,
                 rp_max_distance=max_sequence_length):
        super().__init__()
        # rp stands for relative position
        self.scale = rp_scale
        self.num_buckets = num_buckets
        self.num_heads = num_heads #=1 for single head of attention, later we concatenate for MHA
        self.rp_max_distance = rp_max_distance
        self.relative_position_embedding = Embedding(self.num_buckets, self.num_heads)

    def relative_position_bucket(self, relative_position_matrix):

        # For a decoder model, masking upper triangle after inverting sign:
        n = -relative_position_matrix
        n = tf.math.maximum(n, tf.zeros_like(n))

        # T5 modifications: half of the buckets are for exact increments in position
        max_exact = self.num_buckets//2

        # MASK 1:
        is_small = n < max_exact

        # MASK 2:
        val_if_large = max_exact + (tf.math.log(tf.cast(n, tf.float32) / max_exact) / tf.math.log(self.rp_max_distance / max_exact) * (self.num_buckets - max_exact) )
        val_if_large = tf.cast(val_if_large, dtype=tf.int32)
        val_if_large = tf.minimum(val_if_large, tf.fill(tf.shape(val_if_large), self.num_buckets - 1))

        return tf.where(is_small, n, val_if_large)

    def call(self, sequence_length):
        sequence_pos = tf.range(sequence_length)
        sequence_pos = tf.reshape(sequence_pos, [sequence_pos.shape[0], 1])
        context_pos = tf.range(2*sequence_length)
        #context_pos = tf.range((-sequence_length, sequence_length)
        rel_pos = context_pos - sequence_pos

         # There is no logarithm scaling on the first three, i.e, it stays 3 2 1 0:
        rel_pos_indices = self.relative_position_bucket(rel_pos)

        # We use the above matrix to index into a positional embedding matrix
        # which we initialise randomly to be trained by the model
        # which then gets added to the Attention Scores.
        # There is one of these for each head
        rel_pos_values = self.relative_position_embedding(rel_pos_indices)

        # Need to reshape from (sequence, context, heads) -> (batch, heads, sequence, context):
        rel_pos_values = rearrange(rel_pos_values, 'i j h -> () h i j')
        #equivalently:
        #rel_pos_values = tf.expand_dims(tf.transpose(rel_pos_values, perm=[2,0,1]), axis=0)

        return rel_pos_values * self.scale

## Attention layer class

In [14]:
class Attention_Head_xl(tf.keras.layers.Layer):
    '''
    Computes Attention score (weights) for one head of self-attention,
    generates output to be concatenated in Multi-Head Attention,
    and caches Extended Context

    Arguments:
        head_size -- Size of each head (=embed_dim/num_heads),

        x -- Input tensor shape: (batch_size, seq_length, embed_dim),
        relative_positions -- defaults to None for input to the 1st sequence,
        shape: (1, num_heads=1, seq_length, seq_length) for input to the 2nd sequence,
        and (1, num_heads=1, seq_length, 2*seq_length) for the rest,
        extended_context -- Cached keys and values from prior segment (acts as
                            memory) of shape: (batch_size, seq_length, 2, head_size)

    Returns:
        output -- shape: (batch_size, seq_length, head_size)
        weights -- shape: (batch_size, num_heads=1, seq_length, 2*seq_length)
        extended_context_cache -- shape: (batch_size, seq_length, 2, head_size)
    '''

    def __init__(self, head_size, dropout_rate=0.1, layernorm_eps=1e-6):

        super().__init__()

        self.head_size = head_size
        self.scale = self.head_size**-0.5

        self.query = Dense(self.head_size, use_bias=False)
        self.key = Dense(self.head_size, use_bias=False)
        self.value = Dense(self.head_size, use_bias=False)

        # self.mask = self.add_weight(
        #     name='mask',
        #     shape=(seq_length, 2*seq_length),
        #     initializer=tf.keras.initializers.Constant(
        #         tf.linalg.band_part(tf.ones((seq_length, 2*seq_length)), 0, seq_length-1)),
        #     trainable=False)

        self.layer_norm_q = LayerNormalization(epsilon=layernorm_eps)
        self.layer_norm_k = LayerNormalization(epsilon=layernorm_eps)
        self.dropout = Dropout(dropout_rate)

    def call(self, x, relative_positions=None, extended_context=None,
             training=False):
        # maybe an unnecesaary step since these are predefined variables:
        #batch_size, seq_length, head_size = x.shape # embed_dim = x.shape[2], not head size -- check again and again

        # Local Attention:
        q = self.query(x) #(batch_size, seq_length, head_size)
        k = self.key(x) #(batch_size, seq_length, head_size)
        v = self.value(x) #(batch_size, seq_length, head_size)

        # To mitigate model drift/ covariant shift:
        q = self.layer_norm_q(q)
        k = self.layer_norm_k(k)

        if extended_context is not None:
            # Unpack extended_context and concatentate with keys and values
            xl_keys, xl_values = tf.unstack(extended_context, axis=-2)
            # shape: (batch_size, seq_length, head_size)

            xl_seq_length = xl_keys.shape[1]

            # Prepend k and v values along the 'seq_length' axis:
            k = tf.concat([xl_keys, k], axis=-2)
            v = tf.concat([xl_values, v], axis=-2)
            # shape: (batch_size, 2*seq_length, head_size)

        # Compute the dot product of attention scores:
        weights = tf.linalg.matmul(q, k, transpose_b=True) * self.scale
        # if extended_context is not None:
        # shape: (batch_size, seq_length, head_size) * (batch_size, head_size, 2*seq_length)
        # = (batch_size, seq_length, 2*seq_length)
        # the above transposes only the two last axes of the second given tensor

        i, j = weights.shape[-2:] # i = seq_length, j = 2*seq_length

        # Adding relative positional encodings to the attention weights:
        if relative_positions is not None:
            weights = tf.expand_dims(weights, axis=1) # weights for each head (axis=1)
            #shape: = (batch_size, num_heads=1, seq_length, 2*seq_length)

            weights += relative_positions[..., -i:, -j:] # shape of rel_pos: (1, num_heads=1, seq_length, 2*seq_length)
            # shape: (batch_size, num_heads=1, seq_length, 2*seq_length)

            weights = tf.squeeze(weights, axis=1)

        # Creates a look-ahead mask (for a causal language model):
        mask = tf.linalg.band_part(tf.ones([i, j], dtype=tf.bool), -1, j-i) == False
        weights = tf.where(mask, tf.fill((i, j), float('-inf')), weights)

        # Perform softmax along the last axis, i.e. along each row/sequence
        weights = Softmax(axis=-1)(weights) # shape: (batch_size, num_heads=1, seq_length, 2*seq_length)
        weights = self.dropout(weights, training=training)

        # Computes the attention scores ("affinities"):
        output = tf.linalg.matmul(weights, v)
        # (batch_size, seq_length, 2*seq_length) * (batch_size, 2*seq_length, head_size)
        # = (batch_size, seq_length, head_size)

        # Passing on keys and values:
        kv_memory = tf.stack([k,v], axis=-2)
        # shape: (batch_size, 2*seq_length, 2, head_size), if extended_context is not None
        # But we want (batch_size, seq_length, 2, head_size), which is the case when extended_context is None.
        # For this:

        if extended_context is not None:
            #for all sequences except the first:
            xl_memory, current_input = kv_memory[:, :-xl_seq_length, :, :], kv_memory[:, -xl_seq_length:, :, :]
            extended_context_cache = current_input
            # shape: (batch_size, seq_length, 2, head_size)
            # discard xl_memory
        else:
            #for the first sequence, we don't need to split them out:
            extended_context_cache = kv_memory
            # shape: (batch_size, seq_length, 2, head_size)

        return output, weights, tf.stop_gradient(extended_context_cache)

## Multi-Head Attention layer class

In [15]:
class Multi_Head_Attention_xl(tf.keras.layers.Layer):
    '''
    Uses Attention_Head to implement multiple heads of self-attention

    Arguments:
        head_size -- Size of each head (a type of communication channel, C),
        num_heads -- Number of heads,
        x -- Input tensor shape: (batch_size, seq_length, embed_dim)
        relative_positions_mha -- relative positional encoding matrix
        extended_context_mha -- defaults to None, or of shape: (batch_size, seq_length, 2, head_size * num_heads)

    Returns:
        out -- shape: (batch_size, seq_length, head_size * num_heads=embed_dim),
        where head_size * num_heads = embed_dim ,
        extended_context_mha -- shape: (batch_size, seq_length, 2, head_size)
    '''
    def __init__(self, head_size, num_heads, dropout_rate=0.1):

        super().__init__()

        self.head_size = head_size
        self.num_heads = num_heads
        self.embed_dim = self.head_size * self.num_heads

        self.heads = [Attention_Head_xl(self.head_size) for _ in range(self.num_heads)]
        self.head = Attention_Head_xl(self.head_size)
        self.projection = Dense(self.embed_dim)
        self.dropout = Dropout(dropout_rate)

    def call(self, x, relative_positions_mha=None, extended_context_mha=None, training=False):
        '''
        training: boolean, i.e. either 'True' or 'False'
        Let extended_context_mha stay None when initialising, since for the
        first sequence, there will be no prior context.
        '''

        '''
        Alternatively:
        for i, head in enumerate(self.heads):
            #print(f'i = {i}')
            output, _, extended_context_mha = head(x,
                                                    relative_positions=relative_positions_mha,
                                                    extended_context=extended_context_mha,
                                                    training=training)
            # shape of output: (batch_size, seq_length, head_size)
            out.append(output)
        '''
        out = []
        for i in range(self.num_heads):
            # pass x and the encoder output through a stack of decoder layers and
            # save the attention weights of block 1 and 2
            output, _, extended_context_mha = self.heads[i](x,
                                                    relative_positions=relative_positions_mha,
                                                    extended_context=extended_context_mha,
                                                    training=training)
            out.append(output)

        out = tf.concat(out, axis=-1)
        # shape: (batch_size, seq_length, head_size * num_heads=embed_dim)

        out = self.projection(out)
        out = self.dropout(out, training=training)

        return out, extended_context_mha

## Feed Forward Neural Network layer class

In [16]:
class FeedForward(tf.keras.layers.Layer):
    '''
    Implements a feed-forward network

    Arguments:
        embed_dim -- Size of the input embedding
        x -- Input tensor shape: (batch_size, seq_length, head_size * num_heads=embed_dim)

    Returns:
        output -- shape: (batch_size, seq_length, embed_dim)
    '''
    def __init__(self, embed_dim, dropout_rate=0.1):
        super().__init__()
        full_connected_dim = 4 * embed_dim
        self.ffn = tf.keras.Sequential([
                    Dense(full_connected_dim, activation='relu'),
                    Dropout(dropout_rate),
                    Dense(embed_dim),
                    Dropout(dropout_rate)
                    ])
    def call(self, x, training=False):
        return self.ffn(x, training=training)

## Transformer Decoder layer class

In [17]:
class Transformer_xl_Decoder(tf.keras.layers.Layer):
    '''
    Implements the transformer decoder block

    Arguments:
        num_heads -- Number of heads
        head_size -- size of each head
        x -- Input tensor shape: (batch_size, seq_length, embed_dim)
        rel_pos_decoder -- defaults to None, relative positional encoding matrix
        ext_context -- defaults to None, of shape:
                        (batch_size, seq_length, 2, head_size * num_heads)

    Returns:
        output -- shape: (batch_size, seq_length, embed_dim)
        ext_context -- shape: (batch_size, seq_length, 2, head_size)
    '''
    def __init__(self, num_heads, head_size, layernorm_eps=1e-6):

        super().__init__()

        self.embed_dim = num_heads * head_size

        self.Multi_Head_Attention = Multi_Head_Attention_xl(head_size, num_heads)
        self.FeedForward = FeedForward(self.embed_dim)
        self.layernorm1 = LayerNormalization(epsilon=layernorm_eps)
        self.layernorm2 = LayerNormalization(epsilon=layernorm_eps)

    def call(self, x, rel_pos_decoder=None, ext_context=None, training=False):
        # Skip connection/ residual connection,
        # layer norm is applied before mha, a slight deviation from
        # the original Attention paper:

        residual = tf.identity(x)
        #residual = tf.expand_dims(x, axis=-1)
        # shape of residual: (batch_size, seq_length)

        x, ext_context = self.Multi_Head_Attention(self.layernorm1(x),
                                                relative_positions_mha=rel_pos_decoder,
                                                extended_context_mha=ext_context,
                                                training=training)
        #shape of x: (batch_size, seq_length, head_size * num_heads)

        x = Add()([residual, x])

        # Skip connection/ residual connection:
        x = Add()([x, self.FeedForward(self.layernorm2(x), training=training)])
        #shape: (batch_size, seq_length, embed_dim)
        return x, ext_context

## Transformer XL model class

In [18]:
class Transformer_xl(tf.keras.models.Model):
    '''
    A Model class consisting of all trainable layers, implementing the Transformer
    block 'num_layers' number of times

    Arguments:
        num_layers -- Number of layers of the Transformer Block
        embed_dim -- Input embedding dimension
        num_heads -- Number of heads in Multi-Head Attention
        seq_length -- Maximum sequence length of each input sentence
        rp_scale -- a normalising scale factor for the Relative Positional Encoding matrix

        idx -- Input training data, shape: (batch_size, seq_length), which is the
                output of the 'get_batch' function
        ext_context -- defaults to None, of shape:
                        (batch_size, seq_length, 2, head_size * num_heads)

    Returns:
        logits -- shape: (batch_size, sequence_length, vocab_size)

    '''
    def __init__(self, num_layers, embed_dim, num_heads, seq_length,
                rp_scale, layernorm_eps=1e-6):
        super().__init__()

        self.num_layers = num_layers
        self.embed_dim = embed_dim
        self.seq_length = seq_length
        self.num_heads = num_heads
        self.head_size = embed_dim//num_heads
        self.rp_scale = rp_scale

        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = Embedding(vocab_size, self.embed_dim)
        self.position_embedding_table = Embedding(self.seq_length, self.embed_dim)

        # Generates the relative positional encoding matrix
        self.rel_pos_enc = relative_positional_encoding_xl(self.rp_scale)

        self.decoder_layers = [Transformer_xl_Decoder(num_heads= self.num_heads,
                                                      head_size = self.head_size)
                                for _ in range(self.num_layers)]

        self.head = Dense(vocab_size)
        self.layernorm = LayerNormalization(epsilon=layernorm_eps) #final layer norm

    def call(self, idx, ext_context=None, training=False):
        B, T = idx.shape # B= batch_size, T = seq_length

        tok_emb = self.token_embedding_table(idx) # (B,T,embed_dim)
        pos_emb = self.position_embedding_table(tf.range(T)) # (T,embed_dim)
        x = tok_emb + pos_emb #  (B,T,embed_dim)

        rel_pos = self.rel_pos_enc.call(T)
        # shape: (1, num_heads, T, 2*T)

        for i in range(self.num_layers):
            # pass x and the encoder output through a stack of decoder layers and
            # save the attention weights of block 1 and 2
            x, ext_context = self.decoder_layers[i](x, rel_pos_decoder=rel_pos,
                                                    ext_context=ext_context, training=training)
            # (B,T,embed_dim)

        x = self.layernorm(x) # (B,T,embed_dim)
        logits = self.head(x) # (B,T,vocab_size)

        # B, T, C = logits.shape
        # logits = logits.view(B*T, C)
        # targets = targets.view(B*T)
        # loss = CategoricalCrossentropy()(targets, logits)

        return logits

    def generate(self, idx, max_new_tokens):
        '''
        Generates 'max_new_tokens' number of tokens for an input, 'idx'
        '''
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.seq_length:]
            # get the predictions
            logits = self(idx_cond) #logits, loss previously
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = Softmax(axis=-1)(logits) # (B, C)
            # sample from the distribution
            idx_next = tf.random.categorical(logits=probs,  num_samples=1) # (B, 1)
            idx_next = tf.cast(idx_next, tf.int32)
            #input to tf.concat must be of type int32

            # append sampled index to the running sequence
            idx = tf.concat([idx, idx_next], axis=1) # (B, T+1)

        return idx

# Training

In [19]:
model = Transformer_xl(num_layers=num_layers, embed_dim=embed_dim, num_heads=num_heads,
                       seq_length=max_sequence_length, rp_scale = 0.5)

In [20]:
xb, yb = get_batch('train', max_sequence_length=max_sequence_length,
                       train_batch_size=train_batch_size)
modeltemp = model(xb)
print(f'Model output logits shape: {modeltemp.shape}')
model.summary()

Model output logits shape: (16, 32, 267736)
Model: "transformer_xl"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       multiple                  17135104  
                                                                 
 embedding_1 (Embedding)     multiple                  2048      
                                                                 
 relative_positional_encodi  multiple                  0 (unused)
 ng_xl (relative_positional                                      
 _encoding_xl)                                                   
                                                                 
 transformer_xl__decoder (T  multiple                  50048     
 ransformer_xl_Decoder)                                          
                                                                 
 transformer_xl__decoder_1   multiple                  50048     
 (Transf

In [21]:
for iter in range(max_iters):
    # Sample a batch of data for training:
    xb, yb = get_batch('train', max_sequence_length=max_sequence_length,
                       train_batch_size=train_batch_size)

    # Evaluate the loss
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss=SparseCategoricalCrossentropy(),
        metrics=['accuracy']
        )

    print(f'{iter+1}. Training:')
    # Fit the model to the training data:
    model.fit(xb, yb)

    # Sample a batch of data for testing:
    xb_valid, yb_valid = get_batch('validation', max_sequence_length=max_sequence_length,
                       train_batch_size=train_batch_size)

    print(f'Validation:')
    loss_validation = model.evaluate(xb_valid, yb_valid)

1. Training:
Validation:
2. Training:
Validation:
3. Training:
Validation:
4. Training:
Validation:
5. Training:




Validation:




6. Training:




Validation:




7. Training:
Validation:
8. Training:
Validation:
9. Training:
Validation:
10. Training:
Validation:
11. Training:
Validation:
12. Training:
Validation:
13. Training:
Validation:
14. Training:
Validation:
15. Training:
Validation:
16. Training:
Validation:
17. Training:
Validation:
18. Training:
Validation:
19. Training:
Validation:
20. Training:
Validation:
21. Training:
Validation:
22. Training:
Validation:
23. Training:
Validation:
24. Training:
Validation:
25. Training:
Validation:
26. Training:
Validation:
27. Training:
Validation:
28. Training:
Validation:
29. Training:
Validation:
30. Training:
Validation:
31. Training:
Validation:
32. Training:
Validation:
33. Training:
Validation:
34. Training:
Validation:
35. Training:
Validation:
36. Training:
Validation:
37. Training:
Validation:
38. Training:
Validation:
39. Training:
Validation:
40. Training:
Validation:
41. Training:
Validation:
42. Training:
Validation:
43. Training:
Validation:
44. Training:
Validation:
45. Training:
V

In [22]:
context = tf.zeros((1, 1), dtype=tf.int32)
# prediction = model.predict(context)
# print(prediction.shape)

In [23]:
# generate from the model
generated_tokens = model.generate(context, max_new_tokens=100)
print(tokenizer.decode(generated_tokens[0].numpy().tolist(), skip_special_tokens=True))

Sourou Squalor Montgomeryville carboxy ticketless Moala Lagoons Exhibitor 8x Coronal Ã¥rs nucifera hybridus Mattey Harmonious Portuondo numeration Ke7 Harthama CPJ DAN cAMP Naguabo Preserving Tyger perforator Nanib physiognomy allayed Bitterroot murage Carnifax grazier Tuned Kah pollinating Barium Geeks subclass antecedent Biosphere Endara grotesqueness portamento GrÃ¶na Wenck Tuculescu Peyote Namjoshi Groupies Krul FATHERcopy hnefatafl Kojiro Stimula TWV pathognomonic tidily Handbot ElvisFest Mutare Mamut logothete Cuala Odobenus Iosefo Scheemda biphasic upperwing illegitimately Fushun Bergenhus TECH Incredibly Truiden Nicalis Pakis Eltham dystonia Masuko Crucifixus RMC Tweep carbons capacitive semiconducting Japonesque Savoie Sorachi Maternal Kingstown Didi analgesic Rotem Cultivated Hegel Challenges Valleys Vindication diction
