In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint

import pickle

In [2]:
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, mask=None):
        attention_output = self.attention(query=inputs, value=inputs, key=inputs)
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "dense_dim": self.dense_dim,
                "num_heads": self.num_heads,
            }
        )
        return config


class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config


class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(latent_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        self.add = layers.Add()  # instead of `+` to preserve mask
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, mask=None):
        attention_output_1 = self.attention_1(
            query=inputs, value=inputs, key=inputs, use_causal_mask=True
        )
        out_1 = self.layernorm_1(self.add([inputs, attention_output_1]))

        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
        )
        out_2 = self.layernorm_2(self.add([out_1, attention_output_2]))

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(self.add([out_2, proj_output]))

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "latent_dim": self.latent_dim,
                "num_heads": self.num_heads,
            }
        )
        return config

In [3]:
def data_generator(batch_size=32, max_sentence_length=20):
    while True:
        encoder_inputs = []
        targets = []

        for _ in range(batch_size):
            idx = np.random.randint(len(tokenized_sentences))
            sentence = tokenized_sentences[idx]
            id_sentence = [word_to_id[word] for word in sentence if word in word_to_id]

            if len(id_sentence) <= 1:
                continue  # Skip sentences that are too short

            # Input sequence
            input_sequence = id_sentence[:max_sentence_length]
            input_sequence = input_sequence + [word_to_id['<PAD>']] * (max_sentence_length - len(input_sequence))

            # Target sequence with shifted tokens
            target_sequence = id_sentence[1:max_sentence_length] + [word_to_id['<stop>']]
            target_sequence = target_sequence + [word_to_id['<PAD>']] * (max_sentence_length - len(target_sequence))

            encoder_inputs.append(input_sequence)
            targets.append(target_sequence)

        yield (
            {
                "encoder_inputs": np.array(encoder_inputs),
            },
            np.array(targets),
        )

In [4]:
embed_dim = 128
latent_dim = 512
num_heads = 4

vocab_size = 50003
sequence_length = 40
batch_size = 64

encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
encoder = keras.Model(encoder_inputs, encoder_outputs)

decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)
x = layers.Dropout(0.5)(x)
decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)

decoder_outputs = decoder([decoder_inputs, encoder_outputs])
transformer = keras.Model(
    [encoder_inputs, decoder_inputs], decoder_outputs, name="transformer"
)

In [5]:
# load the word_to_id mapping from a file
with open('utils/word_to_id_mapping.pkl', 'rb') as f:
    word_to_id = pickle.load(f)

# load the id_to_word mapping from a file
with open('utils/id_to_word_mapping.pkl', 'rb') as f:
    id_to_word = pickle.load(f)
    
# load the tokenized sentences from a file
with open('utils/tokenized_sentences.pkl', 'rb') as f:
    tokenized_sentences = pickle.load(f)

# load the sentences from a file
with open('utils/sentences.pkl', 'rb') as f:
    sentences = pickle.load(f)

In [6]:
word_to_id['<start>'] = 50001
word_to_id['<stop>'] = 50002

id_to_word[50001] = '<start>'
id_to_word[50002] = '<stop>'

In [7]:
def data_generator(batch_size=32, max_sentence_length=20):
    while True:
        encoder_inputs = []
        decoder_inputs = []
        targets = []

        for _ in range(batch_size):
            idx = np.random.randint(len(tokenized_sentences))
            sentence = tokenized_sentences[idx]
            id_sentence = [word_to_id[word] for word in sentence if word in word_to_id]

            if len(id_sentence) <= 1:
                continue  # Skip sentences that are too short

            encoder_input = id_sentence[:max_sentence_length]
            encoder_input = encoder_input + [word_to_id['<PAD>']] * (max_sentence_length - len(encoder_input))

            decoder_input = [word_to_id['<start>']] + id_sentence[:max_sentence_length - 2] + [word_to_id['<stop>']]
            decoder_input = decoder_input + [word_to_id['<PAD>']] * (max_sentence_length - len(decoder_input))

            target = id_sentence[1:max_sentence_length - 1] + [word_to_id['<stop>']]
            target = target + [word_to_id['<PAD>']] * (max_sentence_length - len(target))

            encoder_inputs.append(encoder_input)
            decoder_inputs.append(decoder_input)
            targets.append(target)

        yield (
            {
                "encoder_inputs": np.array(encoder_inputs),
                "decoder_inputs": np.array(decoder_inputs),
            },
            np.array(targets),
        )

data_generator = data_generator(batch_size)

In [8]:
class CustomCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):

        max_decoded_sentence_length = 20
        sentence = ['dawno', 'dawno', 'temu', 'było'] 
        tokenized_input_sentence = [word_to_id[word] for word in sentence if word in word_to_id]

        decoded_sentence = [word_to_id["<start>"]]
        for i in range(max_decoded_sentence_length):
            tokenized_target_sentence = decoded_sentence + [word_to_id['<PAD>']] * (max_decoded_sentence_length - len(decoded_sentence))
            tokenized_target_sentence = np.array(tokenized_target_sentence).reshape(1, -1)
            tokenized_input_sentence = np.array(tokenized_input_sentence).reshape(1, -1)

            predictions = self.model([tokenized_input_sentence, tokenized_target_sentence])
            sampled_token_index = np.argmax(predictions[0, i, :])
            sampled_token = id_to_word[sampled_token_index]
            
            decoded_sentence.append(sampled_token_index)

            # if sampled_token == "<stop>":
            #     break

        # Convert the decoded sequence to words
        decoded_sentence = [id_to_word[token_id] for token_id in decoded_sentence]   
        
        sentence = " ".join(sentence)
        decoded_sentence = " ".join(decoded_sentence)
        print()
        print(sentence, decoded_sentence)

    
checkpoint_callback = ModelCheckpoint(filepath='transformer_{epoch:02d}.h5', save_freq=1000)


In [9]:
transformer.summary()
transformer.compile(
    "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)

Model: "transformer"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 encoder_inputs (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 positional_embedding (Position  (None, None, 128)   6405504     ['encoder_inputs[0][0]']         
 alEmbedding)                                                                                     
                                                                                                  
 decoder_inputs (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 transformer_encoder (Transform  (None, None, 128)   396032      ['positional_embedding[

In [10]:
epochs = 100
transformer.fit(data_generator, steps_per_epoch=1000, epochs=epochs, callbacks=[CustomCallback(), checkpoint_callback])


Epoch 1/100
dawno dawno temu było <start> dawno temu było <stop> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
Epoch 2/100
dawno dawno temu było <start> dawno temu było <stop> <PAD> <PAD> <stop> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
Epoch 3/100

KeyboardInterrupt: 