# Encoder-Decoder MT Attention Network

In [None]:
import numpy as np
import seaborn as sns
import tensorflow as tf

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

### Load Data

In [None]:
formal = open('../Data/Supervised Data/Entertainment_Music/S_Formal_EM_Train.txt').read()
informal = open('../Data/Supervised Data/Entertainment_Music/S_Informal_EM_Train.txt').read()

formal_holdout = open('../Data/Supervised Data/Entertainment_Music/S_Formal_EM_ValTest.txt').read()
informal_holdout = open('../Data/Supervised Data/Entertainment_Music/S_Informal_EM_ValTest.txt').read()

In [None]:
f_corpus = ['<start>' + seq + '<end>' for seq in formal.split('\n')]
if_corpus = ['<start>' + seq + '<end>' for seq in informal.split('\n')]

f_holdout = ['<start>' + seq + '<end>' for seq in formal_holdout.split('\n')]
if_holdout = ['<start>' + seq + '<end>' for seq in informal_holdout.split('\n')]

### Preprocess data

In [None]:
def tokenize(corpus):
    """ Tokenize data and return tokenizer """
    tokenizer = Tokenizer(filters='!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n', oov_token='<OOV>')
    tokenizer.fit_on_texts(corpus)
    
    seqs = tokenizer.texts_to_sequences(corpus)
    padded_seqs = pad_sequences(seqs, padding='post')
    return padded_seqs, tokenizer

In [None]:
input_train, input_tokenizer = tokenize(if_corpus)
target_train, target_tokenizer = tokenize(f_corpus)

In [None]:
input_tokenizer.word_index

In [None]:
BUFFER_SIZE = len(input_train)
BATCH_SIZE = 64
steps_per_epoch = len(input_train)//BATCH_SIZE
embedding_dim = 256
units= 1024
input_vocab_size = len(input_tokenizer.word_index) + 1
target_vocab_size = len(target_tokenizer.word_index) + 1

train = tf.data.Dataset.from_tensor_slices((input_train, target_train)).shuffle(BUFFER_SIZE)
train = train.batch(BATCH_SIZE, drop_remainder=True)

In [None]:
example_input_batch, example_target_batch = next(iter(train))
example_input_batch.shape, example_target_batch.shape

### Encoder

In [None]:
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, encoder_units, batch_size):
        super(Encoder, self).__init__()
        self.batch_size = batch_size
        self.encoder_units = encoder_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(self.encoder_units,
                                       return_sequences=True, 
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')
        
    def call(self, x, hidden):
        x = self.embedding(x)
        output, state = self.gru(x, initial_state=hidden)
        return output, state
    
    def initialize_hidden_state(self):
        return tf.zeros((self.batch_size, self.encoder_units))

In [None]:
encoder = Encoder(input_vocab_size, embedding_dim, units, BATCH_SIZE)

In [None]:
sample_hidden = encoder.initialize_hidden_state()
sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)

In [None]:
target_tokenizer.word_index['with']

### Attention Layer

In [None]:
class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)
        
    def call(self, query, values):
        query_with_time_axis = tf.expand_dims(query, 1)
        score = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))
        attention_weights = tf.nn.softmax(score, axis=1)
        
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)
        
        return context_vector, attention_weights

In [None]:
attention_layer = BahdanauAttention(10)
attention_result, attention_weights = attention_layer(sample_hidden, sample_output)

### Decoder

In [None]:
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, decoder_units, batch_size):
        super(Decoder, self).__init__()
        self.batch_size = batch_size
        self.decoder_units = decoder_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(self.decoder_units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')
        self.fc = tf.keras.layers.Dense(vocab_size)
        self.attention = BahdanauAttention(self.decoder_units)
        
    def call(self, x, hidden, encoder_output):
        context_vector, attention_weights = self.attention(hidden, encoder_output)
        
        x = self.embedding(x)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
        
        output, state = self.gru(x)
        output = tf.reshape(output, (-1, output.shape[2]))
        
        x = self.fc(output)
        return x, state, attention_weights

In [None]:
decoder = Decoder(target_vocab_size, embedding_dim, units, BATCH_SIZE)

sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),
                                      sample_hidden, sample_output)

### Optimizer and Loss Function

In [None]:
optimizer = tf.keras.optimizers.Adam()
static_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

In [None]:
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss = static_loss(real, pred)
     
    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask
    return tf.reduce_mean(loss)

### Training

In [None]:
@tf.function
def train_step(inpt, trgt, enc_hidden):
    loss = 0
    
    with tf.GradientTape() as tape:
        enc_output, enc_hidden = encoder(inpt, enc_hidden)
        
        dec_hidden = enc_hidden
        
        dec_input = tf.expand_dims([target_tokenizer.word_index['<start>']] * BATCH_SIZE, 1)
        
        for t in range(1, trgt.shape[1]):
            predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
            loss += loss_function(trgt[:, y], predicitons)
            dec_input = tf.expand_dims(trgt[:, t], 1)
            
        batch_loss = loss / int(trgt.shape[1])
        
        variables = encoder.trainable_variables + decoder.trainable_variables
        
        gradients = tape.gradient(loss, variables)
        
        optimizer.apply_gradients(zip(gradients, variables))
        
        return batch_loss

In [None]:
from datetime import datetime

In [None]:
EPOCHS = 10
for epoch in range(EPOCHS):
    start = datetime.now()
    
    enc_hidden = encoder.initialize_hidden_state()
    total_loss = 0
    
    for (batch, (inpt, trgt)) in enumerate(train.take(steps_per_epoch)):
        batch_loss = train_step(inpt, trgt, enc_hidden)
        total_loss += batch_loss
        
        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,
                                                         batch,
                                                         batch_loss.numpy()))
        if (epoch + 1) % 2 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
            
        print('Epoch {} Loss {:.4f}'.format(epoch+1,
                                            total_loss/steps_per_epoch))
        print('Time taken for 1 epoch {} seconds\n'.format(datetime.now() - start))