In [32]:
import tensorflow as tf

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, name='multi_head_attention'):
        super().__init__(name=name)
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, q, k, v, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        matmul_qk = tf.matmul(q, k, transpose_b=True)

        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

        if mask is not None:
            scaled_attention_logits += (mask * -1e9)

        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

        output = tf.matmul(attention_weights, v)
        output = tf.transpose(output, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(output, (batch_size, -1, self.d_model))

        output = self.dense(concat_attention)

        return output, attention_weights

    
def positional_encoding(position, d_model):
    def get_angles(pos, i, d_model):
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
        return pos * angle_rates

    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                            np.arange(d_model)[np.newaxis, :],
                            d_model)

    # Apply sine to even indices
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    # Apply cosine to odd indices
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)


def point_wise_feed_forward_network(d_model, dff):
    out = tf.keras.Sequential([
        tf.keras.layers.Dense(dff, activation='relu'),
        tf.keras.layers.Dense(d_model)])
    return out


def create_padding_mask(seq):
    # seq: (batch_size, seq_len)
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)


def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask


class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, dff, d_model, num_heads, dropout, name='encoder_layer'):
        super().__init__(name=name)
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(dff, d_model)
        
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(dropout)
        self.dropout2 = tf.keras.layers.Dropout(dropout)
        
    def call(self, inputs, training=None, mask=None):
        x = inputs
        
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)
        
        return out2

    
class Encoder(tf.keras.layers.Layer):
    def __init__(self, vocab_size, num_layers, dff, d_model, num_heads, dropout, name='encoder'):
        super().__init__(name=name)
        self.num_layers = num_layers
        self.d_model = d_model
        
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
        self.pos_encoding = positional_encoding(vocab_size, d_model)
        
        self.enc_layers = [EncoderLayer(dff=dff, d_model=d_model, num_heads=num_heads, dropout=dropout, name='enc_layer_{}'.format(i+1)) for i in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout)
        
    def call(self, inputs, training=None, mask=None):
        # inputs shape: (batch_size, input_seq_len)
        seq_len = tf.shape(inputs)[1]
        
        x = self.embedding(inputs)  # (batch_size, input_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]
        
        x = self.dropout(x, training=training)
        
        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training, mask)
        
        return x
    
    
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, dff, d_model, num_heads, dropout, name='decoder_layer'):
        super().__init__(name=name)
        
        self.mha1 = MultiHeadAttention(d_model, num_heads, name='mha1')
        self.mha2 = MultiHeadAttention(d_model, num_heads, name='mha2')
        self.ffn = point_wise_feed_forward_network(d_model, dff)
        
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name='layernorm1')
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name='layernorm2')
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name='layernorm3')
        
        self.dropout1 = tf.keras.layers.Dropout(dropout, name='dropout1')
        self.dropout2 = tf.keras.layers.Dropout(dropout, name='dropout2')
        self.dropout3 = tf.keras.layers.Dropout(dropout, name='dropout3')
        
    def call(self, inputs, training=None, mask=None):
        x, enc_outputs, look_ahead_mask, padding_mask = inputs
        
        attn1, _ = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
        out1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(x + out1)
        
        attn2, _ = self.mha2(out1, enc_outputs, enc_outputs, padding_mask)
        out2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + out2)
        
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output)  # (batch_size, target_seq_len, d_model)
        
        return out3
        
        
class Decoder(tf.keras.layers.Layer):
    def __init__(self, vocab_size, num_layers, dff, d_model, num_heads, dropout, name='decoder'):
        super().__init__(name=name)
        self.num_layers = num_layers
        self.d_model = d_model
        
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
        self.pos_encoding = positional_encoding(vocab_size, d_model)
        
        self.dec_layers = [DecoderLayer(dff=dff, d_model=d_model, num_heads=num_heads, dropout=dropout, name=f'dec_layer_{i}') for i in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout)
        
    def call(self, inputs, training=None, mask=None):
        # inputs: (dec_inputs, enc_outputs, look_ahead_mask, dec_padding_mask)
        dec_inputs, enc_outputs, look_ahead_mask, padding_mask = inputs
        seq_len = tf.shape(dec_inputs)[1]
        
        x = self.embedding(dec_inputs)  # (batch_size, target_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]
        
        x = self.dropout(x, training=training)
        
        for i in range(self.num_layers):
            x = self.dec_layers[i]([x, enc_outputs, look_ahead_mask, padding_mask], training)
        
        return x  # (batch_size, target_seq_len, d_model)


class Transformer(tf.keras.Model):
    def __init__(self, vocab_size, num_layers, dff, d_model, num_heads, dropout, name='transformer'):
        super().__init__(name=name)
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.dff = dff
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout = dropout
        
        # Encoder, Decoder
        self.encoder = Encoder(vocab_size=vocab_size, num_layers=num_layers, dff=dff, d_model=d_model, num_heads=num_heads, dropout=dropout)
        self.decoder = Decoder(vocab_size=vocab_size, num_layers=num_layers, dff=dff, d_model=d_model, num_heads=num_heads, dropout=dropout)
        
        self.enc_padding_mask = tf.keras.layers.Lambda(create_padding_mask, output_shape=(1, 1, None), name='enc_padding_mask')
        self.look_ahead_mask = tf.keras.layers.Lambda(create_look_ahead_mask, output_shape=(1, None, None,), name='look_ahead_mask')
        self.dec_padding_mask = tf.keras.layers.Lambda(create_padding_mask, output_shape=(1, 1, None), name='dec_padding_mask')
        
        self.outputs = tf.keras.layers.Dense(units=vocab_size, name='outputs')
        
    def call(self, inputs, training=None):
        enc_inputs, dec_inputs = inputs
        
        enc_padding_mask = self.enc_padding_mask(enc_inputs)
        look_ahead_mask = self.look_ahead_mask(dec_inputs)
        dec_padding_mask = self.dec_padding_mask(dec_inputs)
        
        enc_outputs = self.encoder(inputs=[enc_inputs, enc_padding_mask], training=training)
        dec_outputs = self.decoder(inputs=[dec_inputs, enc_outputs, look_ahead_mask, dec_padding_mask], training=training)
        
        outputs = self.outputs(dec_outputs)
        
        return outputs
    
    def get_config(self):
        config = {
            'vocab_size': self.vocab_size,
            'num_layers': self.num_layers,
            'dff': self.dff,
            'd_model': self.d_model,
            'num_heads': self.num_heads,
            'dropout': self.dropout,
        }
        base_config = super(Transformer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


In [34]:
transformer_model = Transformer(vocab_size=10000, num_layers=6, dff=2048, d_model=512, num_heads=8, dropout=0.1)

In [36]:
# 추후 사용시에는 아래와 같이 모듈화 해 사용하면 좋을 것 같다
# 1. multi_head_attnetion.py
# 2. positional_encoding.py
# 3. encoder.py
# 4. decoder.py
# 5. transformer.py