# Credit - https://arminnorouzi.github.io/posts/2023/05/blog-post-13/

'''
GAN is a type of neural network that consists of two networks , a generator and a discriminator.
The Generator tries to creeate new data samples that are similar to the input data , while the discriminator tries to distinguish between real and fake data samples

Transformer used for language translation , text summerization and language modelling - consist of an encoder and decoderthat work together to process input sequence and generate output sequences .  The encoder process the input sequence and prodes a hidden representation of input , the decoder then takes the hidden representation and generates the output seqence

GAN is used for generative tasks , while transformer is used for task related to NLP , GAN generate new samples while transformers tranforms input sequence into output sequence

BENIFITS OF TRANSFORMERS OVER RNN AND LSTM
-------------------------------------------

> Long-Term dependencies
> Parallelization
> Handle variable-length inputs
> Attention-based mechanism

'''

In [None]:
!pip install -q -U tensorflow-text tensorflow
!pip install tensorflow_datasets

In [None]:
import logging
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow as tf
import tensorflow_text


In [None]:
examples,metadata = tfds.load('ted_hrlr_translate/pt_to_en',
                              with_info=True,
                               as_supervised=True)
train_examples,val_examples = examples['train'],examples['validation']

In [None]:
#visualizing Example sentences
for pt_examples,en_examples in train_examples.batch(3).take(1):
    print('> Examples in Portugese:')
    for pt in pt_examples.numpy():
        print(pt.decode('utf-8'))
    print()
    
    print('> Examples in English:')
    for pt in en_examples.numpy():
        print(pt.decode('utf-8'))


In [None]:
# Setting up the tokenizer
model_name = 'ted_hrlr_translate_pt_en_converter'
tf.keras.utils.get_file(
    f'{model_name}.zip',
    f'https://storage.googleapis.com/download.tensorflow.org/models/{model_name}.zip',
    cache_dir='data',cache_subdir='',extract=True
)

In [None]:
tokenizers = tf.saved_model.load(f'data/{model_name}')

In [None]:
#Checking tokenize examples
print('> This is a batch of strings:')
for en in en_examples.numpy():
    print(en.decode('utf-8'))


In [None]:
encoded = tokenizers.en.tokenize(en_examples)
print('> This is a padded-batch of token IDs:')
for row in encoded.to_list():
    print(row)

In [None]:
round_trip = tokenizers.en.detokenize(encoded)
print('> This is human-redable text:')
for line in round_trip.numpy():
    print(line.decode('utf-8'))

In [None]:
print('> This is the text split into tokens:')
tokens = tokenizers.en.lookup(encoded)
tokens

In [None]:
lengths = []
for pt_examples,en_examples in train_examples.batch(1024):
    pt_tokens = tokenizers.pt.tokenize(pt_examples)
    lengths.append(pt_tokens.row_lengths())
    
    en_tokens = tokenizers.en.tokenize(en_examples)
    lengths.append(en_tokens.row_lengths())
    print('.',end='',flush=True)
    
all_lengths = np.concatenate(lengths)

In [None]:
plt.hist(all_lengths,np.linspace(0,500,101))
plt.ylim(plt.ylim())
avg_length = all_lengths.mean()
plt.plot([avg_length,avg_length],plt.ylim())
max_length = max(all_lengths)
plt.plot([max_length,max_length],plt.ylim())
plt.title(f'Maximum tokens per example : {max_length} and average tokens per example : {avg_length}');


In [None]:
#Setting up data Pipeline
MAX_TOKENS=128
def prepare_batch(pt,en):
    """
    Preprocess a batch of porteguse and english sentences for training a machine translation model.
    
    Args:
        pt: A tensor of porteguse sentences of shape (batch_size,) and dtype tf.string
        en: A tensor of english sentences of shape (batch_size,) and dtype tf.string
        
    Returns:
        A tuple of two tensors representing the input and output sequences for the model,and a tensor of shape (batch_size,max_length)
        representing the ground truth out sequences . The input sequence tensor has shape (batch_size,max_length) and dtype tf.int64, 
        and the output sequence has shape (batch_size,max_length) and dtype tf.int64
    """
    
    pt = tokenizers.pt.tokenize(pt)
    pt = pt[:,:MAX_TOKENS]
    pt = pt.to_tensor()
    
    en = tokenizers.en.tokenize(en)
    en = en[:, :(MAX_TOKENS+1)]
    en_inputs = en[:,:-1].to_tensor()
    en_labels = en[:,1:].to_tensor()
    
    return (pt,en_inputs),en_labels

In [None]:
BUFFER_SIZE = 20000
BATCH_SIZE = 64

In [None]:
def make_batches(ds):
    return(
        ds.shuffle(BUFFER_SIZE)
        .batch(BATCH_SIZE)
        .map(prepare_batch,tf.data.AUTOTUNE)
        .prefetch(buffer_size=tf.data.AUTOTUNE))

In [None]:
train_batches = make_batches(train_examples)
val_batches = make_batches(val_examples)

for (pt,en),en_labels in train_batches.take(1):
    break
    
print(f'pt.shape : {pt.shape}')
print(f'en_labels.shape : {en_labels.shape}')


attention layer do not rely on the order of the token in the input sequence , 
becuase the model does not contain any recurrent or convolutional layers that would inheritly capture the sequence order.

To overcome this - Transformer model adds Positional encoding to the embedding vectors . 
the positional Encoding uses a set of sines and cosines at different frequencies accross the sequence

https://www.youtube.com/watch?v=dichIcUZfOw



In [None]:
def positional_encoding(length,depth):
    '''
    Generates a matrix of position encodings for an input sequence.
    
    Args:
        length: An integer representing the length of the input sequence
        depth: An integer representing the dimentionality of the encoding
        
    Returns:
        A tf.tensor of shape '(length,depth)' representing the position encoding matrix
    '''
    
    depth = depth/2
    
    positions = np.arange(length)[:,np.newaxis]
    depths = np.arange(depth)[np.newaxis,:]/depth
    
    angle_rates = 1 / (10000**depths)
    angle_rads = positions * angle_rates
    
    pos_encoding = np.concatenate(
        [np.sin(angle_rads),np.cos(angle_rads)],
        axis=-1)
    
    return tf.cast(pos_encoding,dtype= tf.float32)
    
    
    

In [None]:
pos_encoding = positional_encoding(length=2048,depth=512)

print(pos_encoding.shape)

plt.pcolormesh(pos_encoding.numpy().T,cmap='RdBu')
plt.ylabel('Depth')
plt.xlabel('Position')
plt.colorbar()
plt.show()

In [None]:
pos_encoding /= tf.norm(pos_encoding,axis=1,keepdims=True)
p = pos_encoding[1000]
dots = tf.einsum('pd,d -> p',pos_encoding,p)
plt.subplot(2,1,1)
plt.plot(dots)
plt.ylim([0,1])
plt.plot([950,950,float('nan'),1050,1050],
        [0,1,float('nan'),0,1],color='k',label='Zoom')
plt.legend()
plt.subplot(2,1,2)
plt.plot(dots)
plt.xlim([950,1050])
plt.ylim([0,1])

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):
    '''
    This Layer combines the input embedding with a positional encoding that helps the transformer to understand
    the relative position of the tokens in sequence . Takes an input sequence of tokens and converts it to a sequence 
    of embedding vectors, then adds positional information to it
    '''
    def __init__(self,vocab_size,d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size,d_model,mask_zero=True)
        self.pos_encoding = positional_encoding(length=2048,depth=d_model)
        
    def compute_mask(self,*args,**kwargs):
        return self.embedding.compute_mask(*args,**kwargs)
    
    def call(self,x):
        length= tf.shape(x)[1]
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model,tf.float32))
        x = x + self.pos_encoding[tf.newaxis,:length,:]
        return x
        

In [None]:
embed_pt = PositionalEmbedding(vocab_size=tokenizers.pt.get_vocab_size(),d_model=512)
embed_en = PositionalEmbedding(vocab_size=tokenizers.en.get_vocab_size(),d_model=512)

pt_emb = embed_pt(pt)
en_emb = embed_en(en)

In [None]:
en_emb._keras_mask

In [None]:
class BaseAttention(tf.keras.layers.Layer):
    def __init__(self,**kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm  = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()
        
    

In [None]:
#The cross attention layer : Decoder-Encoder attention
class CrossAttention(BaseAttention):
    def call(self,x,context):
        attn_output,attn_scores = self.mha(
            query=x,
            key=context,
            value=context,
            return_attention_scores=True
        )
        
        self.last_attn_scores = attn_scores
        
        x =  self.add([x,attn_output])
        x = self.layernorm(x)
        
        return x

In [None]:
sample_ca = CrossAttention(num_heads=2,key_dim=512)

print(pt_emb.shape)
print(en_emb.shape)
print(sample_ca(en_emb,pt_emb).shape)
# the output length is the length of the query sequence and not the  length of the context "key/value" space

In [None]:
class GlobalSelfAttention(BaseAttention):
    def call(self,x):
        attn_output = self.mha(
            query=x,
            key=x, 
            value=x
        )
        x = self.add([x,attn_output])
        x = self.layernorm(x)
        return x

In [None]:
sample_gsa = GlobalSelfAttention(num_heads=2,key_dim=512)

print(pt_emb.shape)
print(sample_gsa(pt_emb).shape)
#output tensor has the same shape as the input

In [None]:
#the causal self  attention layer : Decoder  self-attention
#used when the output of each timestep can only depend  on previous time  steps,and not  in future time steps.
#in such task the causal self-attention layer is used to enforce the constraint that the model can only attend to the previous time steps during decoding process
class CausalSelfAttention(BaseAttention):
    def call(self,x):
        attn_output = self.mha(
            query = x,
            value=x, 
            key=x, 
            use_causal_mask=True
        )
        x = self.add([x,attn_output])
        x = self.layernorm(x)
        return x

In [None]:
sample_csa = CausalSelfAttention(num_heads=2,key_dim=512)

print(en_emb.shape)
print(sample_csa(en_emb).shape)

In [None]:
# TheOutput of the  early  sequence  elements does  not depend  on later elements,
# so it should not matter if you trim elements before of after applying the layer
out1 = sample_csa(embed_en(en[:,:3]))
out2 = sample_csa(embed_en(en))[:,:3]

tf.reduce_max(abs(out1-out2)).numpy()

In [None]:
#The feed forward network
#FeedForward class implements feedforward neural network,
#used in Transformer based model to process each token representation
class FeedForward(tf.keras.layers.Layer):
    def __init__(self,d_model,diff,dropout_rate=0.1):
        super().__init__()
        self.seq= tf.keras.Sequential([
            tf.keras.layers.Dense(diff,activation='relu'),
            tf.keras.layers.Dense(d_model),
            tf.keras.layers.Dropout(dropout_rate)
        ])
        self.add  = tf.keras.layers.Add()
        self.layer_norm = tf.keras.layers.LayerNormalization()
        
    def call(self,x):
        x = self.add([x,self.seq(x)])
        x=self.layer_norm(x)
        return x

In [None]:
sample_ffn = FeedForward(512,2048)
print(en_emb.shape)
print(sample_ffn(en_emb).shape)

In [None]:
#The Encoder
'''
The Encoder consists of a  PositionalEmbedding layer at the input and a stck of EncoderLayer Layers. 
Each EncoderLayer contains a GlobalSelfAttention and FeedForward layer
'''
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self,*,d_model,num_heads,dff,dropout_rate=0.1):
        super().__init__()
        self.self_attention =GlobalSelfAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate
        )
        self.ffn = FeedForward(d_model,dff)
        
    def call(self,x):
        x = self.self_attention(x)
        x = self.ffn(x)
        return x
    

In [None]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self,*,num_layers,d_model,num_heads,dff,vocab_size,dropout_rate=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        
        self.pos_embedding = PositionalEmbedding(
            vocab_size=vocab_size,
            d_model=d_model
        )
        
        self.enc_layers =[
            EncoderLayer(d_model=d_model,
                         num_heads=num_heads,
                         dff=dff,
                         dropout_rate=dropout_rate) 
            for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        
    def  call(self,x):
        x = self.pos_embedding(x)
        x  = self.dropout(x)
        for i in range(self.num_layers):
            x = self.enc_layers[i](x)
        return x
        

In [None]:
#Testing the Encoder
with tf.device("CPU"):
    sample_encoder = Encoder(num_layers=4,
                             d_model=512,
                             num_heads=8,
                             dff=2048,
                             vocab_size=8500)
    sample_encoder_output = sample_encoder(pt,training=False)
    
    print(pt.shape)
    print(sample_encoder_output.shape)

In [None]:
#The Decoder
class DecoderLayer(tf.keras.layers.Layer):
    '''
    A single layer of the decoder in a transformer  based architecture
    '''
    def __init__(self,
                 *,
                 d_model,
                 num_heads,
                 dff,
                 dropout_rate=0.1):
        super(DecoderLayer,self).__init__()
        
        self.causal_self_attention = CausalSelfAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate
        )
        
        self.cross_attention = CrossAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate
        )
        
        self.ffn = FeedForward(d_model,dff)
        
    def call(self,x,context):
        x = self.causal_self_attention(x=x)
        x = self.cross_attention(x=x,context=context)
        # caching the last attention score for plotting later
        self.last_attn_score = self.cross_attention.last_attn_scores
        x = self.ffn(x)
        return x

In [None]:
# The decoder class
class Decoder(tf.keras.layers.Layer):
    def __init__(self,*,num_layers,d_model,num_heads,dff,vocab_size,dropout_rate=0.1):
        super(Decoder,self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                                 d_model=d_model)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.dec_layers = [
            DecoderLayer(d_model=d_model,
                         num_heads=num_heads,
                         dff=dff,
                         dropout_rate=dropout_rate)
        for _ in range(num_layers)]
        self.last_attn_scores = None
        
    def call(self,x,context):
        x = self.pos_embedding(x)
        x = self.dropout(x)
        
        for i in range(self.num_layers):
            x = self.dec_layers[i](x,context)
        
        self.last_attn_scores = self.dec_layers[-1].last_attn_score
        return x
            
        
        

In [None]:
#Testing the decoder
with tf.device("CPU"):
    sample_decoder = Decoder(num_layers=4,
                             d_model=512,
                             num_heads=8,
                             dff = 2048,
                             vocab_size=8000)
    
    output = sample_decoder(
        x = en,
        context = pt_emb
    )
    
    print(en.shape)
    print(pt_emb.shape)
    print(output.shape)


In [None]:
sample_decoder.last_attn_scores.shape

In [None]:
#THE TRANSFORMER
class Transformer(tf.keras.Model):
    #A model that consists of encode decoder and final dense layer
    def __init__(self,*,num_layers,d_model,num_heads,dff,input_vocab_size,target_vocab_size,dropout_rate =  0.1):
        super().__init__()
        self.encoder =  Encoder(num_layers=num_layers,
                                d_model=d_model,
                                num_heads=num_heads,
                                dff=dff,
                                vocab_size=input_vocab_size,
                                dropout_rate=dropout_rate)
        
        self.decoder =  Decoder(num_layers=num_layers,
                                d_model=d_model,
                                num_heads=num_heads,
                                dff=dff,
                                vocab_size=target_vocab_size,
                                dropout_rate=dropout_rate)
        
        self.final_layer = tf.keras.layers.Dense(target_vocab_size)
        
    def call(self,inputs):
        context,x =  inputs
        context = self.encoder(context)
        x = self.decoder(x,context)
        logits = self.final_layer(x)
        
        try:
            del  logits._keras_mask
        except AttributeError:
            pass
        
        return logits

In [None]:
num_layers  =  4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.1

transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=tokenizers.pt.get_vocab_size().numpy(),
    target_vocab_size=tokenizers.en.get_vocab_size().numpy(),
    dropout_rate=dropout_rate
)

In [None]:
output = transformer((pt,en))
print(en.shape)
print(pt.shape)
print(output.shape)

In [None]:
transformer.summary()

In [None]:
#Training
#Optimizer
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self,d_model,warmup_steps=4000):
        super().__init__()
        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)
        
        self.warmup_steps = warmup_steps
        
    def __call__(self,step):
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        
        return tf.math.rsqrt(self.d_model)*tf.math.minimum(arg1, arg2)

In [None]:
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate,beta_1=0.9,beta_2=0.98,epsilon=1e-9)

In [None]:
plt.plot(learning_rate(tf.range(40000,dtype=tf.float32)))
plt.ylabel('Learning Rate')
plt.xlabel('Train Step')

In [None]:
def masked_loss(label,pred):
    mask = label != 0
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True,reduction='none'
    )
    loss = loss_object(label,pred)
    mask = tf.cast(mask, dtype=loss.dtype)
    loss *=  mask
    loss = tf.reduce_mean(loss)/tf.reduce_sum(mask)
    return loss

In [None]:
def masked_accuracy(label,pred):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred
    
    mask = label != 0
    match = match & mask
    
    match = tf.cast(match, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(match)/tf.reduce_sum(mask)

In [None]:
#Training
with tf.device("CPU"):
    transformer.compile(loss=masked_loss,
                        optimizer=optimizer,
                        metrics=[masked_accuracy])
    
    transformer.fit(train_batches,
                    epochs=20,
                    validation_data=val_batches)

In [None]:
#Testing
#Translator  Class
class Translator(tf.Module):
    def __init__(self,tokenizers,transformer):
        self.tokenizers = tokenizers
        self.transformer = transformer
        
    def __call__(self,sentence,max_length=MAX_TOKENS):
        assert isinstance(sentence,tf.Tensor)
        if len(sentence.shape) == 0:
            sentence = sentence[tf.newaxis]
        sentence = self.tokenizers.pt.tokenize(sentence).to_tensor()
        encoder_input = sentence
        
        start_end = self.tokenizers.en.tokenize([''])[0]
        start = start_end[0][tf.newaxis]
        end = start_end[1][tf.newaxis]
        
        outtput_array = tf.TensorArray(dtype = tf.int64,size = 0 ,dynamic_size = True)
        outtput_array = outtput_array.write(0,start)
        
        for i in tf.range(max_length):
            output = tf.transpose(outtput_array.stack())
            predictions = self.transformer([encoder_input,output],training=False)
            
            predictions = predictions[:-1:,:]
            predicted_id = tf.argmax(predictions,axis=1)
            
            output_array = outtput_array.write(i+1,predicted_id[0])
            
            if predicted_id == end:
                break
                
            output = tf.transpose(output_array.stack())
            text = tokenizers.en.detokenize(output)[0]
            
            tokens = tokenizers.en.lookup(output)[0]
            
            self.transformer([encoder_input,output[:,:-1]],training=False)
            attention_weights = self.transformer.decoder.last_attn_scores
            return text,tokens,attention_weights
            
            

In [None]:
translator=Translator(tokenizers,transformer)

In [None]:
def print_translation(sentence,tokens,ground_truth):
    print(f'{"Input:":15s}: {sentence}')
    print(f'{"Prediction:":15s}: {tokens.numpy().decode("utf-8")}')
    print(f'{"Ground truth:":15s}: {ground_truth}')