In [None]:
import json
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pylab as plt

from utils import seed_everything, heatmap

np.set_printoptions(precision=2)
pd.set_option("display.precision", 2)
%load_ext autoreload
%autoreload 2

seed_everything()

$$
\text{attn}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left( \frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} \right)\mathbf{V}
$$

In [None]:
github_cp = mpl.colors.LinearSegmentedColormap.from_list(
    name="github",
    colors=[
        "#ebedf0",
        "#9be9a8",
        "#40c463",
        "#30a14e",
        "#216e39",
    ],
)


def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / np.sum(e_x, axis=axis, keepdims=True)


def single_head_self_attention(q, k, v, with_scale=True, draw=True, sentences=None):
    if sentences is None:
        sentences = []

    batch_size, value_sentence_size, embedding_size = v.shape
    scale = np.sqrt(embedding_size) if with_scale else 1

    corr = softmax(q @ k.transpose((0, 2, 1)) / scale)
    assert corr.shape == (batch_size, value_sentence_size, value_sentence_size)

    if draw:
        for i, sentence in zip(range(batch_size), sentences):
            sentence = sentence.split()
            corr_df = pd.DataFrame(corr[i], columns=sentence, index=sentence)

            fig, ax = plt.subplots(figsize=(16, 16))
            heatmap(corr_df, sentence, sentence, ax=ax, cmap=github_cp)
            plt.show()
    return corr @ v

In [None]:
sentences = ["I love this movie", "This movie is bad"]

# https://nlp.stanford.edu/projects/glove/
with open("sample_words.json", "r") as f:
    sample_words = json.load(f)

x = np.asarray([[sample_words[w] for w in sentence.lower().split()] for sentence in sentences])  # (N, T, E)
batch_size, _, embedding_size = x.shape

# for better illustration
wq = np.eye(embedding_size)
wk = np.eye(embedding_size)
wv = np.eye(embedding_size)

q = x.copy() @ wq
k = x.copy() @ wk
v = x.copy() @ wv

# Single head

## without scale

In [None]:
_ = single_head_self_attention(q, k, v, with_scale=False, sentences=sentences)

## with scale

In [None]:
_ = single_head_self_attention(q, k, v, with_scale=True, sentences=sentences)

# multi head

In [None]:
def multi_head_self_attention(q, k, v, num_heads=2, with_scale=True, draw=True, sentences=None):
    if sentences is None:
        sentences = []

    assert num_heads in {2, 4}
    batch_size, query_sentence_size, embedding_size = q.shape
    batch_size, value_sentence_size, embedding_size = v.shape
    assert embedding_size % num_heads == 0

    single_head_embedding_size = embedding_size // num_heads
    scale = np.sqrt(single_head_embedding_size) if with_scale else 1

    q = q.reshape((batch_size, query_sentence_size, num_heads, single_head_embedding_size)).transpose((0, 2, 1, 3))
    k = k.reshape((batch_size, value_sentence_size, num_heads, single_head_embedding_size)).transpose((0, 2, 1, 3))
    v = v.reshape((batch_size, value_sentence_size, num_heads, single_head_embedding_size)).transpose((0, 2, 1, 3))

    assert q.shape == (batch_size, num_heads, query_sentence_size, single_head_embedding_size)
    assert k.shape == (batch_size, num_heads, value_sentence_size, single_head_embedding_size)
    assert v.shape == (batch_size, num_heads, value_sentence_size, single_head_embedding_size)

    corr = softmax(q @ k.transpose((0, 1, 3, 2)) / scale)
    assert corr.shape == (batch_size, num_heads, value_sentence_size, value_sentence_size)

    if draw:
        for i, sentence in zip(range(batch_size), sentences):
            sentence = sentence.split()
            if num_heads == 2:
                _, axs = plt.subplots(nrows=1, ncols=num_heads, figsize=(16, 8))
            if num_heads == 4:
                _, axs = plt.subplots(nrows=2, ncols=2, figsize=(16, 16))

            for ax, df in zip(axs, corr[i]):
                df = pd.DataFrame(df, columns=sentence, index=sentence)
                heatmap(df, sentence, sentence, ax=ax, cmap=github_cp)
            plt.tight_layout()
            plt.show()

    attention = corr @ v
    assert attention.shape == (batch_size, num_heads, query_sentence_size, single_head_embedding_size)
    return attention.transpose((0, 2, 1, 3)).reshape((batch_size, query_sentence_size, embedding_size))

In [None]:
_ = multi_head_self_attention(q, k, v, sentences=sentences)

# references

- [The Transformer Family Version 2.0](https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/)
- [Tutorial 6: Transformers and Multi-Head Attention](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html)
- [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)
- [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html)
- [Master Positional Encoding: Part I](https://towardsdatascience.com/master-positional-encoding-part-i-63c05d90a0c3)
- [Transformer Architecture: The Positional Encoding](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/)
- [Why multi-head self attention works: math, intuitions and 10+1 hidden insights](https://theaisummer.com/self-attention/)