<a href="https://colab.research.google.com/github/Hamza-Kamran/EECS-C106A-Project/blob/main/Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import tensorflow as tf
import tensorflow.keras.layers
from tensorflow import Tensor
from tensorflow.linalg import matmul
from tensorflow.nn import softmax
from tensorflow.math import *
import collections
import logging
import os
import pathlib
import re
import string
import sys
import time

import numpy as np
import matplotlib.pyplot as plt


In [None]:
#Hamza's model
#todos: 
#confirm unit tests for each unit
#fix dimension problems with input in encoder layer
#add any other missing pieces
#add data processing step to convert data e.g language to input tensor
#add decoder

#model taken from https://data-science-blog.com/blog/2021/04/07/multi-head-attention-mechanism/
def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)
def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)
def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights
class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    print("Inside 'MultiHeadAttention' class...")
    batch_size = tf.shape(q)[0]

    print()
    print("The shape of 'q' is " + str(q.shape))
    print("The shape of 'k' is " + str(k.shape))
    print("The shape of 'v' is " + str(v.shape))
    
    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)
    
    print()
    print("After passing 'q', 'k', 'v' through densely connected layers....")
    print("The shape of 'q' is " + str(q.shape))
    print("The shape of 'k' is " + str(k.shape))
    print("The shape of 'v' is " + str(v.shape))

 
    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    print()
    print("After splitting the heads....")
    print("The shape of 'q' is " + str(q.shape))
    print("The shape of 'k' is " + str(k.shape))
    print("The shape of 'v' is " + str(v.shape))

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    
    print()
    print("The shape of 'attention_weights' is " + str(attention_weights.shape))


    print("The shape of 'scaled_attention' is " + str(scaled_attention.shape))
    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
    
    print()
    print("After transposing....")
    print("The shape of 'scaled_attention' is " + str(scaled_attention.shape))
    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
    
    print()
    print("The shape of 'concat_attention' is " + str(concat_attention.shape))
    
    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
    print()
    print("The shape of 'output' is " + str(output.shape))

    return output, attention_weights

#calculated Attn = softmax((QK^T / sqrt(d_k))) V where Q, K, V are query, key, and value matrices, d_k is dimnesionality of a key, and K^T is key transpose matrix
def scaled_dot_product_attention_mine(K: Tensor, Q: Tensor, V: Tensor): 
  V = tf.cast(V, tf.float32)
  K = tf.cast(K, tf.float32)
  Q = tf.cast(Q, tf.float32)
  K_t = tf.transpose(K)
  dot = matmul(Q, K_t)
  d_k = K.shape[-1]
  print(d_k)
  print(dot)
  dot = dot / d_k**(-0.5)
  soft = softmax(dot)
  return matmul(soft, V)


#The feedforward sublayer combined with the residual connection
#Input size is d_model, output size is d_model, hidden layer has dimensions d_hidden
#ouput is calculated using:   ff(x) = ReLu(W1 * x + b_1) W_2 + b_2
class FeedForward(tf.keras.layers.Layer): 

  def __init__(self, d_model, d_hidden) -> None:
    super().__init__()

    self.layer = tf.keras.Sequential(
        [tf.keras.layers.Dense(d_hidden, activation='relu'), 
          tf.keras.layers.Dense(d_model)
         ]
    )
    self.layer_norm = tf.keras.layers.LayerNormalization()
    self.Add = tf.keras.layers.Add()


  def call(self, x): 
    x = self.Add([self.layer(x), x])
    x = self.layer_norm(x)
    return x



#The base class upon which multihead attention is built
class attentionhead(tf.keras.layers.Layer): 

  def __init__(self, d_in: int, d_k: int, d_q: int): 
    super().__init__()
    self.q = tf.keras.layers.Dense(d_q, input_shape = (d_in, ), activation=None)
    self.k = tf.keras.layers.Dense(d_k, input_shape = (d_in, ), activation=None)
    self.v = tf.keras.layers.Dense(d_k, input_shape = (d_in, ), activation=None)
    

  def call(self, key: Tensor, query: Tensor, value: Tensor):
    return scaled_dot_product_attention_mine(self.k(key), self.q(query), self.v(value))
     



# each of the input-dimensional keys, values, and queries are projected to
# key dimensions through a linear layer. Attention is performed on the transformed 
# tensors, and the transformed tensors are then projected again.  
class multiheadattention(tf.keras.layers.Layer): 

  def __init__(self, num_heads: int, dim_in: int, dim_key: int, dim_query: int, mask: bool = False) -> None:
    super().__init__()
    #how to define a module list?
    self.heads = [attentionhead(dim_in, dim_key, dim_query) for _ in range(num_heads)]
    self.lin = tf.keras.layers.Dense(3 * dim_key, input_shape=(dim_in, ))
#    self.lin = tf.keras.layers.Dense(num_heads * dim_key, input_shape=(dim_in, ))

    #todo: figure out masking
    if mask: 
      pass


  def call(self, key: Tensor, query: Tensor, value: Tensor) -> Tensor:
    return self.lin(
        tf.concat([head(key, query, value) for head in self.heads], axis=-1, name="linear")
    )
  
#mha, add and norm -> dense layer, add and norm
#dropout applied here since EncoderLayer combines all the sublayers
class EncoderLayer(tf.keras.layers.Layer): 
  def __init__(self, d_model, num_heads, d_hidden, dropout=0.3) -> None:
     super().__init__()
     self.d_k = max(d_model // num_heads, 1)
     self.mha = multiheadattention(num_heads, d_model, self.d_k, self.d_k, False)
     self.ff = FeedForward(d_model, d_hidden)
     self.layer_norm = tf.keras.layers.LayerNormalization()
     self.Add = tf.keras.layers.Add()
     self.dropout = tf.keras.layers.Dropout(dropout)

  def call(self, input): 
    #how to do batch normalization after mha?
    K = tf.convert_to_tensor(input.numpy())
    Q = tf.convert_to_tensor(input.numpy())
    V = tf.convert_to_tensor(input.numpy())

    x = self.mha(K, Q, V)
    x = self.Add([x, input])
    x = self.dropout(x)
    x = self.layer_norm(x)
    x = self.ff(x)
    x = self.dropout(x)
    return x


def split_heads(x, num_heads, batch_dimensions): 
  dim_model, dim_seq = x.shape

  

class DecoderLayer(tf.keras.layers.Layer) :
  def __init__(self, d_model, num_heads, d_hidden, dropout=0.3): 

    self.d_k = tf.max(d_model // num_heads, 1)

    self.masked_mha = multiheadattention(num_heads, d_model, self.d_k, self.d_k, mask=True)
    self.mha = multiheadattention(num_heads, d_model, self.d_k, self.d_k, mask=True)
    self.ff = FeedForward(d_model, d_hidden)
    self.layer_norm = tf.keras.layers.LayerNormalization()
    self.Add = tf.keras.layers.Add()
    self.dropout = tf.keras.layers.Dropout(dropout)

  
  def call(self, input, cross_input):
    K = tf.convert_to_tensor(input.numpy())
    Q = tf.convert_to_tensor(cross_input.numpy())
    V = tf.convert_to_tensor(cross_input.numpy())

    #first sub layer with masking
    x = self.masked_mha(K, Q, V)
    x = self.Add([x, input])
    x = self.dropout(x)
    x = self.layer_norm(x)

    #second without masking
    x2 = tf.convert_to_tensor(x.numpy())
    x = self.mha(K, K, K)
    x = self.Add([x, x2])
    x = self.dropout(x)
    x = self.layer_norm(x)

    #third sublayer
    x = self.ff(x)
    x = self.dropout(x)
    return x




#input is a tensor that is fed into the transformer and 
#positional encoding is added to it
def create_positional_encoding(dim_model, dim_seq): 
  pos = tf.range(0, dim_model, 2, dtype=tf.float32).numpy()
  pos = pos.reshape(1, -1, 1)
  pos = tf.convert_to_tensor(pos)
  dim = tf.range(0, dim_seq, dtype=tf.float32).numpy()
  dim = dim.reshape(1, 1, -1)
  dim = tf.convert_to_tensor(dim)
  phase = pos / 1e4**(dim / dim_model)
  return tf.where(dim%2 == 0, tf.sin(phase), tf.cos(phase))


class Encoder(tf.keras.layers.Layer): 
  def __init__(self, d_model, num_heads, d_hidden): 
    super().__init__()
    self.l = [EncoderLayer(d_model, num_heads, d_hidden) for _ in range(6)]

  def call(self, input: Tensor):
    for lay in self.l: 
      input = lay(input)
    return input


class Decoder(tf.keras.layers.Layer): 
  def __init__(self, d_model, num_heads, d_hidden):
    super().__init__()
    self.l = [Decoder(self, d_model, num_heads, d_hidden) for _ in range(6)]

    #do we need Input layer for the final linear layer?
    self.linear = tf.keras.layers.Dense(d_model, activation='softmax')

  
  def call(self, input: Tensor): 

    x = input
    for layer in self.l: 
      x = layer(x)
    
    return self.linear(x)

    

#print(input_tensor.shape)
#print(rslt.shape)

SyntaxError: ignored

In [None]:
d_model = 192
num_heads = 3
d_hidden = 1024
enc = Encoder(d_model, num_heads, d_hidden)


#create data
input = tf.random.normal([192])
input = tf.expand_dims(input, axis=0)
print(input)
print(enc(input))

tf.Tensor(
[[-0.32209584 -0.06927506 -0.57095814  0.7766935  -0.7868154  -0.39203525
  -0.04428896 -0.60786307  0.05801601 -0.2546975  -0.92768675  0.72436327
   1.1825615   1.2638073   1.2687995   0.9103692   1.298516    1.9871203
  -0.09246769  1.6756735  -0.9789442   1.2061933  -0.652072    0.5392815
   1.8566234  -0.41659877  1.5074799  -1.5703791  -0.23919009  0.53183734
  -0.5769419  -0.7483095   0.73877513 -1.1111399  -0.14601016 -0.40166846
  -0.55308616 -0.23467162  1.1752588   0.2001493   0.48362836 -0.6464562
  -0.92099637  2.1429296   0.55001235  0.5439586   1.4686579   1.433936
   1.2073331  -0.64747274  0.48129022  0.6143392  -1.5393376  -0.31541872
  -0.6150933  -1.0597662  -0.13111484 -1.1910559   0.4392466   0.42834628
   0.53299004 -0.3839335   2.7596025   1.15551     0.37441167 -0.20678559
   0.6748126  -0.76579565 -0.1849377   2.021187   -1.0069524   1.6959141
   1.3094616  -0.81022644 -0.8127374  -0.9872704  -0.9345691  -0.137123
  -0.17527446  0.59964484  1.376598

In [None]:
seq_len, dim_model = 100, 200
pos = torch.arange(seq_len, dtype=torch.float).reshape(1, -1, 1)
dim = torch.arange(dim_model, dtype=torch.float).reshape(1, 1, -1)
phase = pos / (1e4 ** (dim / dim_model))
print(torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase)))

print(create_positional_encoding(dim_model, seq_len))

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.7761e-01,  7.9074e-01,  ...,  1.0000e+00,
           1.0965e-04,  1.0000e+00],
         [ 9.0930e-01, -3.3272e-01,  9.6811e-01,  ...,  1.0000e+00,
           2.1930e-04,  1.0000e+00],
         ...,
         [ 3.7961e-01, -4.2693e-02,  4.7983e-01,  ...,  9.9994e-01,
           1.0636e-02,  9.9995e-01],
         [-5.7338e-01,  7.9091e-01,  9.8749e-01,  ...,  9.9994e-01,
           1.0745e-02,  9.9995e-01],
         [-9.9921e-01,  9.5637e-01,  7.2918e-01,  ...,  9.9994e-01,
           1.0855e-02,  9.9995e-01]]])
tf.Tensor(
[[[ 0.          1.          0.         ...  1.          0.
    1.        ]
  [ 0.9092974  -0.33272225  0.9681094  ...  0.99973637  0.0219278
    0.9997807 ]
  [-0.7568025  -0.7785918  -0.48507634 ...  0.9989456   0.04384506
    0.9991229 ]
  ...
  [-0.7023863  -0.99635464  0.8419628  ... -0.6104443   0.8491771
   -0.44451517]
  [ 0.9395301  

In [None]:

#TESTING CELL
def check_attention_head():  
  K = tf.constant([[1, 2, 3], [4, 5, 6]])
  V = tf.constant([[1, 2, 3], [4, 5, 6]])
  Q = tf.constant([[1, 2, 3], [4, 5, 6]])
  
  print("Running  attention function...")
  head = attentionhead(3, K.shape[-1], Q.shape[-1])
  print(head(K, Q, V))


def self_attention_check(): 
  K = tf.constant([[1, 2, 3], [4, 5, 6]])
  V = tf.constant([[1, 2, 3], [4, 5, 6]])
  Q = tf.constant([[1, 2, 3], [4, 5, 6]])
  
  print("Running multi headattention function...")
  head = multiheadattention(3, 3, K.shape[-1], Q.shape[-1], False)
  print(head(K, Q, V))
  K = tf.cast(K, tf.int32)
  Q = tf.cast(Q, tf.int32)
  V = tf.cast(V, tf.int32)

  #testing
  temp_mha = MultiHeadAttention(d_model=K.shape[-1], num_heads=3)
  out, attn = temp_mha(v=V, k=K, q=Q, mask=None)
  print(out, attn)


#todo: how to implement MultiHeadAttention from tensorflow??????
#  layer = tf.keras.layers.MultiHeadAttention(num_heads=3, key_dim=6, use_bias=False)
#  print(layer(Q, V, K))

self_attention_check()

Running multi headattention function...
3
tf.Tensor(
[[-0.4375267 -4.169918 ]
 [ 0.5893545 -4.651104 ]], shape=(2, 2), dtype=float32)
3
tf.Tensor(
[[-1.343729  -3.0636737]
 [-1.9823676 -5.6841383]], shape=(2, 2), dtype=float32)
3
tf.Tensor(
[[ 7.5099554 11.30494  ]
 [22.314337  36.26994  ]], shape=(2, 2), dtype=float32)
tf.Tensor(
[[-3.583356    2.7362645   3.0014017  -6.147074   -3.6801023  -4.738505
  -4.1570363  -4.6377773   0.02600089]
 [-3.6483214   2.5922365   3.0082812  -6.184674   -3.5438685  -4.6698055
  -4.180213   -4.4322987   0.08446279]], shape=(2, 9), dtype=float32)
Inside 'MultiHeadAttention' class...

The shape of 'q' is (2, 3)
The shape of 'k' is (2, 3)
The shape of 'v' is (2, 3)

After passing 'q', 'k', 'v' through densely connected layers....
The shape of 'q' is (2, 3)
The shape of 'k' is (2, 3)
The shape of 'v' is (2, 3)

After splitting the heads....
The shape of 'q' is (2, 3, 1, 1)
The shape of 'k' is (2, 3, 1, 1)
The shape of 'v' is (2, 3, 1, 1)

The shape of 'at

In [None]:
# Greg's model

# Define core layers

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, initializers, activations
import numpy as np

class FeedForwardNetwork(layers.Layer):
    def __init__(self, inner_dim=2048, outer_dim=512):
        super().__init__()
        self.inner_dim = inner_dim
        self.outer_dim = outer_dim
        
    def build(self, input_shape):
        self.batch_dim = input_shape[0]
        self.input_dim = input_shape[1:]
        
        self.d1 = layers.Dense(self.inner_dim, activation="relu")
        self.d2 = layers.Dense(self.outer_dim, activation=None)
    
    def call(self, inputs, *args, **kwargs):
        x = inputs
        x = self.d1(x)
        x = self.d2(x)
        return x


# Modified from www.tensorflow.org/text/tutorials/transformer
def positional_encoding(length, depth):
    depth = depth // 2
    
    positions = np.arange(length)[:, np.newaxis]
    depths = np.arange(depth)[np.newaxis, :] / depth
    
    angle_rates = 1 / (10000 ** depths)
    angle_rads = positions * angle_rates
    
    # Is this right?
    #pos_encoding = np.concatenate(
    #    [np.sin(angle_rads), np.cos(angle_rads)],
    #    axis=-1,
    #)
    pos_encoding = np.stack(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1,
    )
    pos_encoding = np.reshape(pos_encoding, (length, depth * 2))
    pos_encoding = tf.constant(pos_encoding, dtype=tf.float32)
    return pos_encoding


class PositionalEncoding(layers.Layer):
    def __init__(self, seq_size, d_model):
        super().__init__()
        self.seq_dim = seq_size
        self.d_model = d_model
    
    def build(self, input_shape):
        pass
    
    def call(self, inputs, *args, **kwargs):
        seq_dim, batch_dim, model_dim = inputs.shape
        
        #pos = tf.expand_dims(tf.ones(0, seq_dim), axis=-1)
        # (seq_dim, 1)
        #pos = tf.one_hot(tf.range(0, seq_dim), seq_dim)
        # TODO: check this
        pos = tf.ones(shape=(self.seq_dim, 1))
        pe = positional_encoding(self.seq_dim, model_dim)
        pos = pos * positional_encoding(self.seq_dim, model_dim)
        # (seq_dim, model_dim)
        
        pos = tf.expand_dims(pos, axis=-2)
        # (seq_dim, 1, model_dim)
        
        # pos broadcasts to (seq_dim, batch_dim, model_dim)
        inputs += pos
        
        return inputs


class MultiHeadAttention(layers.Layer):
    def __init__(self, d_key=64, d_value=64, d_model=512, mask=False, num_heads=8, dropout_rate=0.1):
        super().__init__()
        self.d_key = d_key
        self.d_value = d_value
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.mask = mask
        

    def build(self, input_shape):
        self.WQ = self.add_weight(
            name="WQ",
            shape=(self.num_heads, self.d_model, self.d_key),
            initializer=initializers.GlorotNormal(),
            trainable=True,
        )
        self.WK = self.add_weight(
            name="WK",
            shape=(self.num_heads, self.d_model, self.d_key),
            initializer=initializers.GlorotNormal(),
            trainable=True,
        )
        self.WV = self.add_weight(
            name="WV",
            shape=(self.num_heads, self.d_model, self.d_value),
            initializer=initializers.GlorotNormal(),
            trainable=True,
        )
        self.WO = self.add_weight(
            name="WO",
            shape=(self.num_heads * self.d_value, self.d_model),
            initializer=initializers.GlorotNormal(),
            trainable=True,
        )
        self.dropout_layer = layers.Dropout(self.dropout_rate)

    def call(self, inputs, *args, **kwargs):
        dropout_rate = 0.1
        
        # Extract inputs, and reshape for broadcasting to multiple heads
        if len(inputs) == 3:
            Q, K, V = inputs
            has_mask = False
        elif len(inputs) == 4:
            Q, K, V, mask = inputs
            has_mask = True

        # expand for heads, and into row vector
        Q = tf.expand_dims(Q, axis=-2)
        K = tf.expand_dims(K, axis=-2)
        V = tf.expand_dims(V, axis=-2)
        Q = tf.expand_dims(Q, axis=-2)
        K = tf.expand_dims(K, axis=-2)
        V = tf.expand_dims(V, axis=-2)
        # (seq, batch, 1, 1, d_model)

        # Project inputs
        Q = tf.matmul(Q, self.WQ)
        K = tf.matmul(K, self.WK)
        V = tf.matmul(V, self.WV)
        
        Q = tf.squeeze(Q, axis=-2)
        K = tf.squeeze(K, axis=-2)
        V = tf.squeeze(V, axis=-2)
        # (seq, batch, heads, d_key / d_value)
        
        # Calculate attention
        attention_logits = (Q * K) / tf.math.sqrt(tf.cast(K.shape[-1], tf.float32))
        attention_logits = tf.math.reduce_sum(attention_logits, axis=-1, keepdims=True)
        # (seq, batch, heads, 1)
        
        if has_mask:
            # (seq,)
            mask = tf.expand_dims(mask, axis=-1)
            mask = tf.expand_dims(mask, axis=-1)
            mask = tf.expand_dims(mask, axis=-1)
            # (seq, 1, 1, 1)
            # mask should be non-zero to retain, zero to mask
            # set non-zero to zero, zero to -inf, then add
            inf_mask = tf.where(tf.cast(mask, tf.bool), tf.zeros(shape=(1,)), tf.constant(float("-inf"), shape=(1,)))
            mask_attention_logits = tf.add(attention_logits, inf_mask)
        else:
            mask_attention_logits = attention_logits
        
        # (seq, batch, heads, 1)
        attention_scores = activations.softmax(mask_attention_logits, axis=0)
        # (seq, batch, heads, 1)

        attention_values = self.dropout_layer(attention_scores) * V
        # (seq, batch, heads, d_value)
        
        # Place heads before values, reshape to concatenate
        seq_dim, batch_dim, head_dim, value_dim = attention_values.shape.as_list()
        reshaped_attention_values = tf.reshape(
            attention_values,
            shape=(seq_dim, batch_dim, head_dim * value_dim)
        )
        # (seq, batch, heads * d_value)
        multihead_attention = tf.matmul(reshaped_attention_values, self.WO)
        # (seq, batch, d_model)
        
        return multihead_attention

class EncoderLayer(layers.Layer):
    def __init__(self, dropout_rate=0.1):
        super().__init__()
        self.dropout_rate = dropout_rate
    
    def build(self, input_shape):
        seq_dim, batch_dim, model_dim = input_shape
        self.MA = MultiHeadAttention(d_model=model_dim)

        self.LN1 = layers.LayerNormalization(axis=-1)
        self.LN2 = layers.LayerNormalization(axis=-1)
        self.FFN = FeedForwardNetwork(outer_dim=model_dim)
        self.dropout_layer1 = layers.Dropout(self.dropout_rate)
        self.dropout_layer2 = layers.Dropout(self.dropout_rate)
    
    def call(self, inputs, *args, **kwargs):
        # MA Attention, dropout, add & norm,
        # FFN, dropout, add & norm
        x0 = inputs
        x = x0
        x = self.MA((x, x, x))
        x = self.dropout_layer1(x)
        x = self.LN1(x + x0)
        x1 = x
        x = self.FFN(x)
        x = self.dropout_layer2(x)
        x = self.LN2(x + x1)
        return x

class Encoder(layers.Layer):
    def __init__(self, dropout_rate=0.1):
        super().__init__()
        self.dropout_rate = dropout_rate
    
    def build(self, input_shape):
        seq_dim, batch_dim, model_dim = input_shape
        self.lay = []
        for i in range(6):
            self.lay.append(EncoderLayer(self.dropout_rate))
    
    def call(self, inputs, *args, **kwargs):
        x = inputs
        for i in range(6):
            x = self.lay[i](x)
        return x


class DecoderLayer(layers.Layer):
    def __init__(self, dropout_rate=0.1):
        super().__init__()
        self.dropout_rate = dropout_rate

    def build(self, input_shape):
        (
            (seq_dim, batch_dim, model_dim),
            (input_seq_dim, input_batch_dim, input_model_dim),
            mask_shape
        ) = input_shape
        self.MA = MultiHeadAttention(model_dim)
        self.CrossAttention = MultiHeadAttention(model_dim)
        self.FFN = FeedForwardNetwork(outer_dim=model_dim)
        self.LN1 = layers.LayerNormalization(axis=-1)
        self.LN2 = layers.LayerNormalization(axis=-1)
        self.LN3 = layers.LayerNormalization(axis=-1)
        self.dropout_layer1 = layers.Dropout(self.dropout_rate)
        self.dropout_layer2 = layers.Dropout(self.dropout_rate)
        self.dropout_layer3 = layers.Dropout(self.dropout_rate)
        
    def call(self, inputs, *args, **kwargs):
        # MA Masked Self-Attention, dropout, add & norm,
        # MA Cross Attention, dropout, add & norm,
        # FFN, dropout, add & norm
        x0, y, mask = inputs
        x = x0
        x = self.MA((x, x, x, mask))
        x = self.dropout_layer1(x)
        x = self.LN1(x + x0)
        x1 = x
        x = self.MA((x, y, y))
        x = self.dropout_layer2(x)
        x = self.LN2(x + x1)
        x2 = x
        x = self.FFN(x)
        x = self.dropout_layer3(x)
        x = self.LN3(x + x2)
        return x

class Decoder(layers.Layer):
    def __init__(self, d_vocab, dropout_rate=0.1):
        super().__init__()
        self.dropout_rate = dropout_rate
        self.d_vocab = d_vocab
    
    def build(self, input_shape):
        (seq_dim, batch_dim, model_dim), enc_shape, mask_shape = input_shape
        self.lay = []
        for i in range(6):
            self.lay.append(DecoderLayer(self.dropout_rate))
        self.linear = layers.Dense(self.d_vocab, activation="softmax")
        
    
    def call(self, inputs, *args, **kwargs):
        x, y, mask = inputs
        for i in range(6):
            x = self.lay[i]((x, y, mask))
        x = activations.softmax(self.linear(x), axis=-1)
        return x


In [None]:
# Load data and set constants

import tensorflow.keras as keras
import pathlib
import random

# Download dataset
text_file = keras.utils.get_file(
    fname="spa-eng.zip",
    origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
    extract=True,
)
text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"

# Extract eng, spa, text_pairs
with open(text_file) as f:
    lines = f.read().split("\n")[:-1]
text_pairs = []
for line in lines:
    eng, spa = line.split("\t")
    eng = eng.lower()
    spa = spa.lower()
    text_pairs.append((eng, spa))

# Shuffle, train, val, test split
random.shuffle(text_pairs)
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

# Set eng, spa train samples
eng_samples = [text_pair[0] for text_pair in train_pairs]
spa_samples = [text_pair[1] for text_pair in train_pairs]

# Set model constants
BATCH_SIZE = 1
EPOCHS = 10  # This should be at least 10 for convergence
STEPS_PER_EPOCH = 1
MAX_SEQUENCE_LENGTH = 40
ENG_VOCAB_SIZE = 10000
SPA_VOCAB_SIZE = 10000

KEY_DIM = 64
EMBED_DIM = 512
INTERMEDIATE_DIM = 2048
NUM_HEADS = 8

input_dim = ENG_VOCAB_SIZE
output_dim = SPA_VOCAB_SIZE
model_dim = 512
SHARED_SEQUENCE_SIZE = 8
input_seq_size = SHARED_SEQUENCE_SIZE
output_seq_size = SHARED_SEQUENCE_SIZE
# output_seq_size = 40
# seq_size = output_seq_size


In [None]:
# Clean and tokenize data

import string
from itertools import chain, repeat
import numpy as np

RESERVED_TOKENS = ["<UNK> <START> <END> <PAD>"]
class StringLookup():
    def __init__(self, vocab_size=10000):
        self.lookup = {}
        self.rlookup = {}
        self.vocab = []
        self.vocab_size = vocab_size
        self.adapt(RESERVED_TOKENS)
    
    def adapt(self, texts):
        for text in texts:
            tokens = self.tokenize(text)
            for token in tokens:
                if self.lookup.get(token, None) is None:
                    length = len(self.vocab)
                    if length < self.vocab_size:
                        self.vocab.append(token)
                        self.lookup[token] = length
                        self.rlookup[length] = token
                    else:
                        return
                    
    def convert(self, texts, seq_size=100):
        if not isinstance(texts, list):
            texts = [texts]
        return [
            [
                self.lookup.get(token, 0)
                for token in list(
                    chain(["<START>"], self.tokenize(text)[:seq_size - 2], ["<END>"], repeat("<PAD>", seq_size))
                )[:seq_size]
            ] for text in texts
        ]
    
    def convert_one(self, text, seq_size=100):
        return [
            self.lookup.get(token, 0)
                for token in list(
                    chain(["<START>"], self.tokenize(text)[:seq_size - 2], ["<END>"], repeat("<PAD>", seq_size))
                )[:seq_size]
        ]


    def rconvert(self, vs):
        if not isinstance(vs, list):
            vs = [vs]
        return [" ".join([self.rlookup.get(i, "<UNK>") for i in v]) for v in vs]
    
    def tokenize(self, text):
        return text.split()
    
string.punctuation
spanish_punctuation = string.punctuation + "¡¿"

def remove_punctuation(text):
    punctuation_free = "".join([i for i in text if i not in spanish_punctuation])
    return punctuation_free

eng_samples_no_punc = [remove_punctuation(eng_sample) for eng_sample in eng_samples]
spa_samples_no_punc = [remove_punctuation(spa_sample) for spa_sample in spa_samples]
train_pairs_no_punc = [(remove_punctuation(e), remove_punctuation(s)) for e, s in train_pairs]
val_pairs_no_punc = [(remove_punctuation(e), remove_punctuation(s)) for e, s in val_pairs]
test_pairs_no_punc = [(remove_punctuation(e), remove_punctuation(s)) for e, s in test_pairs]

en_lookup = StringLookup(vocab_size=ENG_VOCAB_SIZE)
en_lookup.adapt(eng_samples_no_punc)

es_lookup = StringLookup(vocab_size=SPA_VOCAB_SIZE)
es_lookup.adapt(spa_samples_no_punc)

train_pairs_int = [(en_lookup.convert(en, seq_size=input_seq_size)[0], es_lookup.convert(es, seq_size=output_seq_size)[0]) for en, es in train_pairs_no_punc]
val_pairs_int = [(en_lookup.convert(en, seq_size=input_seq_size)[0], es_lookup.convert(es, seq_size=output_seq_size)[0]) for en, es in val_pairs_no_punc]
test_pairs_int = [(en_lookup.convert(en, seq_size=input_seq_size)[0], es_lookup.convert(es, seq_size=output_seq_size)[0]) for en, es in test_pairs_no_punc]
train_pairs_np = np.array(train_pairs_int)
val_pairs_np = np.array(val_pairs_int)
test_pairs_np = np.array(test_pairs_int)

train_pairs_back = [(en_lookup.rconvert([en])[0], es_lookup.rconvert([es])[0]) for en, es in train_pairs_int]
train_pairs_back