<a href="https://colab.research.google.com/github/starhou/notebooks/blob/master/ML/Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SelfAttention 理解

感谢
**伟大是熬出来的** 大佬的分享[超详细图解Self-Attention](https://zhuanlan.zhihu.com/p/410776234)

## 原始公式

$$
\operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V
$$

  ![](https://pic4.zhimg.com/80/v2-6b6030a342a43d7c220cdc940738b783_720w.jpg)

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

# TensorFlow and tf.keras
import tensorflow as tf

import os
import io
from tensorflow.keras import layers



# Helper libraries
import imageio
import datetime
import numpy as np
import time
import matplotlib.pyplot as plt
from IPython import display
import PIL
import glob
from scipy import signal
print(tf.__version__)

2.6.0


In [None]:
class SelfAttention(tf.keras.Model):
  def __init__(self, input_dim = 3, dim_q_k = 4, dim_v = 5):
    super(SelfAttention, self).__init__()
    self.input_size = input_dim
    self.key_size = dim_q_k
    self.value_size = dim_v
    self.query_kernel = self.add_weight(
      name="query_kernel",
      shape=[self.input_size, self.input_size, self.key_size],
      trainable=True,
    )
    self.key_kernel = self.add_weight(
        name="key_kernel",
        shape=[self.input_size, self.input_size, self.key_size],
        trainable=True,
    )
    self.value_kernel = self.add_weight(
        name="value_kernel",
        shape=[self.input_size, self.input_size, self.value_size],
        trainable=True,
    )

  def call(self, x):
    Q = tf.einsum('...b,bbc->...bc', x, self.query_kernel)
    K = tf.einsum('...b,bbc->...bc', x, self.key_kernel)
    V = tf.einsum('...b,bbc->...bc', x, self.value_kernel)
    W = tf.einsum('...ad,...cd->...ac', Q, K)

    out = tf.einsum('...ab,...bc->...ac', W, V)
    return out

In [None]:
selfAttention = SelfAttention()

In [None]:
x = tf.ones((10, 3))

In [None]:
selfAttention(x)

## MultiHeadAttention
Scaled Dot-Product Attention 过程做 H 次，再把输出合并起来。

![链接文字](https://pic3.zhimg.com/80/v2-f221c5a13a4e6e3fb84685e0f884b1da_720w.jpg)

In [16]:
class MultiHeadAttention(tf.keras.Model):
  def __init__(self, head_size = 3, 
            num_heads = 4, 
            output_size = 5,
            num_query_features = 6,
            num_key_features = 6,
            num_value_features = 6,
            ):
    super(MultiHeadAttention, self).__init__()
    self.head_size = head_size
    self.num_heads = num_heads
    self.output_size = output_size
    self.query_kernel = self.add_weight(
        name="query_kernel",
        shape=[self.num_heads, num_query_features, self.head_size],
    )
    self.key_kernel = self.add_weight(
        name="key_kernel",
        shape=[self.num_heads, num_key_features, self.head_size],
    )
    self.value_kernel = self.add_weight(
        name="value_kernel",
        shape=[self.num_heads, num_value_features, self.head_size],
    )
    self.projection_kernel = self.add_weight(
        name="projection_kernel",
        shape=[self.num_heads, self.head_size, output_size],
    )
    self.projection_bias = self.add_weight(
      name="projection_bias",
      shape=[output_size],
    )
  def call(self, inputs):
    # einsum nomenclature
    # ------------------------
    # N = query elements
    # M = key/value elements
    # H = heads
    # I = input features
    # O = output features
    query = inputs[0]
    key = inputs[1]
    value = inputs[2] if len(inputs) > 2 else key
    # Linear transformations
    query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel)
    key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel)
    value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel)
    depth = tf.constant(self.head_size, dtype=query.dtype)
    query /= tf.sqrt(depth)

    # Calculate dot product attention
    logits = tf.einsum("...NHO,...MHO->...HNM", query, key)
    attn_coef = tf.nn.softmax(logits)
    # attention * value
    multihead_output = tf.einsum("...HNM,...MHI->...NHI", attn_coef, value)
    output = tf.einsum(
        "...NHI,HIO->...NO", multihead_output, self.projection_kernel
    )
    return output

In [17]:
multiHeadAttention = MultiHeadAttention()

In [24]:
x = tf.ones((10, 7, 20, 3, 6))

In [None]:
multiHeadAttention(x)