# Attention is All You Need

In [1]:
import numpy as np
import seaborn as sns
import tensorflow as tf

import re 
import os
from datetime import datetime

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

### Declare Static Variables

These parameters are mostly stolen from the Google Paper

In [2]:
EMBEDDING_DIM = 256
ATTENTION_UNITS = 10
ENCODER_UNITS = 1024
DECODER_UNITS = 1024
BATCH_SIZE = 64

### Load Data

In [3]:
formal = open('../Data/Supervised Data/Entertainment_Music/S_Formal_EM_Train.txt').read()
informal = open('../Data/Supervised Data/Entertainment_Music/S_Informal_EM_Train.txt').read()

formal_holdout = open('../Data/Supervised Data/Entertainment_Music/S_Formal_EM_ValTest.txt').read()
informal_holdout = open('../Data/Supervised Data/Entertainment_Music/S_Informal_EM_ValTest.txt').read()

In [4]:
def process_sequence(seq):
    """This inserts a space in between the last word and a period"""
    s = re.sub('([.,!?()])', r' \1 ', seq)
    s = re.sub('\s{2,}', ' ', s)
    
    return '<start> ' + s + ' <end>'

In [80]:
def process_seq_target_input(seq):
    """
    This inserts a space in between the last word and a period
    This function covers shifting right for being fed to the Transformer
    """
    
    s = re.sub('([.,!?()])', r' \1 ', seq)
    s = re.sub('\s{2,}', ' ', s)
    
    return s + ' <end>'

In [81]:
f_corpus = [process_sequence(seq) for seq in formal.split('\n')]
f_corpus_input = [process_seq_target_input(seq) for seq in formal.split('\n')]
if_corpus = [process_sequence(seq) for seq in informal.split('\n')]

f_holdout = [process_sequence(seq) for seq in formal_holdout.split('\n')]
if_holdout = [process_sequence(seq) for seq in informal_holdout.split('\n')]

### Preprocess data

In [84]:
def tokenize(corpus):
    """ Tokenize data and pad sequences """
    tokenizer = Tokenizer(filters='!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n', oov_token='<OOV>')
    tokenizer.fit_on_texts(corpus)
    
    seqs = tokenizer.texts_to_sequences(corpus)
    padded_seqs = pad_sequences(seqs, maxlen=30, padding='post')
    return padded_seqs, tokenizer

In [85]:
input_train, input_tokenizer = tokenize(if_corpus)
target_train, target_tokenizer = tokenize(f_corpus)
target_input_train, target_input_tokenizer = tokenize(f_corpus_input)

In [86]:
buffer_size = len(input_train)
steps_per_epoch = len(input_train) // BATCH_SIZE
input_vocab_size = len(input_tokenizer.word_index) + 1
target_vocab_size = len(target_tokenizer.word_index) + 1

train = tf.data.Dataset.from_tensor_slices((input_train, target_input_train, target_train)).shuffle(buffer_size)
train = train.batch(BATCH_SIZE, drop_remainder=True)

In [88]:
example_input_batch, example_target_input_batch, example_target_batch = next(iter(train))

## Positional Embedding

In [10]:
def positional_embedding(p, model_size):
    p_emb = np.zeros((1, model_size))
    for i in range(model_size):
        if i % 2 == 0:
            p_emb[:, i] = np.sin(p / 10000 ** (i / model_size))
        else:
            p_emb[:, i] = np.cos(p / 10000 ** (i / model_size))
    return p_emb

max_length = input_train.shape[1]
MODEL_SIZE = 128

pes = [positional_embedding(i, MODEL_SIZE) for i in range(max_length)]

pes = np.concatenate(pes, axis=0)
pes = tf.constant(pes, dtype=tf.float32)

## Multi-Head Attention

Computing 
$$ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,...,head_h)W^o$$ 
where $$head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$
and attention is 
$$ \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$$ 

In [97]:
class MultiHeadAttention(tf.keras.Model):
    def __init__(self, model_size, h):
        super(MultiHeadAttention, self).__init__()
        self.query_size = model_size // h
        self.key_size = model_size // h
        self. value_size = model_size // h
        self.h = h
        self.wq = [tf.keras.layers.Dense(self.query_size) for _ in range(h)]
        self.wk = [tf.keras.layers.Dense(self.key_size) for _ in range(h)]
        self.wv = [tf.keras.layers.Dense(self.value_size) for _ in range(h)]
        self.wo = tf.keras.layers.Dense(model_size)

    def __one_head_attention(self, query, value, i, mask=None):
        """run for each query, value, key in h"""
        # query shape: (batch_size, query_length, model_size)
        # value shape: (batch_size, value_length, model_size)
        score = tf.matmul(self.wq[i](query), self.wk[i](value), transpose_b=True)

        # eq(1) from AAYN
        d_k = tf.math.sqrt(tf.cast(self.key_size, dtype=tf.float32))

        # score shape: (batch_size, query_length, value_length)
        score /= d_k
        
        # apply mask
        if mask:
            score *= mask
            score = tf.where(tf.equal(score, 0), tf.ones_like(score)*1e-6, score)

        # attention shape: (batch_size, query_length, value_length)
        attention = tf.nn.softmax(score, axis=2)

        # context shape: (batch_size, query_length, value_length)
        head = tf.matmul(attention, value)

        return head 

    def call(self, query, value, mask=None):
        """This computes the multi head attention by calling for each h"""
        # compute one head attention for each head
        multi_head = [self.__one_head_attention(query, value, i) for i in range(self.h)]

        # concat all heads 
        multi_head = tf.concat(multi_head, axis=2)

        # multi_head shape: (batch_size, query_length, model_size)
        mutli_head = self.wo(multi_head)

        return mutli_head

## Encoder

In [55]:
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, model_size, num_layers, h):
        super(Encoder, self).__init__()
        self.model_size = model_size
        self.num_layers = num_layers
        self.h = h

        self.embedding = tf.keras.layers.Embedding(vocab_size, model_size)

        self.mha = [MultiHeadAttention(model_size, h) for _ in range(num_layers)]
        self.mha_norm = [tf.keras.layers.BatchNormalization() for _ in range(num_layers)]

        self.FFN_l1 = [tf.keras.layers.Dense(4 * model_size, activation='relu') for _ in range(num_layers)]
        self.FFN_l2 = [tf.keras.layers.Dense(model_size) for _ in range(num_layers)]
        self.FFN_norm = [tf.keras.layers.BatchNormalization() for _ in range(num_layers)]

    def call(self, seq, mask=None):
        
        E_out = self.embedding(seq)
        E_out += pes[:seq.shape[1], :]

        # MultiHeadAttention
        for i in range(self.num_layers):
            mha_out = self.mha[i](E_out, E_out, mask)
            mha_out = self.mha_norm[i](E_out + mha_out)

            # Feed Forward Network
            FFN_out = self.FFN_l2[i](self.FFN_l1[i](mha_out))

            #  add and norm
            FFN_out = self.FFN_norm[i](FFN_out + mha_out)

        return FFN_out

## Decoder

In [72]:
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, model_size, num_layers, h):
        super(Decoder, self).__init__()
        self.model_size = model_size
        self.num_layers = num_layers
        self.h = h

        self.embedding = tf.keras.layers.Embedding(vocab_size, model_size)

        self.mha1 = [MultiHeadAttention(model_size, h) for _ in range(num_layers)]
        self.mha1_norm = [tf.keras.layers.BatchNormalization() for _ in range(num_layers)]
        self.mha2 = [MultiHeadAttention(model_size, h) for _ in range(num_layers)]
        self.mha2_norm = [tf.keras.layers.BatchNormalization() for _ in range(num_layers)]

        self.FFN_l1 = [tf.keras.layers.Dense(4 * model_size) for _ in range(num_layers)]
        self.FFN_l2 = [tf.keras.layers.Dense(model_size) for _ in range(num_layers)]
        self.FFN_norm = [tf.keras.layers.BatchNormalization() for _ in range(num_layers)]

        self.fc = tf.keras.layers.Dense(vocab_size)

    def call(self, seq, enc_opt, mask=None):
        E_out = self.embedding(seq)
        E_out *= pes[:seq.shape[1], :]
        
        for i in range(self.num_layers):
            # Define mask
            pad_mask = tf.linalg.band_part(tf.ones((len(seq), len(seq))), -1, 0)
            
            # First MHA layer
            mha1_out = self.mha1[i](E_out, E_out, pad_mask)
            mha1_out = self.mha1_norm[i](mha1_out + E_out)
            
            # Second MHA layer
            mha2_out = self.mha2[i](E_out, enc_opt, pad_mask)
            mha2_out = self.mha2_norm[i](mha2_out + mha1_out)
            
            # FFN
            FFN_out = self.FFN_l2[i](self.FFN_l1[i](mha2_out))
            FFN_out = self.FFN_norm[i](FFN_out + mha2_out)
        
            output = self.fc(FFN_out)
        
        return output

In [68]:
H = 2
NUM_LAYERS = 2

example_input_sequence = example_input_batch[0]
example_output_sequence = example_target_batch[0]

In [73]:
encoder = Encoder(input_vocab_size, MODEL_SIZE, NUM_LAYERS, H)
decoder = Decoder(target_vocab_size, MODEL_SIZE, NUM_LAYERS, H)

In [74]:
ex = tf.reshape(example_input_sequence, (1,example_input_sequence.shape[0]))
ex1 = tf.reshape(example_output_sequence, (1,example_output_sequence.shape[0]))

In [75]:
enc_output = encoder(ex)
dec_output = decoder(ex1, enc_output)

## Define Training

In [102]:
optimizer = tf.keras.optimizers.Adam()
static_loss = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True)

def loss_func(real, preds):
    """Calculate and return loss"""
    # caclulate loss
    loss = static_loss(real, preds)

    # create mask 
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    mask = tf.cast(mask, dtype=loss.dtype)

    return tf.reduce_sum(loss) / tf.reduce_sum(mask)

In [107]:
@tf.function
def train_step(in_seq, targ_in_seq, targ_out_seq):
    with tf.GradientTape() as tape:
        enc_opt = encoder(in_seq)
        dec_opt = decoder(targ_in_seq, enc_opt)
        loss = loss_func(targ_out_seq, dec_opt)
        
    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    
    return loss / targ_out_seq.shape[1]

In [109]:
NUM_EPOCHS = 10
start = datetime.now()
for epoch in range(NUM_EPOCHS):
    total_loss=0
    for batch, (inpt, targ_inpt, targ_out) in enumerate(train.take(steps_per_epoch)):
        batch_loss = train_step(inpt, targ_inpt, targ_out)
        total_loss += batch_loss
        
    if batch % 100 == 0:
        print('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,
                                                     batch,
                                                     batch_loss.numpy()))
    print('Epoch {} Loss {:.4f}'.format(epoch+1,
                                        total_loss/steps_per_epoch))
    print('Time taken for 1 epoch {} seconds\n'.format(datetime.now() - start))

KeyboardInterrupt: 