In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Warm start embedding matrix with changing vocabulary

This tutorial contains an introduction to `tf.keras.utils.warmstart_embedding_matrix`. You will train a simple Keras model for a sentiment classification task with a base vocabulary. You will learn how to warmstart this model training when you have a new vocabulary using which you want to continue to improve the model training.

## Embedding matrix

Embeddings give us a way to use an efficient, dense representation in which similar vocabulary tokens have a similar encoding. They are trainable parameters (weights learned by the model during training, in the same way a model learns weights for a dense layer). It is common to see embeddings that are 8-dimensional (for small datasets), up to 1024-dimensions when working with large datasets. A higher dimensional embedding can capture fine-grained relationships between words, but takes more data to learn.


### Vocabulary

The set of unique words used in the text corpus is referred to as the vocabulary. We can use the vocabulary to find the number of times each word appears in the corpus. This helps us analyze which words are more common.Vocabulary allows us to represent each piece of text by the specific words that appear in it.

### Why warm starting embedding matrix?

A model is trained with a set of embeddings that represents a given vocabulary. If the model needs to be updated or improved as and when the vocabulary input extends(or changes or shuffles), previously the model architecture would change (because the embedding layer's `input_dim` would change). As a consequence users could not reuse previously trained embeddings and the training would start from scratch.

`tf.keras.utils.warmstart_embedding_matrix` util can be used to warmstart the embedding layer matrix when vocabulary changes between previously saved checkpoint and model. Vocabulary change could mean, the size of the new vocab is different or the vocabulary is reshuffled or new vocabulary has been added to old vocabulary. If the vocabulary size changes, size of the embedding layer matrix also changes. This util remaps the old vocabulary embeddings to the new embedding layer matrix.

## Setup

In [None]:
# install tf-nightly as `warmstart_embedding_matrix` is only available in nightly
! pip install -q tf-nightly
# uninstall nightly tensorboardx and reinstall to work with tf-nightly
! pip uninstall --yes tb-nightly tensorboardX tensorboard
! pip install tensorboard

In [None]:
import io
import numpy as np
import os
import re
import shutil
import string
import tensorflow as tf

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D
from tensorflow.keras.layers import TextVectorization

### Load Dataset
You will use the [Large Movie Review Dataset](http://ai.stanford.edu/~amaas/data/sentiment/) through the tutorial. You will train a sentiment classifier model on this dataset and in the process learn embeddings from scratch. To read more about loading a dataset from scratch, see the [Loading text tutorial](https://www.tensorflow.org/tutorials/load_data/text).  

Download the dataset using Keras file utility and take a look at the directories.

In [None]:
url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

dataset = tf.keras.utils.get_file(
    "aclImdb_v1.tar.gz", url, untar=True, cache_dir=".", cache_subdir=""
)

dataset_dir = os.path.join(os.path.dirname(dataset), "aclImdb")
os.listdir(dataset_dir)

Take a look at the `train/` directory. It has `pos` and `neg` folders with movie reviews labelled as positive and negative respectively. You will use reviews from `pos` and `neg` folders to train a binary classification model.

In [None]:
train_dir = os.path.join(dataset_dir, "train")
os.listdir(train_dir)

The `train` directory also has additional folders which should be removed before creating training dataset.

In [None]:
remove_dir = os.path.join(train_dir, "unsup")
shutil.rmtree(remove_dir)

Next, create a `tf.data.Dataset` using `tf.keras.utils.text_dataset_from_directory`. You can read more about using this utility in this [text classification tutorial](https://www.tensorflow.org/tutorials/keras/text_classification). 

Use the `train` directory to create both train and validation datasets with a split of 20% for validation.

In [None]:
batch_size = 1024
seed = 123
train_ds = tf.keras.utils.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="training",
    seed=seed,
)
val_ds = tf.keras.utils.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="validation",
    seed=seed,
)

### Configure the dataset for performance

You can learn more about `.cache()` and `.prefetch()`, as well as how to cache data to disk in the [data performance guide](https://www.tensorflow.org/guide/data_performance).

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

## Text preprocessing

Next, define the dataset preprocessing steps required for your sentiment classification model. Initialize a TextVectorization layer with the desired parameters to vectorize movie reviews. You can learn more about using this layer in the [Text Classification](https://www.tensorflow.org/tutorials/keras/text_classification) tutorial.

In [None]:
# Create a custom standardization function to strip HTML break tags '<br />'.
def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
    return tf.strings.regex_replace(
        stripped_html, "[%s]" % re.escape(string.punctuation), ""
    )


# Vocabulary size and number of words in a sequence.
vocab_size = 10000
sequence_length = 100

# Use the text vectorization layer to normalize, split, and map strings to
# integers. Note that the layer uses the custom standardization defined above.
# Set maximum_sequence length as all samples are not of the same length.
vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length,
)

# Make a text-only dataset (no labels) and call adapt to build the vocabulary.
text_ds = train_ds.map(lambda x, y: x)
vectorize_layer.adapt(text_ds)

## Create a classification model

Use the [Keras Functional API](https://www.tensorflow.org/guide/keras/functional) to define the sentiment classification model. 

In [None]:
# build a functional model
embedding_dim = 16
text_model_input = tf.keras.layers.Input(dtype=tf.string, shape=(1,))
text_vectorize_layer = vectorize_layer(text_model_input)
text_embedding_layer = Embedding(vocab_size, embedding_dim, name="embedding")(
    text_vectorize_layer
)
global_avg_pool = GlobalAveragePooling1D()(text_embedding_layer)
dense_1 = Dense(16, activation="relu")(global_avg_pool)
output = Dense(1)(dense_1)
model = Model(inputs=text_model_input, outputs=output)

## Compile and train the model

You will use [TensorBoard](https://www.tensorflow.org/tensorboard) to visualize metrics including loss and accuracy. Create a `tf.keras.callbacks.TensorBoard`.

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

Compile and train the model using the `Adam` optimizer and `BinaryCrossentropy` loss. 

In [None]:
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

In [None]:
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=15,
    callbacks=[tensorboard_callback],
)

With this approach the model reaches a validation accuracy of around 85% 

Note: Your results may be a bit different, depending on how weights were randomly initialized before training the embedding layer. 

You can look into the model summary to learn more about each layer of the model.

In [None]:
model.summary()

Visualize the model metrics in TensorBoard.

In [None]:
# docs_infra: no_execute
%load_ext tensorboard
%tensorboard --logdir logs

# Vocabulary Remapping

Scenerio: The vocab size has now changed. The new vocab has new words or lesser words or is shuffled, etc. The embedding layer needs to be remapped and updated.

Get base vocabulary and embedding matrix.

In [None]:
embedding_weights_base = model.get_layer("embedding").get_weights()[0]
vocab_base = vectorize_layer.get_vocabulary()

Define a new vectorization layer to generate a new bigger vocabulary

In [None]:
# Vocabulary size and number of words in a sequence.
vocab_size_new = 10200
sequence_length = 100

vectorize_layer_new = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size_new,
    output_mode="int",
    output_sequence_length=sequence_length,
)

# Make a text-only dataset (no labels) and call adapt to build the vocabulary.
text_ds = train_ds.map(lambda x, y: x)
vectorize_layer_new.adapt(text_ds)

# get new vocab
vocab_new = vectorize_layer_new.get_vocabulary()

Generate updated embeddings using warmstart_embedding_matrix util.

In [None]:
# generate updated embedding matrix
updated_embedding = tf.keras.utils.warmstart_embedding_matrix(
    base_vocabulary=vocab_base,
    new_vocabulary=vocab_new,
    base_embeddings=embedding_weights_base,
    new_embeddings_initializer="uniform",
)
# update model variable
updated_embedding_variable = tf.Variable(updated_embedding)

**OR**

If you have an embedding matrix which you would like to initialize the new embedding matrix with use `keras.initializers.Constant` as new_embeddings initializer. Uncomment the following block to try this out.

In [None]:
'''
# generate updated embedding matrix
new_embedding = np.random.rand(len(vocab_new), 16)
updated_embedding = tf.keras.utils.warmstart_embedding_matrix(
            base_vocabulary=vocab_base,
            new_vocabulary=vocab_new,
            base_embeddings=embedding_weights_base,
            new_embeddings_initializer=tf.keras.initializers.Constant(
                new_embedding
            )
        )
# update model variable
updated_embedding_variable = tf.Variable(updated_embedding)
''''

verify if embedding matrix shape has changed to reflect new vocabulary.

In [None]:
updated_embedding_variable.shape

Now that we have the updated embedding matrix, the next step is to update the layer weights.

In [None]:
model.get_layer("embedding").embeddings = updated_embedding_variable

# Verify updated weights shape
# The new weights shape should reflect new vocab size
model.get_layer("embedding").get_weights()[0].shape

modify the model architecture to use the new text vectorization layer

In [None]:
text_vectorize_layer_new = vectorize_layer_new(text_model_input)
text_embedd = Embedding(vocab_size, embedding_dim, name="embedding")(
    text_vectorize_layer_new
)
model = Model(inputs=text_model_input, outputs=output)

# view model summary and check updated Param # for Embedding layer.
model.summary()

We have successfully updated the model to accept a new vocabulary. The embedding layer is updated to map old vocabulary words to old embeddings and initialize embeddings for new vocabulary words to be learnt. The learned weights of the rest of the model will remain the same. The model is warmstarted to continue to train from where it left off previously.

Let us verify that the remapping worked. Get index of the vocabulary word "the" that is present both in base and new vocabulary and compare the embedding values. They should be equal.

In [None]:
# new vocab words
example_old_vocab_word = "the"  # index 2
base_vocab_index = vectorize_layer("the")[0]
new_vocab_index = vectorize_layer_new("the")[0]
print(
    model.get_layer("embedding")(new_vocab_index)
    == embedding_weights_base[base_vocab_index]
)

# Continue with warm started training

Notice how the training is warmstarted. The accuracy of first epoch is around 85%. Close to the accuracy where the previous traning ended.

In [None]:
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=15,
    callbacks=[tensorboard_callback],
)

# Visualize warm started training

In [None]:
# docs_infra: no_execute
%reload_ext tensorboard
%tensorboard --logdir logs