## The Transformer architecture

In [1]:
import random
import tensorflow as tf
from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.utils import text_dataset_from_directory
from tensorflow.keras.layers import TextVectorization, Bidirectional, LSTM, Dropout, Dense, Layer, Embedding
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, GlobalMaxPooling1D
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
from tensorflow.keras.initializers import Constant
import numpy as np

In [2]:
batch_size = 32
train_ds = text_dataset_from_directory("aclImdb/train", batch_size=batch_size)
validation_ds = text_dataset_from_directory("aclImdb/val", batch_size=batch_size)
test_ds = text_dataset_from_directory("aclImdb/test", batch_size=batch_size)

Found 20000 files belonging to 2 classes.
Found 5000 files belonging to 2 classes.
Found 25000 files belonging to 2 classes.


**Vectorizing the data**

In [3]:
def clean_text(text):

    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text, r'[^\x00-\x7F]+', '')  
    text = tf.strings.regex_replace(text, r"[^a-zA-Z0-9\s.,!?']", "")
    
    return text

def preprocess_text(text, label):
    
    text = clean_text(text)
    
    return text, label

In [4]:
train = train_ds.map(preprocess_text)
validation = validation_ds.map(preprocess_text)
test = test_ds.map(preprocess_text)

In [5]:
sequence_length = 600
max_tokens = 20000
vectorizer = TextVectorization(max_tokens=max_tokens, output_mode="int", output_sequence_length=sequence_length)

In [6]:
train_text = train.map(lambda x, y: x)
vectorizer.adapt(train_text)

In [7]:
int_train_ds = train.map(lambda x, y: (vectorizer(x), y), num_parallel_calls=4)
int_val_ds = validation.map(lambda x, y: (vectorizer(x), y), num_parallel_calls=4)
int_test_ds = test.map(lambda x, y: (vectorizer(x), y), num_parallel_calls=4)

**Transformer**

In [8]:
class TransformerEncoder(Layer):

    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = Sequential([Dense(dense_dim, activation="relu"), Dense(embed_dim)])
        self.layernorm_1 = LayerNormalization()
        self.layernorm_2 = LayerNormalization()

    def call(self, inputs, mask=None):
        
        if mask is not None:
            mask = mask[:, tf.newaxis, :]
        
        attention_output = self.attention(inputs, inputs, attention_mask=mask)
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        
        config = super().get_config()
        
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "dense_dim": self.dense_dim
        })
        
        return config

**Implementing positional embedding**

In [12]:
class PositionalEmbedding(Layer):
    
    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = Embedding(input_dim=input_dim, output_dim=output_dim)
        self.position_embeddings = Embedding(input_dim=sequence_length, output_dim=output_dim)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim = output_dim

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        
        return tf.math.not_equal(inputs, 0)

    def get_config(self):
        
        config = super().get_config()
        config.update({
            "output_dim": self.output_dim,
            "sequence_length": self.sequence_length,
            "input_dim": self.input_dim,
        })
        
        return config

In [13]:
vocab_size = 20000
sequence_length = 600
embed_dim = 256
num_heads = 2
dense_dim = 32

**Combining the Transformer encoder with positional embedding**

In [14]:
inputs = Input(shape=(None,), dtype="int64")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)
x = TransformerEncoder(embed_dim, dense_dim, num_heads)(x)
x = GlobalMaxPooling1D()(x)
x = Dropout(0.5)(x)
outputs = Dense(1, activation="sigmoid")(x)
model = Model(inputs, outputs)
model.compile(optimizer="rmsprop", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, None)]            0         
                                                                 
 positional_embedding (Posit  (None, None, 256)        5273600   
 ionalEmbedding)                                                 
                                                                 
 transformer_encoder (Transf  (None, None, 256)        543776    
 ormerEncoder)                                                   
                                                                 
 global_max_pooling1d (Globa  (None, 256)              0         
 lMaxPooling1D)                                                  
                                                                 
 dropout (Dropout)           (None, 256)               0         
                                                             

In [15]:
callbacks = [ModelCheckpoint("full_transformer_encoder.keras", save_best_only=True)]
model.fit(int_train_ds, validation_data=int_val_ds, epochs=20, callbacks=callbacks)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x1f22a3cac40>

In [16]:
model = load_model("full_transformer_encoder.keras", 
                   custom_objects={"TransformerEncoder": TransformerEncoder, "PositionalEmbedding": PositionalEmbedding }
                  )

In [17]:
print(f"Test acc: {model.evaluate(int_test_ds)[1]:.3f}")

Test acc: 0.882
