In [2]:
import tensorflow as tf
from keras.datasets import imdb
from keras.utils import pad_sequences
from tensorflow.keras import layers
from tensorflow.keras import models
from keras.layers import *
from keras.models import *


In [3]:
class TransformerBlock(layers.Layer):
  def __init__(self,embed_dim,num_heads,ff_dim,rate = 0.1):

    super().__init__()
    self.att = layers.MultiHeadAttention(num_heads=num_heads,key_dim = embed_dim)
    self.ffn = Sequential([
        Dense(ff_dim,activation = 'relu'),
        Dense(embed_dim)
        ])

    self.layersnorm1 = LayerNormalization(epsilon = 1e-6)
    self.layersnorm2 = LayerNormalization(epsilon = 1e-6)

    self.dropout1 = Dropout(rate)
    self.dropout2 = Dropout(rate)

  def call(self,inputs,training):
    attn_output = self.att(inputs,inputs)

    attn_output = self.dropout1(attn_output,training = training)
    out1 = self.layersnorm1(inputs+attn_output)

    ffn_output = self.ffn(out1)
    ffn_output = self.dropout2(ffn_output,training = training)

    return self.layersnorm2(out1+ffn_output)


In [4]:
class TokenandPositionembedding(layers.Layer):
  def __init__(self,maxlen,vocab_size,embed_dim):
    super().__init__()
    self.token_emb = layers.Embedding(input_dim = vocab_size,output_dim = embed_dim)
    self.pos_emb = layers.Embedding(input_dim = maxlen,output_dim = embed_dim)


  def call(self,x):
    maxlen = tf.shape(x)[-1]
    positions = tf.range(start =0,limit = maxlen,delta =1)
    positions =self.pos_emb(positions)

    x = self.token_emb(x)

    return x+positions

In [5]:
from re import X
vocab_size=20000 #Only consider top 20k words
maxlen=200 # Only consider the first 200 words of each movie review
(X_train,y_train),(X_val,y_val)=imdb.load_data(num_words=vocab_size)
print(len(X_train),"Training Sequences")
print(len(X_val),"Validation sequences")
X_train=pad_sequences(X_train,maxlen=maxlen)
X_val=pad_sequences(X_val,maxlen=maxlen)

25000 Training Sequences
25000 Validation sequences


In [6]:
X_train[:1]

array([[    5,    25,   100,    43,   838,   112,    50,   670,     2,
            9,    35,   480,   284,     5,   150,     4,   172,   112,
          167,     2,   336,   385,    39,     4,   172,  4536,  1111,
           17,   546,    38,    13,   447,     4,   192,    50,    16,
            6,   147,  2025,    19,    14,    22,     4,  1920,  4613,
          469,     4,    22,    71,    87,    12,    16,    43,   530,
           38,    76,    15,    13,  1247,     4,    22,    17,   515,
           17,    12,    16,   626,    18, 19193,     5,    62,   386,
           12,     8,   316,     8,   106,     5,     4,  2223,  5244,
           16,   480,    66,  3785,    33,     4,   130,    12,    16,
           38,   619,     5,    25,   124,    51,    36,   135,    48,
           25,  1415,    33,     6,    22,    12,   215,    28,    77,
           52,     5,    14,   407,    16,    82, 10311,     8,     4,
          107,   117,  5952,    15,   256,     4,     2,     7,  3766,
      

In [7]:
embed_dim = 32
num_heads= 2
ff_dim = 32


inputs = layers.Input(shape=(maxlen,))
embedding_layer = TokenandPositionembedding(maxlen,vocab_size,embed_dim)

x= embedding_layer(inputs)

transformer_block = TransformerBlock(embed_dim,num_heads,ff_dim)
x = transformer_block(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(20,activation = 'relu')(x)
x = layers.Dropout(0.1)(x)
outputs = layers.Dense(2,activation = 'softmax')(x)
model = Model(inputs = inputs,outputs = outputs)

In [8]:
model.compile(optimizer = 'adam',loss = 'sparse_categorical_crossentropy',metrics = ['accuracy'])
history = model.fit(X_train,y_train,batch_size = 32,epochs= 10,validation_data=[X_val,y_val])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [9]:
model.evaluate(X_val,y_val)



[1.1192463636398315, 0.8324800133705139]

In [10]:
y_pred = model.predict(X_val)



In [12]:
import pandas as pd
pd.DataFrame(history.history)

Unnamed: 0,loss,accuracy,val_loss,val_accuracy
0,0.386625,0.81668,0.291627,0.872
1,0.198331,0.92464,0.314006,0.87312
2,0.128021,0.9534,0.360826,0.8654
3,0.083103,0.972,0.465782,0.85336
4,0.055873,0.98208,0.59606,0.84344
5,0.037295,0.98892,0.635949,0.84568
6,0.02888,0.99148,0.708026,0.83744
7,0.0294,0.99088,0.715428,0.83752
8,0.019442,0.99472,0.889971,0.83552
9,0.015559,0.99544,1.119246,0.83248
