In [1]:
import sys
import os
sys.path.append(os.path.abspath('../transformers'))
sys.path.append(os.path.abspath('..'))


In [2]:
import tensorflow as tf
from utils import create_masks, positional_encoding

In [3]:
d_model=256
batch_size = 32
seq_len = 128
num_heads=8
depth = d_model // num_heads
vocab_size = 30522
bert_d_model = 768

In [4]:
#input = tf.random.uniform((batch_size, seq_len),maxval=vocab_size, dtype=tf.dtypes.int32)
enc_output = tf.random.uniform((batch_size, seq_len, bert_d_model))
dec_input = tf.random.uniform((batch_size, seq_len),maxval=vocab_size, dtype=tf.dtypes.int32)
print(enc_output.shape)

(32, 128, 768)


In [5]:
embedding = tf.keras.layers.Embedding(vocab_size, d_model)
pos_encoding = positional_encoding(vocab_size, d_model)

In [6]:
x = embedding(dec_input)
print(x.shape)
x *= tf.math.sqrt(tf.cast(d_model, tf.float32))
x += pos_encoding[:, :seq_len, :]
print(x.shape)

(32, 128, 256)
(32, 128, 256)


In [7]:
wq = tf.keras.layers.Dense(d_model)
wk = tf.keras.layers.Dense(d_model)
wv = tf.keras.layers.Dense(d_model)

q = wq(x) # (batch_size, seq_len, d_model)
k = wk(x) # (batch_size, seq_len, d_model)
v = wv(x) # (batch_size, seq_len, d_model)
print("shape q: {}, k: {}, v: {}".format(q.shape,k.shape,v.shape))
# shape q: (32, 128, 256), k: (32, 128, 256), v: (32, 128, 256)

In [8]:
q = wq(x) # (batch_size, seq_len, d_model)
k = wk(enc_output) # (batch_size, seq_len, d_model)
v = wv(enc_output) # (batch_size, seq_len, d_model)
print("shape q: {}, k: {}, v: {}".format(q.shape,k.shape,v.shape))

shape q: (32, 128, 256), k: (32, 128, 256), v: (32, 128, 256)


In [9]:
def split_heads(x):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, num_heads, depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

In [12]:
q = split_heads(q) # (batch_size, num_heads, seq_len_q, depth)
k = split_heads(k) # (batch_size, num_heads, seq_len_q, depth)
v = split_heads(v) # (batch_size, num_heads, seq_len_q, depth)
print(q.shape)
print(k.shape)
print(v.shape)

(32, 8, 128, 32)
(32, 8, 128, 32)
(32, 8, 128, 32)


In [13]:
matmul_qk = tf.matmul(q, k, transpose_b=True) 
print(matmul_qk.shape)

(32, 8, 128, 128)


In [14]:
dk = tf.cast(tf.shape(k)[-1], tf.float32)
print(dk.shape)

()


In [15]:
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

In [16]:
print(scaled_attention_logits.shape)

(32, 8, 128, 128)


Mask: enc_padding_mask, look_ahead_mask, dec_padding_mask

In [17]:
demo_input = tf.random.uniform((batch_size, seq_len),maxval=vocab_size, dtype=tf.dtypes.int32)
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(demo_input, dec_input)
print(enc_padding_mask.shape)
print(combined_mask.shape)
print(dec_padding_mask.shape)

(32, 1, 1, 128)
(32, 1, 128, 128)
(32, 1, 1, 128)


More about combine or look ahead mask

In [18]:
def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len)

In [21]:
look_ahead_mask = create_look_ahead_mask(10)
print(look_ahead_mask.shape)
print(look_ahead_mask)

(10, 10)
tf.Tensor(
[[0. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)


look_ahead_mask sequentially view one word at a time like LSTM

Add mask to the scaled_attention_logits

In [22]:
logits1 = scaled_attention_logits + (dec_padding_mask * -1e9)
print(logits1.shape)

(32, 8, 128, 128)


In [23]:
logits2 = scaled_attention_logits + (combined_mask * -1e9)
print(logits2.shape)

(32, 8, 128, 128)
