In [None]:
from tensorflow.keras.layers import Layer, Dense, LayerNormalization, Embedding, Dropout, Softmax, Masking
from tensorflow.keras.models import Model, Sequential
import tensorflow as tf
import numpy as np
import re
from tensorflow.keras.callbacks import Callback,ModelCheckpoint,CSVLogger,History
import time

In [None]:
d_model = 512 # 
N = 6 # Number of stacks of Encoders and decoders
h = 8 # parallel attention layers
p_drop = 0.1
dff=2048 # first layer of the FFN


In [None]:
class ScaledDotProductAttention(Layer):
    def __init__(self,d_model, num_heads, **kwargs):
        super(ScaledDotProductAttention, self).__init__(**kwargs)
        self.supports_masking = True
        self.scaling_factor = tf.sqrt(d_model/num_heads)

   
        
    def call(self, Q,K,V, causal_mask=None, mask=None): # As tested, mask correspond to the msak of the first arg. Q in this case
        
    
        key_mask = K._keras_mask
        out = tf.matmul(Q,K, transpose_b=True) # matmul   
        out = tf.divide(out,self.scaling_factor) # scaling

       
        total_mask = None
        # Padding mask will never be none due to embedding layer always passing it  the mask
        mask_Q_num = tf.cast(mask, "float32")
        mask_K_num = tf.cast(key_mask, "float32")
        mask_QK = tf.matmul(mask_Q_num[...,None],mask_K_num[:,None])
        
        # Mask with illegal connections due to padding. Here, illegal connections are set True value
        illegal_padding_mask = tf.logical_not( tf.cast(mask_QK, dtype=tf.bool) )
        total_mask = illegal_padding_mask
         
    

          
        if causal_mask is not None: # this is bool mask with illegal connections set to True
            #print("using causal masking")
            total_mask = tf.logical_or(causal_mask[None],total_mask)

        # Setting the illegal connections in the total mask to -infty to make them zero in the softmax computation
        out += tf.cast(total_mask, tf.float32) * tf.float32.min

 
        out = tf.nn.softmax(out,axis=-1)
        out = out * mask_Q_num[..., None]
        
        
        out = tf.matmul(out,V)
        return out
            
            

In [None]:
class MultiHeadAttention(Layer):

    """
    Notes: queries, keys, and values will be projected to learned linear projections of 
    dimension d_k, d_k, d_v
    """
    
    def __init__(self,d_model, num_heads, **kwargs):
        super(MultiHeadAttention, self).__init__( **kwargs)

        

        
        self.num_heads= num_heads
        self.scaledDotProductAttention = ScaledDotProductAttention(d_model, num_heads)
        
        self.W_Qs = []       # List of Query Weights matrices
        self.W_Ks = []       # List of Key Weights matrices
        self.W_Vs = []       # List of Values Weights matrices
        
        d_k = int(d_model/num_heads) # Projected key dimension
        d_v = d_k                    # Projected value dimension
        
        for i in range(num_heads): 
            self.W_Qs.append(Dense(units = d_k, use_bias=False))
            self.W_Ks.append(Dense(units = d_k, use_bias=False))
            self.W_Vs.append(Dense(units = d_v, use_bias=False))

        self.W_O =  Dense(units = d_model, use_bias=False) 

    def call(self, Q,K,V, causal_mask=None):
        attentionHeads = []
        for i in range(self.num_heads):

            # Project Querys, Keys and Values
            Q_i = self.W_Qs[i](Q) # Queries' Projection for the ith head
            K_i = self.W_Ks[i](K) # Keys' Projection for the ith head
            V_i = self.W_Vs[i](V) # Values' Projection for the ith head
            

            attentionHeads.append(self.scaledDotProductAttention(Q_i,K_i,V_i,causal_mask))

        # Concatenate all the attention heads on the last (feature) axis
        concat_heads = tf.concat(attentionHeads,axis=-1)

        # Final linear layer
        return self.W_O(concat_heads)



In [None]:
class FFN(Layer):
    
    def __init__(self,dff, **kwargs):
        super(FFN, self).__init__()
        self.dense1 =  Dense(units=dff, activation='relu')
        self.dense2 =  Dense(units = d_model)

    
    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return x

In [None]:
class ResidualBlock(Layer):
    def __init__(self, **kwargs):
        super(ResidualBlock, self).__init__(**kwargs)
        self.layerNorm = LayerNormalization()
        
    def call(self, sublayer_in, sub_layer_out, mask=None,):
        # masking input before addition. Remember pad values have non-zero vectors
        numeric_mask = tf.cast(mask[...,None],tf.float32)
        masked_in = sublayer_in*numeric_mask
        
        x = tf.add(sub_layer_out, masked_in) 
        x = self.layerNorm(x)
        return x

    

In [None]:
class SubLayer1(Layer):
    
    def __init__(self, d_model, num_heads, p_drop, **kwargs):
        super(SubLayer1, self).__init__(**kwargs)

        self.supports_masking = True # Because SubLayer2 needs masking     
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.dropout = Dropout(rate=p_drop)
        self.resid = ResidualBlock()

    def call(self, Q,K,V,X, training, causal_mask=None):
        
        x = self.mha(Q,K,V,causal_mask)
        x = self.dropout(x,training)

        x = self.resid(X,x)
        return x
         


In [None]:
class SubLayer2(Layer):
    def __init__(self,dff,p_drop, **kwargs):
        super(SubLayer2, self).__init__(**kwargs)
    
        self.ffn = FFN(dff)
        self.dropout = Dropout(rate=p_drop)
        self.resid = ResidualBlock()
        
    def call(self, X, training, mask=None ):
        
        x = self.ffn(X)
        x = self.dropout(x, training)

        x = self.resid(X,x)
        return x

In [None]:
#### Encoder
class EncoderLayer(Layer):

    
    def __init__(self, d_model, num_heads, dff, p_drop, **kwargs):
        super(EncoderLayer, self).__init__()
        self.supports_masking = True # This ensures the mask it recieves in the call method is passed as is to the next layer.
                                    # Maybe another stacked encoder layer or decoder
                                    # I confirmed that the the way mask arg of call method is set is if
                                    # all args have mask,then mask of the first arg is set.
                                    # Setting this variable just attaches mask to the output tensor of the call method.
                                    # When the next Layer's call method recieves this masked tensor, the mask arg is set as mentioned.

                                    # One thing I want to clarify. This layer doesn't use the mask, so no mask arg needed. However,
                                    # The layers called in the call method need masking info. This info is propogated as usual. This is
                                    # because encoder_input will have mask, secondly sublayer1 also returns tensor with mask. So each
                                    # sublayer gets the mask arg in their call method
                        
        self.self_attention = SubLayer1(d_model, num_heads, p_drop)
        self.position_wise = SubLayer2(dff, p_drop)

    def call(self, encoder_input, training, mask=None):
        # tokenEmbeddings: Embeddings of tokens in a sentence
        #self attention, so Q, K, V are all same
        #print("training:",training)
        x = encoder_input
        x = self.self_attention(x,x,x,x,training)
        x = self.position_wise(x,training)
        return x
     

In [None]:
def getPositionalEmbedding(seq_length=256, feature_dimension=None):


    # Returns featrure_dimension dimensional positional encodings for each position in a sequence of length seq_length.
    # return shape (sequ_length, feature_dimension)
    dimensions = tf.range(feature_dimension, dtype=tf.float32)
    positions = tf.range(seq_length, dtype=tf.float32)[...,None] # add additional dimension at the end for broadcasting
    even_dimensions = dimensions[::2]
    odd_dimensions   =  dimensions[1::2]

    feature_dimension = tf.cast(feature_dimension, tf.float32)
    even = tf.sin( positions/tf.pow(10000.,even_dimensions/feature_dimension)  )
    odd  = tf.cos( positions/tf.pow(10000.,odd_dimensions/feature_dimension) )

    # Since for a given position, concatenating the even and odd dimension is functinonally
    # equivalent to interleaving the even and odd dimensions, we will just stack them as follows. (I will later try interleave as well and check the time it takes) 

    sequence_position_embeddings = tf.concat([even,odd], axis=-1)
 
    return sequence_position_embeddings


In [None]:
class PositionalEncoding(Layer):
    """
    This layer takes ouput from the embedding layer. Embedding layer must have output shape of (batch, seq-len, features(d_model))
    """

    def __init__(self, **kwargs):
        super(PositionalEncoding, self).__init__(**kwargs)
        self.supports_masking=True

    def call(self, inputs):
        
        # inputs shape should be (batch, seq-len, features(d_model))
        #_,seq_len, features = tf.shape(inputs)#[1:]#.numpy()
        #_,seq_len, features = tf.split(tf.shape(inputs),3)
        shape= tf.shape(inputs)
        seq_len, features = shape[1],shape[2]

        positionalEncodings = getPositionalEmbedding(seq_len, features)
        positionalEncodings = positionalEncodings[None,...] # add batch dimension along which to broadcast

        x = tf.add(inputs,positionalEncodings)  # This is equivalent to inputs+positionalEncodings[None]. 
                                                # Here, I explicitly added a batch dimension for broadcasting, but its not needed.
                                                # You can confirm as follows:  print(tf.math.reduce_all((inputs+positionalEncodings[None])==x))
        return x

        


In [None]:
class DecoderLayer(Layer):
    def __init__(self,d_model, num_heads,dff,p_drop, **kwargs):
        super(DecoderLayer, self).__init__(**kwargs)
        
        self.supports_masking=True # Ensures padding mask's propogation to the next decoder layer.
        self.masked_self_attention = SubLayer1(d_model, num_heads, p_drop) 
        self.encoder_decoder_attention = SubLayer1(d_model, num_heads, p_drop)
        self.position_wise = SubLayer2(dff, p_drop) # Position-wise feed forward

    # Important note: First are should be the tensor whose padding mask needs to be propagated because I chose the easiest way for now.
    def call(self,  decoder_input, encoder_output,training, causal_mask):
        # tokenEmbeddings: Embeddings of tokens in a sentence
        #self attention, so Q, K, V are all same
        x = decoder_input
        x = self.masked_self_attention(x,x,x,x,training,causal_mask)
        x = self.encoder_decoder_attention(x,encoder_output,encoder_output,x, training)
        x = self.position_wise(x,training)
        return x
     

In [None]:
class Encoder(Layer):
    def __init__(self,N, d_model, num_heads,dff, p_drop, **kwargs):
        super(Encoder, self).__init__(**kwargs)

        self.supports_masking=True # Dobule checks that mask gets passed to the next layer. Decoder in this case
        self.encoder_layers = []
        for n in range(N):
            self.encoder_layers.append(EncoderLayer(d_model, num_heads,dff, p_drop))
    
    
    def call(self, inputs, training):
        x = inputs
        for layer in self.encoder_layers:
            x = layer(x,training) # Each layer recieves masked tensor because the previous layer has self.supports_mask=True
        return x # This output will have mask atthed to encoder_layers's implementation. However, for clarity I can also set the self.supports_masking=True
    


In [None]:
class Decoder(Layer):
    
    def __init__(self, N, d_model, num_heads, dff, seq_len, p_drop, **kwargs):
        super(Decoder, self).__init__()
        self.decoder_layers = []
        self.causal_mask = tf.constant(np.triu(np.ones((seq_len,seq_len)),k=1 ), dtype=tf.bool)

        for n in range(N):
            self.decoder_layers.append(DecoderLayer(d_model, num_heads,dff, p_drop))

    def call(self, enc_output,  dec_input, training):
        #def call(self, encoder_output, decoder_input, mask):
        x = dec_input
        for layer in self.decoder_layers:
            x = layer(x , enc_output,training,self.causal_mask)
        return x
    

In [None]:
class Scaling(Layer):
    # I created this layer, because simply doing a scalar multiplication with
    # a tensor with _keras_mask property resulted in a tensor with no mask.
    def __init__(self, scale, **kwargs):
        super(Scaling,self).__init__(**kwargs)
        self.supports_masking=True
        self.scale=scale
        
    def call(self, inputs):
        return inputs*self.scale
        


In [None]:
class Transformer(Model):
    def __init__(self, N, d_model, num_heads,dff,seq_len, vocab_size, p_drop, **kwargs):
        super(Transformer,self).__init__(**kwargs)
        self.shared_embedding = Embedding(vocab_size+1,d_model,mask_zero=True) # be becasue word indexes staart from 1, but embedding layer's embedding indexes start from 0.0 index is left for padding
        self.positional_encoding = PositionalEncoding()
        self.encoder = Encoder(N, d_model, num_heads,dff, p_drop)
        self.decoder =  Decoder(N, d_model, num_heads,dff,seq_len, p_drop)
        self.dropout1 = Dropout(p_drop)
        self.dropout2 = Dropout(p_drop)
        self.scale_embed =  Scaling(d_model**.5)

    def call(self, input, training=False):
        """
        enc_inputs : tokenized sequence of shape (batch, enc_seq_len)
        dec_inputs : tokenized sequence of shape (batch, dec_seq_len)
        """

        # YYYou missed something very important. Add positional embedings to the input !!!!!!!!!!!
        enc_inputs, dec_inputs = input
        
        enc_inputs = self.shared_embedding(enc_inputs)
        enc_inputs = self.scale_embed(enc_inputs)
        enc_inputs = self.positional_encoding(enc_inputs)  
        enc_inputs = self.dropout1(enc_inputs, training)
        
        dec_inputs = self.shared_embedding(dec_inputs)
        dec_inputs = self.scale_embed(dec_inputs)
        dec_inputs = self.positional_encoding(dec_inputs)
        dec_inputs = self.dropout2(dec_inputs,training)
       
        
        enc_output = self.encoder(enc_inputs, training)
        decoder_output = self.decoder(enc_output , dec_inputs, training) # has shape(batch, dec_input, d_model)

        # Final linear operation
        embed_weights = self.shared_embedding.weights[0] # has shape (vocab_size,d_model)
        embed_weights_scaled = self.scale_embed(embed_weights)

        logits = tf.matmul(decoder_output,embed_weights_scaled,transpose_b=True)

        return logits

In [None]:
loss_obj= tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1, reduction='none', from_logits=True)
mask_layer = Masking()

def loss_func(y_true, y_pred):
    
    lab_masked = mask_layer(y_true)
    mask = tf.cast(lab_masked._keras_mask, tf.float32)
    loss = loss_obj(lab_masked[:,:,1:],y_pred[:,:,1:])
    masked_loss = loss * mask
    
    return tf.reduce_sum(masked_loss)

In [None]:
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

    def __init__(self, warmup_steps=4000, d_model=512):
        super(MyLRSchedule,self).__init__()
        self.d_model = d_model#tf.cast(d_model, tf.float32)
        self.warmup_steps=warmup_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        lrate = self.d_model**-.5 * tf.minimum(step**-0.5, step*self.warmup_steps**-1.5)
        return lrate

    
    
    def get_config(self):
        config = {
            'warmup_steps':self.warmup_steps,
            'd_model':self.d_model,
        }
        return config
