In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
from dataclasses import dataclass #This module provides a decorator and functions for automatically adding generated special methods such as __init__() and __repr__() to user-defined classes
import pandas as pd
import numpy as np
import glob #Returns a list of files that match the given pattern(s).
import re
from pprint import pprint

In [None]:
@dataclass
class Config:
  MAX_LEN=256
  BATCH_SIZE=32
  LR=0.001
  VOCAB_SIZE=30000
  EMBED_DIM=128
  NUM_HEAD=8
  FF_DIM=128
  NUM_LAYERS=1
  

In [None]:
config=Config()

In [None]:
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz #transferring data specified with URL syntax
!tar -xf aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  60.5M      0  0:00:01  0:00:01 --:--:-- 60.5M


In [None]:
def get_text_list_from_files(files):
  text_list=[]
  for name in files:
    with open(name) as f:
      for line in f:
        text_list.append(line)
  return text_list

In [None]:
def get_data_from_text_files(folder_name):
  pos_files=glob.glob("aclImdb/"+folder_name+"/pos/*.txt")
  pos_texts=get_text_list_from_files(pos_files)
  neg_files=glob.glob("aclImdb/"+folder_name+"/neg/*.txt")
  neg_texts=get_text_list_from_files(neg_files)
  df=pd.DataFrame({"review":pos_texts+neg_texts,"sentiment":[0]*len(pos_texts)+[1]*len(neg_texts),})
  df=df.sample(len(df)).reset_index(drop=True)
  return df

In [None]:
train_df=get_data_from_text_files("train")
test_df=get_data_from_text_files("test")

In [None]:
all_data = train_df.append(test_df)

In [None]:
train_df

Unnamed: 0,review,sentiment
0,"This is a very grim, hard hitting, even brutal...",0
1,What was this supposed to be? A remake of Fish...,1
2,"It's hard to rate films like this, because do ...",0
3,"All I can say is, first movie this season that...",0
4,Gundam Wing to me happens to be a good anime. ...,0
...,...,...
24995,This movie is wonderful. It always has been al...,0
24996,"The French film ""Extension Du Domaine De La Lu...",1
24997,Three Russian aristocrats soak up the decadenc...,1
24998,"This movie was made in 1948, but it still ring...",0


In [None]:
def custom_standardization(input_data):
  lowercase=tf.strings.lower(input_data)
  stripped_html=tf.strings.regex_replace(lowercase,"<br />"," ")
  return tf.strings.regex_replace(stripped_html,"[%s]"%re.escape("!#$%&'()*+,-./:;<=>?@\^_`{|}~"),"")
  

In [None]:
special_tokens=["[MASK]"]
len(special_tokens)

1

In [None]:
def get_vectorize_layer(texts,vocab_size,max_seq,special_tokens=["[MASK"]):
  vectorize_layer=TextVectorization(max_tokens=vocab_size,output_mode="int",standardize=custom_standardization,output_sequence_length=max_seq)
  vectorize_layer.adapt(texts)
  vocab=vectorize_layer.get_vocabulary()
  vocab=vocab[2:vocab_size-len(special_tokens)]+["[mask]"]
  vectorize_layer.set_vocabulary(vocab)
  return vectorize_layer

In [None]:
all_data

Unnamed: 0,review,sentiment
0,"This is a very grim, hard hitting, even brutal...",0
1,What was this supposed to be? A remake of Fish...,1
2,"It's hard to rate films like this, because do ...",0
3,"All I can say is, first movie this season that...",0
4,Gundam Wing to me happens to be a good anime. ...,0
...,...,...
24995,This film is a completely inaccurate depiction...,1
24996,<br /><br />Well-known comedians meekly admit ...,0
24997,I got subjected to this pile one Wednesday aft...,1
24998,I've discovered this movie accidentally and it...,0


In [None]:
vectorize_layer=get_vectorize_layer(all_data.review.values.tolist(),config.VOCAB_SIZE,config.MAX_LEN,special_tokens=["[mask]"],)

In [None]:
print(all_data.review.values.tolist()) #convert values and append in the list

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [None]:
mask_token_id=vectorize_layer(["mask"]).numpy()[0][0]


In [None]:
mask_token_id

2368

In [None]:
list1=[1,2,3]+[4]

In [None]:
list1

[1, 2, 3, 4]

In [None]:
vectorize_layer

<tensorflow.python.keras.layers.preprocessing.text_vectorization.TextVectorization at 0x7fc048da77b8>

In [None]:
def encode(texts):
  encoded_texts=vectorize_layer(texts)
  return encoded_texts.numpy()

In [None]:
def get_masked_input_and_labels(encoded_texts):
  inp_mask=np.random.rand(*encoded_texts.shape)<0.15
  inp_mask[encoded_texts<=2]=False
  labels=-1*np.ones(encoded_texts.shape,dtype=int)
  labels[inp_mask]=encoded_texts[inp_mask] 
  encoded_texts_masked=np.copy(encoded_texts)
  inp_mask_2mask=inp_mask&(np.random.rand(*encoded_texts.shape)<0.90) 
#  print(encoded_texts_masked[0])
#  print(encoded_texts_masked[1])
#  print(labels[0])
#  print(labels[1])
  print(inp_mask_2mask[0])
  print(inp_mask_2mask[1])
  encoded_texts_masked[inp_mask_2mask]=mask_token_id
 # print(encoded_texts_masked[0])
 # print(encoded_texts_masked[1])
  inp_mask_2random=inp_mask_2mask & (np.random.rand(*encoded_texts.shape)<1/9)
  encoded_texts_masked[inp_mask_2random]=np.random.randint(3,mask_token_id,inp_mask_2random.sum())
  sample_weights=np.ones(labels.shape)
  sample_weights[labels==-1]=0
  y_labels=np.copy(encoded_texts)
  return encoded_texts_masked,y_labels,sample_weights  

In [None]:
a=np.random.rand(3,3)<0.15
a

array([[False, False, False],
       [ True, False, False],
       [False, False, False]])

In [None]:
a[0][1]=False

In [None]:
a[1][2]=False

In [None]:
a

array([[False, False, False],
       [ True, False, False],
       [False, False, False]])

In [None]:
x_train=encode(train_df.review.values)
y_train=train_df.sentiment.values
train_classifier_ds=(tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(1000).batch(config.BATCH_SIZE))

In [None]:
x_train

array([[   11,     7,     4, ...,     0,     0,     0],
       [   48,    13,    11, ...,     0,     0,     0],
       [   29,   263,     6, ...,  1263,    23,   179],
       ...,
       [  283,  1511, 17753, ...,     0,     0,     0],
       [   11,    17,    13, ...,     0,     0,     0],
       [ 1476,    10,   173, ...,     3,     1,   452]])

In [None]:
y_train

array([0, 1, 0, ..., 1, 0, 0])

In [None]:
x_test=encode(test_df.review.values)
y_test=test_df.sentiment.values
test_classifier_ds=tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(config.BATCH_SIZE)


In [None]:
test_raw_classifier_ds = tf.data.Dataset.from_tensor_slices(
    (test_df.review.values, y_test)
).batch(config.BATCH_SIZE)

In [None]:
test_raw_classifier_ds

<BatchDataset shapes: ((None,), (None,)), types: (tf.string, tf.int64)>

In [None]:
print(all_data.review.values)

["Japanese indie film with humor and philosophy where the three main characters run literally almost through the entire film, chasing each other due to strange circumstances and comical coincidence. As they are running, we see what is going on in their minds and how they got where they are at the moment. The act of running is a metaphor for these down-on-their luck people's lives. In some way, what they're really chasing for is not what they were originally chasing, but for meaning in their lives and an escape from their personal problems and broken dreams. Running makes them all feel truly alive. The big life-altering running adventure comes to an end when they accidentally get in the middle of something big, violent, and so absurd that it's funny in a clever way. One of my favorite films of all time by genius director Sabu."
 'This must have been one of the worst movies I have ever seen.<br /><br />I have to disagree with another commenter, who said the special effects were okay. I f

In [None]:
x_all_review=encode(all_data.review.values)
x_masked_train,y_masked_labels,sample_weights=get_masked_input_and_labels(x_all_review)


[False False False False False False False False  True False  True False
  True False False False False False False False False False  True False
  True False  True False  True False  True False  True False  True False
 False False False False False False False False False False False False
  True  True False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False Fa

In [None]:
mask_token_id

2368

In [None]:
x_train[0]

array([  923,  2782,    19,    15,   457,     3,  4072,   114,     2,
         283,   274,   100,   527,  1168,   210,   138,     2,   424,
          19,  3214,   247,    76,   668,     6,   676,  2201,     3,
        2825,  4914,    14,    34,    23,   632,    73,    64,    48,
           7,   164,    20,     8,    63,  2373,     3,    86,    34,
         183,   114,    34,    23,    30,     2,   549,     2,   497,
           5,   632,     7,     4,  4881,    16,   129,     1,  2124,
        1942,   465,     8,    46,    96,    48,   492,    62,  3214,
          16,     7,    21,    48,    34,    66,  1798,  3214,    18,
          16,  1173,     8,    63,   465,     3,    33,  1035,    35,
          63,   905,   680,     3,  1964,  1433,   632,   157,    90,
          31,   233,   351,  1158,     2,   196,     1,   632,  1230,
         261,     6,    33,   125,    50,    34,  2413,    75,     8,
           2,   750,     5,   137,   196,  1108,     3,    37,  1805,
          12,    29,

In [None]:
print(sample_weights[0])

[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1.
 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


In [None]:
x_masked_train[0]

array([  923,   634,    19,    15,   457,     3,  4072,   114,     2,
         283,   274,   100,   527,  1168,   210,   138,     2,   424,
          19,  3214,   247,    76,   668,  2368,   676,  2201,     3,
        2825,  4914,  2368,    34,    23,   632,    73,    64,  2368,
           7,   164,    20,  2368,    63,  2373,     3,    86,    34,
         183,   114,  2368,    23,    30,     2,   549,     2,   497,
           5,   632,     7,     4,  4881,    16,  2275,     1,  2124,
        1942,   465,     8,    46,    96,    48,   492,    62,  3214,
          16,     7,    21,    48,    34,    66,  1798,  3214,    18,
          16,  1173,     8,  2368,   465,     3,    33,  1035,    35,
          63,   905,   680,     3,  1964,  1433,   632,   157,    90,
          31,   233,   351,  1158,     2,   196,     1,   632,  1230,
         261,     6,    33,   125,    50,    34,  2368,    75,     8,
           2,   750,     5,   137,   196,  2368,     3,    37,  1805,
          12,    29,

In [None]:
y_masked_labels[0]

array([  923,  2782,    19,    15,   457,     3,  4072,   114,     2,
         283,   274,   100,   527,  1168,   210,   138,     2,   424,
          19,  3214,   247,    76,   668,     6,   676,  2201,     3,
        2825,  4914,    14,    34,    23,   632,    73,    64,    48,
           7,   164,    20,     8,    63,  2373,     3,    86,    34,
         183,   114,    34,    23,    30,     2,   549,     2,   497,
           5,   632,     7,     4,  4881,    16,   129,     1,  2124,
        1942,   465,     8,    46,    96,    48,   492,    62,  3214,
          16,     7,    21,    48,    34,    66,  1798,  3214,    18,
          16,  1173,     8,    63,   465,     3,    33,  1035,    35,
          63,   905,   680,     3,  1964,  1433,   632,   157,    90,
          31,   233,   351,  1158,     2,   196,     1,   632,  1230,
         261,     6,    33,   125,    50,    34,  2413,    75,     8,
           2,   750,     5,   137,   196,  1108,     3,    37,  1805,
          12,    29,

In [None]:
sample_weights[0]

array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
       0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
       0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0.

In [None]:
aa=np.random.rand(25000,256)<0.15

In [None]:
aa


array([[False, False,  True, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ...,  True, False, False],
       [False,  True, False, ..., False, False, False],
       [False, False, False, ..., False,  True, False]])

In [None]:
x_train[0]

array([  923,  2782,    19,    15,   457,     3,  4072,   114,     2,
         283,   274,   100,   527,  1168,   210,   138,     2,   424,
          19,  3214,   247,    76,   668,     6,   676,  2201,     3,
        2825,  4914,    14,    34,    23,   632,    73,    64,    48,
           7,   164,    20,     8,    63,  2373,     3,    86,    34,
         183,   114,    34,    23,    30,     2,   549,     2,   497,
           5,   632,     7,     4,  4881,    16,   129,     1,  2124,
        1942,   465,     8,    46,    96,    48,   492,    62,  3214,
          16,     7,    21,    48,    34,    66,  1798,  3214,    18,
          16,  1173,     8,    63,   465,     3,    33,  1035,    35,
          63,   905,   680,     3,  1964,  1433,   632,   157,    90,
          31,   233,   351,  1158,     2,   196,     1,   632,  1230,
         261,     6,    33,   125,    50,    34,  2413,    75,     8,
           2,   750,     5,   137,   196,  1108,     3,    37,  1805,
          12,    29,

In [None]:
y_masked_labels[0]

array([  923,  2782,    19,    15,   457,     3,  4072,   114,     2,
         283,   274,   100,   527,  1168,   210,   138,     2,   424,
          19,  3214,   247,    76,   668,     6,   676,  2201,     3,
        2825,  4914,    14,    34,    23,   632,    73,    64,    48,
           7,   164,    20,     8,    63,  2373,     3,    86,    34,
         183,   114,    34,    23,    30,     2,   549,     2,   497,
           5,   632,     7,     4,  4881,    16,   129,     1,  2124,
        1942,   465,     8,    46,    96,    48,   492,    62,  3214,
          16,     7,    21,    48,    34,    66,  1798,  3214,    18,
          16,  1173,     8,    63,   465,     3,    33,  1035,    35,
          63,   905,   680,     3,  1964,  1433,   632,   157,    90,
          31,   233,   351,  1158,     2,   196,     1,   632,  1230,
         261,     6,    33,   125,    50,    34,  2413,    75,     8,
           2,   750,     5,   137,   196,  1108,     3,    37,  1805,
          12,    29,

In [None]:
labb=-1*np.ones((5,5),dtype=int)

In [None]:
c=[[11,22,33,44,55],
   [2,3,4,5,6],
   [3,4,5,6,7],
   [4,5,6,7,8],
   [5,6,7,8,9]]

In [None]:
ab=np.random.rand(5,5)<0.15

In [None]:
labb

array([[-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1]])

In [None]:
ab

array([[False, False,  True, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False]])

In [None]:
ab.astype(int)

array([[0, 0, 1, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0]])

In [None]:
c

[[11, 22, 33, 44, 55],
 [2, 3, 4, 5, 6],
 [3, 4, 5, 6, 7],
 [4, 5, 6, 7, 8],
 [5, 6, 7, 8, 9]]

In [None]:
labb[ab]

array([-1])

In [None]:
labb[ab.astype(int)]

array([[[-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1]],

       [[-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1]],

       [[-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1]],

       [[-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1]],

       [[-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1]]])

In [None]:
!pip install tf-nightly



In [None]:
def bert_module(query, key, value, i):
    # Multi headed self-attention
    attention_output = layers.MultiHeadAttention(
        num_heads=config.NUM_HEAD,
        key_dim=config.EMBED_DIM // config.NUM_HEAD,
        name="encoder_{}/multiheadattention".format(i),
    )(query, key, value)
    attention_output = layers.Dropout(0.1, name="encoder_{}/att_dropout".format(i))(
        attention_output
    )
    attention_output = layers.LayerNormalization(
        epsilon=1e-6, name="encoder_{}/att_layernormalization".format(i)
    )(query + attention_output)

    # Feed-forward layer
    ffn = keras.Sequential(
        [
            layers.Dense(config.FF_DIM, activation="relu"),
            layers.Dense(config.EMBED_DIM),
        ],
        name="encoder_{}/ffn".format(i),
    )
    ffn_output = ffn(attention_output)
    ffn_output = layers.Dropout(0.1, name="encoder_{}/ffn_dropout".format(i))(
        ffn_output
    )
    sequence_output = layers.LayerNormalization(
        epsilon=1e-6, name="encoder_{}/ffn_layernormalization".format(i)
    )(attention_output + ffn_output)
    return sequence_output


def get_pos_encoding_matrix(max_len, d_emb):
    pos_enc = np.array(
        [
            [pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]
            if pos != 0
            else np.zeros(d_emb)
            for pos in range(max_len)
        ]
    )
    pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2])  # dim 2i
    pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2])  # dim 2i+1
    return pos_enc


loss_fn = keras.losses.SparseCategoricalCrossentropy(
    reduction=tf.keras.losses.Reduction.NONE
)
loss_tracker = tf.keras.metrics.Mean(name="loss")


class MaskedLanguageModel(tf.keras.Model):
    def train_step(self, inputs):
        if len(inputs) == 3:
            features, labels, sample_weight = inputs
        else:
            features, labels = inputs
            sample_weight = None

        with tf.GradientTape() as tape:
            predictions = self(features, training=True)
            loss = loss_fn(labels, predictions, sample_weight=sample_weight)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Compute our own metrics
        loss_tracker.update_state(loss, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value
        return {"loss": loss_tracker.result()}

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [loss_tracker]


def create_masked_language_bert_model():
    inputs = layers.Input((config.MAX_LEN,), dtype=tf.int64)

    word_embeddings = layers.Embedding(
        config.VOCAB_SIZE, config.EMBED_DIM, name="word_embedding"
    )(inputs)
    position_embeddings = layers.Embedding(
        input_dim=config.MAX_LEN,
        output_dim=config.EMBED_DIM,
        weights=[get_pos_encoding_matrix(config.MAX_LEN, config.EMBED_DIM)],
        name="position_embedding",
    )(tf.range(start=0, limit=config.MAX_LEN, delta=1))
    embeddings = word_embeddings + position_embeddings

    encoder_output = embeddings
    for i in range(config.NUM_LAYERS):
        encoder_output = bert_module(encoder_output, encoder_output, encoder_output, i)

    mlm_output = layers.Dense(config.VOCAB_SIZE, name="mlm_cls", activation="softmax")(
        encoder_output
    )
    mlm_model = MaskedLanguageModel(inputs, mlm_output, name="masked_bert_model")

    optimizer = keras.optimizers.Adam(learning_rate=config.LR)
    mlm_model.compile(optimizer=optimizer)
    return mlm_model


id2token = dict(enumerate(vectorize_layer.get_vocabulary()))
token2id = {y: x for x, y in id2token.items()}


class MaskedTextGenerator(keras.callbacks.Callback):
    def __init__(self, sample_tokens, top_k=5):
        self.sample_tokens = sample_tokens
        self.k = top_k

    def decode(self, tokens):
        return " ".join([id2token[t] for t in tokens if t != 0])

    def convert_ids_to_tokens(self, id):
        return id2token[id]

    def on_epoch_end(self, epoch, logs=None):
        prediction = self.model.predict(self.sample_tokens)

        masked_index = np.where(self.sample_tokens == mask_token_id)
        masked_index = masked_index[1]
        mask_prediction = prediction[0][masked_index]

        top_indices = mask_prediction[0].argsort()[-self.k :][::-1]
        values = mask_prediction[0][top_indices]

        for i in range(len(top_indices)):
            p = top_indices[i]
            v = values[i]
            tokens = np.copy(sample_tokens[0])
            tokens[masked_index[0]] = p
            result = {
                "input_text": self.decode(sample_tokens[0].numpy()),
                "prediction": self.decode(tokens),
                "probability": v,
                "predicted mask token": self.convert_ids_to_tokens(p),
            }
            pprint(result)


sample_tokens = vectorize_layer(["I have watched this [mask] and it was awesome"])
generator_callback = MaskedTextGenerator(sample_tokens.numpy())

bert_masked_model = create_masked_language_bert_model()
bert_masked_model.summary()


Model: "masked_bert_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 256)]        0                                            
__________________________________________________________________________________________________
word_embedding (Embedding)      (None, 256, 128)     3840000     input_2[0][0]                    
__________________________________________________________________________________________________
tf.__operators__.add_1 (TFOpLam (None, 256, 128)     0           word_embedding[0][0]             
__________________________________________________________________________________________________
encoder_0/multiheadattention (M (None, 256, 128)     66048       tf.__operators__.add_1[0][0]     
                                                                 tf.__operators__.

In [None]:
x_all_review = encode(all_data.review.values)
x_masked_train, y_masked_labels, sample_weights = get_masked_input_and_labels(
    x_all_review
)

mlm_ds = tf.data.Dataset.from_tensor_slices(
    (x_masked_train, y_masked_labels, sample_weights)
)
mlm_ds = mlm_ds.shuffle(1000).batch(config.BATCH_SIZE)

In [None]:
next(iter(mlm_ds))

(<tf.Tensor: shape=(32, 256), dtype=int64, numpy=
 array([[2368,  416,  143, ...,    0,    0,    0],
        [  10,   68,   64, ...,    0,    0,    0],
        [  11, 9410,    5, ...,    0,    0,    0],
        ...,
        [2368,  279,   10, ...,    0,    0,    0],
        [   2,   17,    7, ...,    2, 2038,    3],
        [  10, 2368, 1191, ...,    0,    0,    0]])>,
 <tf.Tensor: shape=(32, 256), dtype=int64, numpy=
 array([[1611,  416,  143, ...,    0,    0,    0],
        [  10,   68,   64, ...,    0,    0,    0],
        [  11, 9410,    5, ...,    0,    0,    0],
        ...,
        [  60,  279,   10, ...,    0,    0,    0],
        [   2,   17,    7, ...,    2, 1080,    3],
        [  10,  229, 1191, ...,    0,    0,    0]])>,
 <tf.Tensor: shape=(32, 256), dtype=float64, numpy=
 array([[1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0

In [None]:
y_masked_labels

array([[  923,  2782,    19, ...,     0,     0,     0],
       [   11,   213,    25, ...,     0,     0,     0],
       [  802,  2886,    22, ...,     0,     0,     0],
       ...,
       [ 3758, 17382,     7, ...,     0,     0,     0],
       [ 1263, 16414,  2168, ...,     0,     0,     0],
       [   11,  2271,     1, ...,     0,     0,     0]])

In [None]:
mask_token_id

2368

In [None]:
x_masked_train[0]

array([  923,  1587,  2368,    15,   457,  2368,  4072,   114,     2,
         283,   274,   100,   527,  1168,   210,  2368,     2,  2368,
          19,  3214,   247,    76,   668,     6,   676,  2201,     3,
        2368,  4914,    14,    34,    23,   632,    73,  2368,    48,
           7,   164,    20,     8,    63,  2373,  2368,    86,    34,
         183,  2368,    34,    23,    30,     2,   549,     2,   497,
           5,  2368,     7,     4,  4881,    16,   129,     1,  2124,
        1942,  2368,     8,    46,    96,    48,   492,    62,  3214,
          16,     7,    21,  2368,   747,    66,  1698,  3214,    18,
          16,  1173,     8,    63,   465,     3,    33,  1035,    35,
        2368,   905,   680,     3,  1964,  1433,   632,   157,    90,
          31,   233,   351,  1158,     2,   196,     1,   632,  1230,
         261,     6,  2368,   125,    50,    34,  2413,  1800,     8,
           2,  2368,     5,  2368,   196,  1108,     3,    37,  1805,
          12,    29,

In [None]:
bert_masked_model.fit(mlm_ds, epochs=5, callbacks=[generator_callback])
bert_masked_model.save("bert_mlm_imdb.h5")

Epoch 1/5
  45/1563 [..............................] - ETA: 2:30:10 - loss: 8.2065

KeyboardInterrupt: ignored

In [None]:
mlm_model = keras.models.load_model(
    "bert_mlm_imdb.h5", custom_objects={"MaskedLanguageModel": MaskedLanguageModel}
)
pretrained_bert_model = tf.keras.Model(
    mlm_model.input, mlm_model.get_layer("encoder_0/ffn_layernormalization").output
)

# Freeze it
pretrained_bert_model.trainable = False


def create_classifier_bert_model():
    inputs = layers.Input((config.MAX_LEN,), dtype=tf.int64)
    sequence_output = pretrained_bert_model(inputs)
    pooled_output = layers.GlobalMaxPooling1D()(sequence_output)
    hidden_layer = layers.Dense(64, activation="relu")(pooled_output)
    outputs = layers.Dense(1, activation="sigmoid")(hidden_layer)
    classifer_model = keras.Model(inputs, outputs, name="classification")
    optimizer = keras.optimizers.Adam()
    classifer_model.compile(
        optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
    )
    return classifer_model


classifer_model = create_classifier_bert_model()
classifer_model.summary()

# Train the classifier with frozen BERT stage
classifer_model.fit(
    train_classifier_ds,
    epochs=5,
    validation_data=test_classifier_ds,
)

# Unfreeze the BERT model for fine-tuning
pretrained_bert_model.trainable = True
optimizer = keras.optimizers.Adam()
classifer_model.compile(
    optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
)
classifer_model.fit(
    train_classifier_ds,
    epochs=5,
    validation_data=test_classifier_ds,