# Transformer 除了Self-Attention的其他部分

上一期：

LiveCoding-Transformer-Attention部分

https://www.bilibili.com/video/BV1nV411a74n

In [26]:
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

 The versions of TensorFlow you are currently using is 2.3.1 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [2]:
class MultiHeadSelfAttention(tf.keras.Model):
    def __init__(self, vector_size, heads=1):
        super(MultiHeadSelfAttention, self).__init__()
        self.vector_size = vector_size
        self.heads = heads
        self.key = tf.keras.layers.Dense(vector_size // heads)
        self.value = tf.keras.layers.Dense(vector_size // heads)
        self.query = tf.keras.layers.Dense(vector_size // heads)

    def call(self, x):
        batch_size = x.shape[0]
        sequence_length = x.shape[1]
        heads = self.heads
        vector_size = self.vector_size

        x_mh = tf.reshape(x, [batch_size, sequence_length, heads, vector_size // heads])
        x_mh = tf.transpose(x_mh, (0, 2, 1, 3))
        key = tf.keras.layers.Dense(vector_size // heads)(x_mh)
        value = tf.keras.layers.Dense(vector_size // heads)(x_mh)
        query = tf.keras.layers.Dense(vector_size // heads)(x_mh)
        atten = tf.matmul(query, key, transpose_b=True)
        atten = atten / tf.sqrt(tf.cast(vector_size, tf.float32))
        atten = tf.nn.softmax(atten)
        y = tf.matmul(atten, value)
        y = tf.transpose(y, [0, 2, 1, 3])
        y = tf.reshape(y, (batch_size, sequence_length, vector_size))
        return y

In [3]:
batch_size = 4
sequence_length = 10
vector_size = 32
heads = 4

assert vector_size % heads == 0

x = tf.random.uniform((batch_size, sequence_length, vector_size))

In [4]:
# LayerNormalization 来自 BatchNormalization

In [27]:
class Transformer(tf.keras.Model):
    
    def __init__(self, vector_size, heads=1):
        super(Transformer, self).__init__()
        self.ln0 = tf.keras.layers.LayerNormalization()
        self.ln1 = tf.keras.layers.LayerNormalization()
        self.mh_atten = MultiHeadSelfAttention(vector_size, heads)
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(vector_size * 4),
            tf.keras.layers.Activation(tfa.activations.gelu),
            tf.keras.layers.Dense(vector_size),
        ])
    
    def call(self, x):
        z = self.ln0(x + self.mh_atten(x))
        y = self.ln1(z + self.ffn(z))
        return y

In [28]:
transformer = Transformer(vector_size, heads)

In [31]:
transformer(x).shape

TensorShape([4, 10, 32])