# End-to-end Masked Language Modeling with BERT

**Author:** [Ankur Singh](https://twitter.com/ankur310794)<br>
**Date created:** 2020/09/18<br>
**Last modified:** 2020/09/18<br>
**Description:** Implement a Masked Language Model (MLM) with BERT and fine-tune it on the IMDB Reviews dataset.

## Introduction

Masked Language Modeling is a fill-in-the-blank task,
where a model uses the context words surrounding a mask token to try to predict what the
masked word should be.

For an input that contains one or more mask tokens,
the model will generate the most likely substitution for each.

Example:

- Input: "I have watched this [MASK] and it was awesome."
- Output: "I have watched this movie and it was awesome."

Masked language modeling is a great way to train a language
model in a self-supervised setting (without human-annotated labels).
Such a model can then be fine-tuned to accomplish various supervised
NLP tasks.

This example teaches you how to build a BERT model from scratch,
train it with the masked language modeling task,
and then fine-tune this model on a sentiment classification task.

We will use the Keras `TextVectorization` and `MultiHeadAttention` layers
to create a BERT Transformer-Encoder network architecture.

Note: This example should be run with `tf-nightly`.

<br>

### INITIAL SETUP

---

In [None]:
# Import the libraires
import pprint, glob, re
import pandas as pd
import numpy as np
import tensorflow as tf
from dataclasses import dataclass

In [None]:
# Configuration class
@dataclass
class Config:
    
    # Parameters
    MAX_LEN = 256         # Maximum length of input sentence to the model.
    BATCH_SIZE = 32       # Batch size for training.
    LR = 0.001            # Learning rate for Adam optimizer.
    VOCAB_SIZE = 30000    # Size of the vocabulary.
    EMBED_DIM = 128       # Embedding size for embedding matrix
    NUM_HEAD = 8          # Number of attention heads
    FF_DIM = 128          # Hidden layer size in feed forward network inside transformer
    NUM_LAYERS = 1        # Number of transformer layers

# Initialize the configuration
config = Config()
config

<br>

### DOWNLOAD AND LOAD DATASET

---

In [None]:
# Download the dataset
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

In [None]:
# Function for getting the list of files
def get_text_list_from_files(files):
    
    # Initialize a list
    text_list = []
    
    # Loop over the files
    for name in files:
        
        # Open the file
        with open(name) as f:
            
            # Loop over the lines
            for line in f:
                
                # Append the line to the list
                text_list.append(line)
                
    return text_list

In [None]:
# Function for getting the data from the text files
def get_data_from_text_files(folder_name):

    # List of files for positive reviews
    pos_files = glob.glob("aclImdb/" + folder_name + "/pos/*.txt")
    
    # Get the list of texts from the file
    pos_texts = get_text_list_from_files(pos_files)
    
    # List of files for negative reviews
    neg_files = glob.glob("aclImdb/" + folder_name + "/neg/*.txt")
    
    # Get the list of texts from the file
    neg_texts = get_text_list_from_files(neg_files)
    
    # Add the positive and negative reviews to a dataframe
    df = pd.DataFrame({"review": pos_texts + neg_texts, "sentiment": [0] * len(pos_texts) + [1] * len(neg_texts),})
    
    # Shuffle the dataframe
    df = df.sample(len(df)).reset_index(drop=True)
    
    return df

In [None]:
# Get the training and testing dataset
train_df = get_data_from_text_files("train")
test_df = get_data_from_text_files("test")

In [None]:
# Combine the training and testing dataset
all_data = train_df.append(test_df)

<br>

### DATA PREPARATION

---

In [None]:
# Function for custom standardization
def custom_standardization(input_data):
    
    # Lowercase the data
    lowercase = tf.strings.lower(input_data)
    
    # Remove the html tags
    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
    
    # Remove the punctuations
    stripped_punc = tf.strings.regex_replace(stripped_html, "[%s]" % re.escape("!#$%&'()*+,-./:;<=>?@\^_`{|}~"), "")
    
    return stripped_punc

In [None]:
# Function for getting the vectorize layer
def get_vectorize_layer(texts, vocab_size, max_seq, special_tokens=["[MASK]"]):

    # Vectorize the text
    vectorize_layer = tf.keras.layers.TextVectorization(max_tokens=vocab_size,
                                                        output_mode="int",
                                                        standardize=custom_standardization,
                                                        output_sequence_length=max_seq,)
    
    # Apply to the text
    vectorize_layer.adapt(texts)

    # Get the vocabulary
    vocab = vectorize_layer.get_vocabulary()
    
    # Insert mask token in vocabulary
    vocab = vocab[2 : vocab_size - len(special_tokens)] + ["[mask]"]
    
    # Update the vocabulary
    vectorize_layer.set_vocabulary(vocab)
    
    return vectorize_layer

In [None]:
# Get the vectorize layer
vectorize_layer = get_vectorize_layer(all_data.review.values.tolist(),
                                      config.VOCAB_SIZE,
                                      config.MAX_LEN,
                                      special_tokens=["[mask]"],)

In [None]:
# Mask token id
mask_token_id = vectorize_layer(["[mask]"]).numpy()[0][0]

In [None]:
# Function for encoding the text
def encode(texts):
    
    # Vextorize the text
    encoded_texts = vectorize_layer(texts)
    
    # Convert to numpy
    encoded_texts = encoded_texts.numpy()
    
    return encoded_texts

In [None]:
# Function for getting the masked input and labels
def get_masked_input_and_labels(encoded_texts):
    
    # 15% BERT masking
    inp_mask = np.random.rand(*encoded_texts.shape) < 0.15
    
    # Do not mask special tokens
    inp_mask[encoded_texts <= 2] = False
    
    # Set targets to -1 by default, it means ignore
    labels = -1 * np.ones(encoded_texts.shape, dtype=int)
    
    # Set labels for masked tokens
    labels[inp_mask] = encoded_texts[inp_mask]

    # Prepare input
    encoded_texts_masked = np.copy(encoded_texts)
    
    # Set input to [MASK] which is the last token for the 90% of tokens. This means leaving 10% unchanged.
    inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)
    encoded_texts_masked[inp_mask_2mask] = mask_token_id  # mask token is the last in the dict

    # Set 10% to a random token
    inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)
    encoded_texts_masked[inp_mask_2random] = np.random.randint(3, mask_token_id, inp_mask_2random.sum())

    # Prepare sample_weights to pass to .fit() method
    sample_weights = np.ones(labels.shape)
    sample_weights[labels == -1] = 0

    # y_labels would be same as encoded_texts i.e input tokens
    y_labels = np.copy(encoded_texts)

    return encoded_texts_masked, y_labels, sample_weights

In [None]:
# Encode the inupt training data 
x_train = encode(train_df.review.values)  

# Get the output training data
y_train = train_df.sentiment.values

# Convert to tf.data + Shuffle + Set batch size
train_classifier_ds = (tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1000).batch(config.BATCH_SIZE))

In [None]:
# Encode the input testing data
x_test = encode(test_df.review.values)

# Get the output testing data
y_test = test_df.sentiment.values

# Convert to tf.data + Set batch size
test_classifier_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(config.BATCH_SIZE)

In [None]:
# Convert the raw test set to tf.data + Set batch size
test_raw_classifier_ds = tf.data.Dataset.from_tensor_slices((test_df.review.values, y_test)).batch(config.BATCH_SIZE)

In [None]:
# Prepare data for masked language model
x_all_review = encode(all_data.review.values)
x_masked_train, y_masked_labels, sample_weights = get_masked_input_and_labels(
    x_all_review
)

In [None]:
# Convert masked data to tf.data
mlm_ds = tf.data.Dataset.from_tensor_slices((x_masked_train, y_masked_labels, sample_weights))

# Shuffle + Set batch size
mlm_ds = mlm_ds.shuffle(1000).batch(config.BATCH_SIZE)

<br>

### BERT MODEL

Create BERT model (Pretraining Model) for masked language modeling.

In [None]:
# Implementation of the BERT module
def bert_module(query, key, value, i):
    
    # Multi headed self-attention
    attention_output = tf.keras.layers.MultiHeadAttention(num_heads=config.NUM_HEAD,
                                                          key_dim=config.EMBED_DIM // config.NUM_HEAD,
                                                          name="encoder_{}/multiheadattention".format(i),
                                                         )(query, key, value)
    
    # Dropout   
    attention_output = tf.keras.layers.Dropout(0.1, name="encoder_{}/att_dropout".format(i))(attention_output)
    
    # Add and norm
    attention_output = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="encoder_{}/att_layernormalization".format(i)
                                                         )(query + attention_output)

    # Initialize dense layers
    ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(config.FF_DIM, activation="relu"),
            tf.keras.layers.Dense(config.EMBED_DIM),
        ],
        name="encoder_{}/ffn".format(i),
    )
    
    # Feed-forward
    ffn_output = ffn(attention_output)
    
    # Dropout
    ffn_output = tf.keras.layers.Dropout(0.1, name="encoder_{}/ffn_dropout".format(i))(ffn_output)
    
    # Add and norm
    sequence_output = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="encoder_{}/ffn_layernormalization".format(i)
                                                        )(attention_output + ffn_output)
    
    return sequence_output

In [None]:
# Function for positional encoding
def get_pos_encoding_matrix(max_len, d_emb):
    
    # Position encoding matrix
    pos_enc = np.array(
        [
            [pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]   
            if pos != 0
            else np.zeros(d_emb)
            for pos in range(max_len)
        ]
    )
    
    # Apply the cosine to even columns and sin to odds
    pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2])  # dim 2i
    pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2])  # dim 2i+1
    
    return pos_enc

In [None]:
# Categorical crossentropy loss function (sparse)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)

# Loss tracker 
loss_tracker = tf.keras.metrics.Mean(name="loss")

In [None]:
# Class for masked language model
class MaskedLanguageModel(tf.keras.Model):
    
    # Function for training step
    def train_step(self, inputs):
        
        # If there are 3 inputs
        if len(inputs) == 3:
            
            # Unpack the inputs
            features, labels, sample_weight = inputs
        
        # If there are 2 inputs
        else:
            
            # Unpack the inputs
            features, labels = inputs
            
            # Set sample weight to None
            sample_weight = None

        # Tell the model to compute the gradients
        with tf.GradientTape() as tape:
            
            # Get the predictions
            predictions = self(features, training=True)
            
            # Compute the loss
            loss = loss_fn(labels, predictions, sample_weight=sample_weight)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Compute our own metrics
        loss_tracker.update_state(loss, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value
        return {"loss": loss_tracker.result()}

    # Function for listing our `Metric` objects so that `reset_states()` can be called automatically at the start of each epoch or at the start of `evaluate()`. If you don't implement this property, you have to call `reset_states()` yourself at the time of your choosing.
    @property
    def metrics(self):
        return [loss_tracker]

In [None]:
# Function for creating BERT model for MLM
def create_masked_language_bert_model():
    
    # Input layer
    inputs = tf.keras.layers.Input((config.MAX_LEN,), dtype=tf.int64)

    # Word embedding layer
    word_embeddings = tf.keras.layers.Embedding(config.VOCAB_SIZE, config.EMBED_DIM, name="word_embedding")(inputs)
    
    # Position embedding layer
    position_embeddings = tf.keras.layers.Embedding(input_dim=config.MAX_LEN,
                                                    output_dim=config.EMBED_DIM,
                                                    weights=[get_pos_encoding_matrix(config.MAX_LEN, config.EMBED_DIM)],
                                                    name="position_embedding",
                                                    )(tf.range(start=0, limit=config.MAX_LEN, delta=1))
    
    # Add the word and position embeddings
    embeddings = word_embeddings + position_embeddings

    # Get the output of the embedding layer
    encoder_output = embeddings
    
    # Loop over the number of layers
    for i in range(config.NUM_LAYERS):
        
        # Create BERT module
        encoder_output = bert_module(encoder_output, encoder_output, encoder_output, i)

    # Output layer
    mlm_output = tf.keras.layers.Dense(config.VOCAB_SIZE, name="mlm_cls", activation="softmax")(encoder_output)
    
    # MLM model
    mlm_model = MaskedLanguageModel(inputs, mlm_output, name="masked_bert_model")

    # Adam optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=config.LR)
    
    # Compile the model
    mlm_model.compile(optimizer=optimizer)
    
    return mlm_model

In [None]:
# Id2Token
id2token = dict(enumerate(vectorize_layer.get_vocabulary()))

# Token2Id
token2id = {y: x for x, y in id2token.items()}

In [None]:
# Class for callback function
class MaskedTextGenerator(tf.keras.callbacks.Callback):
    
    # Constructor function
    def __init__(self, sample_tokens, top_k=5):
        
        # Initialization
        self.sample_tokens = sample_tokens
        self.k = top_k

    # Function for decoding tokens
    def decode(self, tokens):
        
        # Return the decoded tokens
        return " ".join([id2token[t] for t in tokens if t != 0])

    # Function for converting ids to tokens
    def convert_ids_to_tokens(self, id):
        
        # Return the token
        return id2token[id]

    # Function to run at the end of each epoch
    def on_epoch_end(self, epoch, logs=None):
        
        # Make predictions
        prediction = self.model.predict(self.sample_tokens)

        # Get the masked index
        masked_index = np.where(self.sample_tokens == mask_token_id) 
        masked_index = masked_index[1]
        
        # Get the masked prediction
        mask_prediction = prediction[0][masked_index]

        # Get the top k predictions
        top_indices = mask_prediction[0].argsort()[-self.k :][::-1]
        
        # Get the values
        values = mask_prediction[0][top_indices]

        # Loop over the top k predictions
        for i in range(len(top_indices)):
            
            # Get the prediction
            p = top_indices[i]
            
            # Get the value
            v = values[i]
            
            # Get the tokens
            tokens = np.copy(self.sample_tokens[0])
            
            # Replace the masked token with the predicted token
            tokens[masked_index[0]] = p
            
            # Results to print
            result = {
                "input_text": self.decode(self.sample_tokens[0].numpy()),
                "prediction": self.decode(tokens),
                "probability": v,
                "predicted mask token": self.convert_ids_to_tokens(p),
            }
            
            # Report
            pprint.pprint(result)

In [None]:
# Vectorize the sample text (for callback)
sample_tokens = vectorize_layer(["I have watched this [mask] and it was awesome"])

# Setup the callback function
generator_callback = MaskedTextGenerator(sample_tokens.numpy())

In [None]:
# Create the BERT model for MLM task
bert_masked_model = create_masked_language_bert_model()

# Model summary
bert_masked_model.summary()

<br>

### TRAINING

---

In [None]:
# Train the model
bert_masked_model.fit(mlm_ds, epochs=5, callbacks=[generator_callback])

# Save the model
bert_masked_model.save("bert_mlm_imdb.h5")

<br>

### FINE-TUNNING

-- 

Fine-tune a sentiment classification model. We will fine-tune our self-supervised model on a downstream task of sentiment classification.
To do this, let's create a classifier by adding a pooling layer and a `Dense` layer on top of the
pretrained BERT features.

In [None]:
# Load pretrained bert model
mlm_model = tf.keras.models.load_model("bert_mlm_imdb.h5", custom_objects={"MaskedLanguageModel": MaskedLanguageModel})

# Construct the pretrained bert model
pretrained_bert_model = tf.keras.Model(mlm_model.input, mlm_model.get_layer("encoder_0/ffn_layernormalization").output)

# Freeze the layers
pretrained_bert_model.trainable = False

In [None]:
# Function for creating the classifier model
def create_classifier_bert_model():
    
    # Input layer
    inputs = tf.keras.layers.Input((config.MAX_LEN,), dtype=tf.int64)
    
    # Feed the input to the pretrained bert model
    sequence_output = pretrained_bert_model(inputs)
    
    # Global max pooling
    pooled_output = tf.keras.layers.GlobalMaxPooling1D()(sequence_output)
    
    # Hidden dense layer
    hidden_layer = tf.keras.layers.Dense(64, activation="relu")(pooled_output)
    
    # Output layer
    outputs = tf.keras.layers.Dense(1, activation="sigmoid")(hidden_layer)
    
    # Construct the mode;
    classifer_model = tf.keras.Model(inputs, outputs, name="classification")
    
    # Adam optimizer
    optimizer = tf.keras.optimizers.Adam()
    
    # Compile the model
    classifer_model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
    
    return classifer_model

In [None]:
# Create the classifier model
classifer_model = create_classifier_bert_model()

# Model summary
classifer_model.summary()

In [None]:
# Train the classifier with frozen BERT stage
classifer_model.fit(
    train_classifier_ds,
    epochs=5,
    validation_data=test_classifier_ds,
)

In [None]:
# Unfreeze the BERT model for fine-tuning
pretrained_bert_model.trainable = True

# Adam optimizer
optimizer = tf.keras.optimizers.Adam()

# Compile the model
classifer_model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])

# Train the model (fine-tuning)
classifer_model.fit(train_classifier_ds,
                    epochs=5,
                    validation_data=test_classifier_ds,
)

<br>

### PREDICTION

--- 

Create an end-to-end model and evaluate it. When you want to deploy a model, it's best if it already includes its preprocessing pipeline, so that you don't have to reimplement the preprocessing logic in your production environment. Let's create an end-to-end model that incorporates the `TextVectorization` layer, and let's evaluate. Our model will accept raw strings
as input.

In [None]:
# Function for end-to-end prediction
def get_end_to_end(model):
    
    # Input layer
    inputs_string = tf.keras.Input(shape=(1,), dtype="string")
    
    # Vectorize the input
    indices = vectorize_layer(inputs_string)
    
    # Feed to the model
    outputs = model(indices)
    
    # Construct the model
    end_to_end_model = tf.keras.Model(inputs_string, outputs, name="end_to_end_model")
    
    # Adam optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=config.LR)
    
    # Compile the model
    end_to_end_model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
    
    return end_to_end_model


In [None]:
# Initialize the end-to-end model
end_to_end_classification_model = get_end_to_end(classifer_model)

# TODO: Make prediction
sample_text = ["I have watched this [mask] and it was awesome"]
end_to_end_classification_model.predict(sample_text)

<br>

### EVALUATION

--- 


In [None]:
# Initialize the end-to-end model
end_to_end_classification_model = get_end_to_end(classifer_model)

# Model evaluation
end_to_end_classification_model.evaluate(test_raw_classifier_ds)