In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist

import numpy as np

In [2]:
tf.random.set_seed(42)

(X_train, y_train), (X_test, y_test) = mnist.load_data()

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(buffer_size=1024).batch(batch_size=256)
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(batch_size=256)

In [3]:
# 定义一个位置编码函数
def position_encoding(seq_len, d_model):
	encoding = np.zeros((seq_len, d_model))
	pos = np.reshape(np.arange(seq_len), (-1, 1))
	div_term = 10000 ** (2 * np.floor(np.arange(d_model) / 2) / d_model)
	encoding[:, 0::2] = np.sin(pos / div_term[0::2])
	encoding[:, 1::2] = np.cos(pos / div_term[1::2])

	return tf.cast(encoding, tf.float32)

# 示例
position_encoding(4, 10)

<tf.Tensor: shape=(4, 10), dtype=float32, numpy=
array([[ 0.0000000e+00,  1.0000000e+00,  0.0000000e+00,  1.0000000e+00,
         0.0000000e+00,  1.0000000e+00,  0.0000000e+00,  1.0000000e+00,
         0.0000000e+00,  1.0000000e+00],
       [ 8.4147096e-01,  5.4030228e-01,  1.5782665e-01,  9.8746681e-01,
         2.5116222e-02,  9.9968451e-01,  3.9810613e-03,  9.9999207e-01,
         6.3095731e-04,  9.9999982e-01],
       [ 9.0929741e-01, -4.1614684e-01,  3.1169716e-01,  9.5018148e-01,
         5.0216600e-02,  9.9873835e-01,  7.9620592e-03,  9.9996829e-01,
         1.2619144e-03,  9.9999923e-01],
       [ 1.4112000e-01, -9.8999250e-01,  4.5775455e-01,  8.8907862e-01,
         7.5285293e-02,  9.9716204e-01,  1.1942931e-02,  9.9992865e-01,
         1.8928709e-03,  9.9999821e-01]], dtype=float32)>

In [4]:
# 定义 MHA 自注意力机制

class MHA(tf.keras.layers.Layer):
	def __init__(self, seq_len, d_model, n_head):
		super(MHA, self).__init__()
		self.seq_len = seq_len
		self.d_model = d_model
		self.n_head = n_head
		self.d_head = d_model // n_head

		self.Wq = tf.keras.layers.Dense(self.d_model)
		self.Wk = tf.keras.layers.Dense(self.d_model)
		self.Wv = tf.keras.layers.Dense(self.d_model)

		self.dense = tf.keras.layers.Dense(d_model)

	def split(self, x, batch_size):
		x = tf.reshape(x, (batch_size, self.seq_len, self.n_head, self.d_head))

		return tf.transpose(x, (0, 2, 1, 3)) 
	
	def call(self, x):
		batch_size = tf.shape(x)[0]

		q = self.Wq(x)
		k = self.Wk(x)
		v = self.Wv(x)

		q = self.split(q, batch_size) # batch_size, n_head, seq_len, d_head
		k = self.split(k, batch_size) # batch_size, n_head, seq_len, d_head
		v = self.split(v, batch_size) # batch_size, n_head, seq_len, d_head

		qk = tf.matmul(q, tf.transpose(k, (0, 1, 3, 2))) / tf.sqrt(tf.cast(tf.shape(k)[-1], tf.float32))
		qk_softmax = tf.nn.softmax(qk, axis=-1)

		res = tf.matmul(qk_softmax, v) # batch_size, n_head, seq_len, d_head
		res = tf.transpose(res, (0, 2, 1, 3)) # batch_size, seq_len, n_head, d_head
		res = tf.reshape(res, (batch_size, self.seq_len, -1))

		return self.dense(res)
	
# 示例

layer = MHA(seq_len=5, d_model=512, n_head=8)
x = tf.random.normal((2, 5, 512))
layer(x)

<tf.Tensor: shape=(2, 5, 512), dtype=float32, numpy=
array([[[-6.74563527e-01, -2.25118384e-01, -2.73084670e-01, ...,
          3.74144316e-03, -1.45801395e-01,  1.79100558e-01],
        [-2.44295210e-01,  2.21353039e-01, -2.83471286e-01, ...,
          3.70672047e-01, -5.63003719e-01, -3.71754676e-01],
        [ 2.48779356e-02,  3.21374118e-01,  8.96210074e-01, ...,
          6.12242401e-01,  6.29714608e-01, -8.74842286e-01],
        [-9.93932724e-01,  1.01466686e-01, -2.61080772e-01, ...,
         -8.55647206e-01, -4.58812773e-01, -2.62865216e-01],
        [-7.72614717e-01, -4.85290945e-01,  3.75694782e-01, ...,
         -3.56333852e-02,  2.76163280e-01,  1.04431227e-01]],

       [[-5.41567862e-01,  4.23610359e-01,  3.58521938e-05, ...,
          8.30780327e-01,  3.47328186e-01,  3.77032101e-01],
        [ 3.72637808e-03,  8.83038878e-01, -6.69075549e-02, ...,
          7.27783442e-01,  3.68383348e-01,  6.68810606e-02],
        [ 9.03029144e-02,  1.09662676e+00,  1.35017142e-01, ...

In [5]:
# 定义残差链接和层归一化

class AddNorm(tf.keras.layers.Layer):
	def __init__(self, epsilon=1e-6):
		super().__init__()
		self.norm = tf.keras.layers.LayerNormalization(epsilon=epsilon)
	
	def call(self, x, sublayer_output):
		return self.norm(x + sublayer_output)

In [6]:
# 定义FFN
class FFN(tf.keras.layers.Layer):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(d_ff, activation='relu'),
            tf.keras.layers.Dense(d_model)
        ])

    def call(self, x):
        return self.ffn(x)

In [16]:
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, n_head, d_ff, seq_len):
        super().__init__()
        self.position = tf.constant(position_encoding(seq_len, d_model))
        self.embedding = tf.keras.layers.Embedding(100, d_model)
        self.mha = MHA(seq_len, d_model, n_head)
        self.add_norm1 = AddNorm()
        self.ffn = FFN(d_model, d_ff)
        self.add_norm2 = AddNorm()

    def call(self, x):
        x = self.embedding(x)
        x = x + self.position[None, :, :]
        attn_output = self.mha(x)
        x = self.add_norm1(x, attn_output)
        ffn_output = self.ffn(x)
        x = self.add_norm2(x, ffn_output)
        return x

In [18]:
# 示例

model = EncoderLayer(512, 8, 128, 10)
x = tf.random.uniform((5, 10), minval=0, maxval=100, dtype=tf.int32)
model(x)

<tf.Tensor: shape=(5, 10, 512), dtype=float32, numpy=
array([[[-1.3524554 , -0.07163309, -1.4625372 , ..., -0.46846896,
         -0.76534396,  1.7851676 ],
        [-0.50993747, -0.65109587, -0.5035783 , ..., -0.44811943,
         -0.75598425,  1.6804966 ],
        [-0.57834435, -1.811507  , -0.34832782, ..., -0.34342116,
         -0.7374161 ,  1.5633274 ],
        ...,
        [-0.20739476, -0.5264578 , -1.5128825 , ..., -0.35041055,
         -0.66425246,  1.6913017 ],
        [ 0.32422712, -1.5190591 , -0.81760037, ..., -0.5111117 ,
         -0.69395524,  1.5366465 ],
        [-0.31902927, -2.3752072 , -0.85781276, ..., -0.47169682,
         -0.5960291 ,  1.4610789 ]],

       [[-1.3039428 , -0.0917065 , -1.4280611 , ..., -0.4895043 ,
         -0.705066  ,  1.7918689 ],
        [-0.6019963 , -0.6590475 , -0.48992437, ..., -0.37612402,
         -0.7499124 ,  1.7062058 ],
        [-0.55276835, -1.7759085 , -0.36483955, ..., -0.33802414,
         -0.78432596,  1.6315383 ],
        ...,
