In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization
import numpy as np
import os
import re
import string
import random

In [2]:
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    """
    Mask the upper half of the dot product matrix in self attention.
    This prevents flow of information from future tokens to current token.
    1's in the lower triangle, counting from the lower right corner.
    """
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
    )
    return tf.tile(mask, mult)


class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads, embed_dim)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)
        attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
        attention_output = self.dropout1(attention_output)
        out1 = self.layernorm1(inputs + attention_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

In [3]:
class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

In [4]:
vocab_size = 20000  # Only consider the top 20k words
maxlen = 80  # Max sequence size
embed_dim = 256  # Embedding size for each token
num_heads = 2  # Number of attention heads
feed_forward_dim = 256  # Hidden layer size in feed forward network inside transformer


def create_model():
    inputs = layers.Input(shape=(maxlen,), dtype=tf.int32)
    embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
    x = embedding_layer(inputs)
    transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
    x = transformer_block(x)
    outputs = layers.Dense(vocab_size)(x)
    model = keras.Model(inputs=inputs, outputs=[outputs, x])
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(
        "adam", loss=[loss_fn, None],
    )  # No loss and optimization based on word embeddings from transformer block
    return model

In [5]:
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:--  0:00:01 --:--:--     0
  0     0    0     0    0     0      0      0 --:--:--  0:00:02 --:--:--     0
  0 80.2M    0     0    0     0      0      0 --:--:--  0:00:03 --:--:--     0
  0 80.2M    0 32768    0     0   8247      0  2:50:00  0:00:03  2:49:57  8260
  0 80.2M    0  112k    0     0  22813      0  1:01:27  0:00:05  1:01:22 22846
  0 80.2M    0  240k    0     0  40553      0  0:34:34  0:00:06  0:34:28 54335
  0 80.2M    0  416k    0     0  62136      0  0:22:33  0:00:06  0:22:27 99296
  0 80.2M    0  752k    0     0  97380      0  0:14:23  0:00:07  0:14:16  156k
  1 80.2M    1 1280k    0     0   142k      0  0:09:35  0:00:08  0:09:27  250k
  2 80.2M    2 2032k    0     0   206k      0  0:06

In [6]:
batch_size = 128

# The dataset contains each review in a separate text file
# The text files are present in four different folders
# Create a list all files
filenames = []
directories = [
    "aclImdb/train/pos",
    "aclImdb/train/neg",
    "aclImdb/test/pos",
    "aclImdb/test/neg",
]
for dir in directories:
    for f in os.listdir(dir):
        filenames.append(os.path.join(dir, f))

print(f"{len(filenames)} files")

# Create a dataset from text files
random.shuffle(filenames)
text_ds = tf.data.TextLineDataset(filenames)
text_ds = text_ds.shuffle(buffer_size=256)
text_ds = text_ds.batch(batch_size)


def custom_standardization(input_string):
    """ Remove html line-break tags and handle punctuation """
    lowercased = tf.strings.lower(input_string)
    stripped_html = tf.strings.regex_replace(lowercased, "<br />", " ")
    return tf.strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")


# Create a vectorization layer and adapt it to the text
vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size - 1,
    output_mode="int",
    output_sequence_length=maxlen + 1,
)
vectorize_layer.adapt(text_ds)
vocab = vectorize_layer.get_vocabulary()  # To get words back from token indices


def prepare_lm_inputs_labels(text):
    """
    Shift word sequences by 1 position so that the target for position (i) is
    word at position (i+1). The model will use all words up till position (i)
    to predict the next word.
    """
    text = tf.expand_dims(text, -1)
    tokenized_sentences = vectorize_layer(text)
    x = tokenized_sentences[:, :-1]
    y = tokenized_sentences[:, 1:]
    return x, y


text_ds = text_ds.map(prepare_lm_inputs_labels)
text_ds = text_ds.prefetch(tf.data.AUTOTUNE)

50000 files


In [7]:
class TextGenerator(keras.callbacks.Callback):
    """A callback to generate text from a trained model.
    1. Feed some starting prompt to the model
    2. Predict probabilities for the next token
    3. Sample the next token and add it to the next input

    Arguments:
        max_tokens: Integer, the number of tokens to be generated after prompt.
        start_tokens: List of integers, the token indices for the starting prompt.
        index_to_word: List of strings, obtained from the TextVectorization layer.
        top_k: Integer, sample from the `top_k` token predictions.
        print_every: Integer, print after this many epochs.
    """

    def __init__(
        self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1
    ):
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
        self.index_to_word = index_to_word
        self.print_every = print_every
        self.k = top_k

    def sample_from(self, logits):
        logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)
        indices = np.asarray(indices).astype("int32")
        preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]
        preds = np.asarray(preds).astype("float32")
        return np.random.choice(indices, p=preds)

    def detokenize(self, number):
        return self.index_to_word[number]

    def on_epoch_end(self, epoch, logs=None):
        start_tokens = [_ for _ in self.start_tokens]
        if (epoch + 1) % self.print_every != 0:
            return
        num_tokens_generated = 0
        tokens_generated = []
        while num_tokens_generated <= self.max_tokens:
            pad_len = maxlen - len(start_tokens)
            sample_index = len(start_tokens) - 1
            if pad_len < 0:
                x = start_tokens[:maxlen]
                sample_index = maxlen - 1
            elif pad_len > 0:
                x = start_tokens + [0] * pad_len
            else:
                x = start_tokens
            x = np.array([x])
            y, _ = self.model.predict(x)
            sample_token = self.sample_from(y[0][sample_index])
            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)
        txt = " ".join(
            [self.detokenize(_) for _ in self.start_tokens + tokens_generated]
        )
        print(f"generated text:\n{txt}\n")


# Tokenize starting prompt
word_to_index = {}
for index, word in enumerate(vocab):
    word_to_index[word] = index

start_prompt = "this movie is"
start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]
num_tokens_generated = 40
text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)

In [8]:
model = create_model()

model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])

Epoch 1/25
generated text:
this movie is not an interesting story , and a lot of [UNK] in the story , and [UNK] , and a young woman . the film was a good thing about the story about it 's a [UNK] that 's [UNK] , [UNK]

391/391 - 42s - loss: 5.5762 - dense_2_loss: 5.5762 - 42s/epoch - 107ms/step
Epoch 2/25
generated text:
this movie is a great movie , the acting is horrible and the script that it is terrible . i am not to say that i 'm not sure the movie . the acting is bad . the only reason it could have been

391/391 - 39s - loss: 4.7082 - dense_2_loss: 4.7082 - 39s/epoch - 101ms/step
Epoch 3/25
generated text:
this movie is about a young girl who 's lives in a house and is a very good one of the most boring movies ever made . not bad , bad , it 's a waste of time . . . .i 'm a

391/391 - 40s - loss: 4.4572 - dense_2_loss: 4.4572 - 40s/epoch - 101ms/step
Epoch 4/25


generated text:
this movie is one of the worst movies i have ever seen , but it just seems to me and my girlfriend are friends . i can 't watch it . i have to admit that i can safely say that the only thing

391/391 - 39s - loss: 4.2985 - dense_2_loss: 4.2985 - 39s/epoch - 101ms/step
Epoch 5/25
generated text:
this movie is a complete disappointment . it is very well done , with its not . what really did with the story line was so much . the film is about [UNK] . the movie is about two young lovers - - -

391/391 - 39s - loss: 4.1776 - dense_2_loss: 4.1776 - 39s/epoch - 101ms/step
Epoch 6/25
generated text:
this movie is a must -see for everyone , and i have to watch it . i love horror movies . i don 't think i can tell you that if it had to go and enjoy the movie . this movie has a

391/391 - 40s - loss: 4.0778 - dense_2_loss: 4.0778 - 40s/epoch - 101ms/step
Epoch 7/25


generated text:
this movie is the worst movie i 've ever seen in a few movies in my life , i 'm not so sure that it seems a lot more than a bit as it does . the plot is a weak . the acting

391/391 - 39s - loss: 3.9920 - dense_2_loss: 3.9920 - 39s/epoch - 101ms/step
Epoch 8/25
generated text:
this movie is very very well paced . the acting of this story . the characters are all over and over the top , the direction and the plot is very poor as the acting . the writing is wooden , the acting was

391/391 - 40s - loss: 3.9174 - dense_2_loss: 3.9174 - 40s/epoch - 101ms/step
Epoch 9/25
generated text:
this movie is really the best movie of all time . it 's a great movie . a movie , and a few times . it 's very entertaining , a classic . . [UNK] " i have no idea what it 's got

391/391 - 40s - loss: 3.8516 - dense_2_loss: 3.8516 - 40s/epoch - 101ms/step
Epoch 10/25
generated text:
this movie is very much more than a documentary , and shows like a real life . it doesn 't seem like a document

generated text:
this movie is a bit too predictable for me . it is not funny at all in it . the characters are just plain silly and simple but it 's the best movie i have ever seen . the acting is terrible and i

391/391 - 40s - loss: 3.7392 - dense_2_loss: 3.7392 - 40s/epoch - 101ms/step
Epoch 12/25
generated text:
this movie is very bad , not even a very good movie . its bad language that is a bad film , but it doesn 't matter what is really that bad in it . the film is like the plague . the city

391/391 - 40s - loss: 3.6906 - dense_2_loss: 3.6906 - 40s/epoch - 101ms/step
Epoch 13/25
generated text:
this movie is really good . . i can 't believe how bad i even think about this movie . [UNK] [UNK] " was one of the best . it was so awful that it was not the best movie i have ever seen

391/391 - 39s - loss: 3.6464 - dense_2_loss: 3.6464 - 39s/epoch - 101ms/step
Epoch 14/25


generated text:
this movie is a great movie ! it 's so good that it shows how people can be on the top notch and that is a good movie . the actors are great , and you don 't want it to keep you laughing

391/391 - 39s - loss: 3.6051 - dense_2_loss: 3.6051 - 39s/epoch - 101ms/step
Epoch 15/25
generated text:
this movie is one of the funniest movies i 've seen in a long time . i saw it on cable one night in a long time when i was in the theater with it , i thought it would be funny and not

391/391 - 39s - loss: 3.5682 - dense_2_loss: 3.5682 - 39s/epoch - 101ms/step
Epoch 16/25
generated text:
this movie is a great example of an actor who is a good actor . he does not know what it has a great actor . his charisma , and the acting is very well -cast and he is very funny in this movie

391/391 - 40s - loss: 3.5332 - dense_2_loss: 3.5332 - 40s/epoch - 101ms/step
Epoch 17/25
generated text:
this movie is a waste of time and time . it is a complete waste of time . it has nothing to do to do w

Epoch 18/25
generated text:
this movie is an abomination who was a real good movie , but not a good film . it 's a story , but in the plot , the film is a joke and the execution is completely unbelievable , but in this is

391/391 - 39s - loss: 3.4709 - dense_2_loss: 3.4709 - 39s/epoch - 101ms/step
Epoch 19/25
generated text:
this movie is a perfect example of what kind of an old hollywood movie that is about as bad as this one . the best movies of all time , most people in their minds or they need to make the story . if

391/391 - 39s - loss: 3.4425 - dense_2_loss: 3.4425 - 39s/epoch - 101ms/step
Epoch 20/25
generated text:
this movie is not a bad movie , but it does not deserve any honorable justice to this piece of [UNK] . it does a very good job in this movie . this is not a [UNK] that is the movie . there are

391/391 - 39s - loss: 3.4168 - dense_2_loss: 3.4168 - 39s/epoch - 101ms/step
Epoch 21/25


generated text:
this movie is very well acted . it has some very good performances , very good direction . it 's not that good and it 's not worth the time . if you 're a good thing , you 'll be warned by fans

391/391 - 39s - loss: 3.3920 - dense_2_loss: 3.3920 - 39s/epoch - 101ms/step
Epoch 22/25
generated text:
this movie is an absolutely wonderful piece of crap . it is filled with balls that can hardly ever hit the face ? how the mouths the graphics are horrendous . it is awful , and the only real thing that is a bad

391/391 - 40s - loss: 3.3696 - dense_2_loss: 3.3696 - 40s/epoch - 101ms/step
Epoch 23/25
generated text:
this movie is an absolute must see if you have ever seen . i would be able to buy it . i was curious as some of the worst i have ever seen . the movie is bad , it was shot in nyc

391/391 - 40s - loss: 3.3481 - dense_2_loss: 3.3481 - 40s/epoch - 101ms/step
Epoch 24/25


generated text:
this movie is a good example of what a great movie ! the acting is great , and it has been said it . the script is also excellent and the acting is wonderful , especially for everyone who wants to see the characters

391/391 - 40s - loss: 3.3277 - dense_2_loss: 3.3277 - 40s/epoch - 101ms/step
Epoch 25/25
generated text:
this movie is about a wealthy woman who can 't help her to see the movie , her mother and the family moves into a remote house where her daughter asks her brother is kidnapped and has a connection to them in the woods

391/391 - 40s - loss: 3.3082 - dense_2_loss: 3.3082 - 40s/epoch - 102ms/step


<keras.callbacks.History at 0x1c6db329360>

In [46]:
start_prompt = "i am proud of"
start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]
max_tokens = 400

top_k = 10
k = top_k
def sample_from(logits):
    logits, indices = tf.math.top_k(logits, k=k, sorted=True)
    indices = np.asarray(indices).astype("int32")
    preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]
    preds = np.asarray(preds).astype("float32")
    return np.random.choice(indices, p=preds)
    
start_tokens = [_ for _ in start_tokens]
num_tokens_generated = 0
tokens_generated = []
while num_tokens_generated <= max_tokens:
    pad_len = maxlen - len(start_tokens)
    sample_index = len(start_tokens) - 1
    if pad_len < 0:
        x = start_tokens[:maxlen]
        sample_index = maxlen - 1
    elif pad_len > 0:
        x = start_tokens + [0] * pad_len
    else:
        x = start_tokens
    x = np.array([x])
    y, _ = model.predict(x)
    sample_token = sample_from(y[0][sample_index])
    tokens_generated.append(sample_token)
    start_tokens.append(sample_token)
    num_tokens_generated = len(tokens_generated)







In [47]:
 txt = " ".join([vocab[_] for _ in start_tokens + tokens_generated])

In [48]:
txt

"i am proud of this film and i like it was a good movie and i was very interested to watch the characters . i was surprised by this movie . the plot was just bad . i am not going to watch the film for the [UNK] . i was expecting another d -movie . it 's a fun movie . the acting was great . . . . the acting was not that bad but the script was terrible just a just so just just a good terrible not bad just terrible a just just just terrible not not terrible just good terrible so very bad good really just bad really bad bad so really just bad good really good really very terrible terrible a terrible just really good just just just good awful terrible really terrible just a just terrible terrible really just terrible really bad bad terrible very a really so terrible really just just bad very just really terrible so bad awful awful just not awful good awful not terrible a just just bad just awful a so awful just terrible so just just terrible not really bad good just just just just a just ve