In [None]:
import numpy as np
import pandas as pd
import matplotlib.pylab as plt

from IPython.display import display

from utils import seed_everything, batch_plot


np.set_printoptions(precision=2)
pd.set_option("display.precision", 2)
%load_ext autoreload
%autoreload 2

seed_everything()

# Sinusoidal Positional Encoding

advantages:

1. Provide unique encoding for each position in the sequence.
2. Each adjacent position have the same relative distance.
3. The encoding can be extended to arbitrary length of the sequence.

$$
\begin{cases}
\sin\left(i \cdot 10000^{-\frac{j}{d}}\right) & \text{if j is even} \\
\cos\left(i \cdot 10000^{-\frac{(j-1)}{d}}\right) & \text{otherwise} \\
\end{cases}
$$

$$
PE_t =
\begin{bmatrix}
\sin(\omega_1t) & \cos(\omega_1t) & \sin(\omega_2t) & \cos(\omega_2t) & \cdots & \sin(\omega_{d/2}t) & \cos(\omega_{d/2}t)
\end{bmatrix}_{1 \times d}
$$

where $\omega_n = 10000^{-\frac{n}{d}}$ and $d$ is the embedding size.

$$
\begin{aligned}
\Vert PE_{t+1} - PE_t \Vert &= \sqrt{\sum_{n=1}^{d/2} \Bigl(\bigl( \sin(\omega_nt+\omega_n) - \sin(\omega_nt) \bigr)^2 + \bigl(\cos(\omega_nt+\omega_n) - \cos(\omega_nt) \bigr)^2\Bigr)} \\
&= \sqrt{\sum_{n=1}^{d/2} \left( \sin^2(\omega_nt+\omega_n) - 2\sin(\omega_nt+\omega_n)\sin(\omega_nt) + \sin^2(\omega_nt) + \cos^2(\omega_nt+\omega_n) - 2\cos(\omega_nt+\omega_n)\cos(\omega_nt) + \cos^2(\omega_nt) \right)} \\
&= \sqrt{\sum_{n=1}^{d/2} \Bigl( 2 - 2\bigl(\cos(\omega_nt+\omega_n)\cos(\omega_nt) + \sin(\omega_nt+\omega_n)\sin(\omega_nt)\bigr) \Bigr)} \\
&= \sqrt{\sum_{n=1}^{d/2} \bigl( 2 - 2\cos(\omega_n) \bigr)} \\
&= \sqrt{d - 2\sum_{n=1}^{d/2}\cos(\omega_n)}
\end{aligned}
$$

In [None]:
def generate_positional_encoding(max_length, embedding_size):
    positional_encoding = np.empty((max_length, embedding_size))
    positions = np.arange(max_length)[:, None]
    frequencies = 10000 ** (-np.arange(0, embedding_size, 2) / embedding_size)
    positional_encoding[:, 0::2] = np.sin(positions * frequencies)
    positional_encoding[:, 1::2] = np.cos(positions * frequencies)
    return positional_encoding


def generate_binary_encoding(max_length):
    encoding = np.empty((max_length, 64), dtype=np.uint8)
    for i in range(max_length):
        encoding[i] = np.asarray(list(reversed(np.binary_repr(i, width=64))), dtype=np.uint8)
    return encoding


display(pd.DataFrame(generate_binary_encoding(4)[:, :8], columns=[f"B_{i}" for i in range(8)]))
pd.DataFrame(generate_positional_encoding(4, 8), columns=[f"P_{i}" for i in range(8)])

In [None]:
embedding_size = 500
max_length = 1000

positional_encoding = generate_positional_encoding(max_length, embedding_size)
relative_distance = np.sqrt(embedding_size - 2 * np.cos(frequencies).sum())

assert np.allclose(
    np.linalg.norm(np.diff(positional_encoding, axis=0), axis=1), relative_distance
), "relative distances should be equal"

In [None]:
plt.figure(figsize=(16, 12))
plt.imshow(generate_binary_encoding(max_length).T, cmap="gray_r")
plt.xlabel("position")
plt.ylabel("bit")
plt.show()

In [None]:
# transpose for better visualization
transposed_encoding = positional_encoding.T  # (embedding_size, max_length)

view_dims = 64

_, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 12), gridspec_kw={"height_ratios": [1, 2]})
ax1.imshow(transposed_encoding, cmap="gray_r")
ax2.imshow(transposed_encoding[:view_dims, :view_dims], cmap="gray_r")
ax2.set_xlabel("position")
ax2.set_ylabel("embedding dimension")
plt.tight_layout()
plt.show()

In [None]:
for position in range(8):
    batch_plot(transposed_encoding[:view_dims, position], flatten_layout=True)

In [None]:
# (max_length, 1, embedding_size) - (1, max_length, embedding_size) -> (max_length, max_length, embedding_size)
distances = np.linalg.norm(
    positional_encoding[:, None] - positional_encoding[None, :], axis=2
)  # (max_length, max_length)

plt.figure(figsize=(12, 12))
plt.imshow(distance[:view_dims, :view_dims], cmap="gray_r")
plt.xlabel("position")
plt.ylabel("position")
plt.show()

In [None]:
plot_depth = 4
plot_step = 4
plot_nums = 32


_, axs = plt.subplots(nrows=4, ncols=2, figsize=(12, 16), sharex=True, sharey=True)
for i, (axo, axe) in enumerate(axs, 1):
    even_index, odd_index = 2 * i * plot_step, 2 * i * plot_step - 1
    axo.plot(positional_encoding[:plot_nums, odd_index], label=odd_index)
    axo.grid()
    axo.legend(loc="upper right")
    axe.plot(positional_encoding[:plot_nums, even_index], label=even_index)
    axe.grid()
    axe.legend(loc="upper right")
plt.tight_layout()
plt.show()

# references

- [Linear Relationships in the Transformer’s Positional Encoding](https://blog.timodenk.com/linear-relationships-in-the-transformers-positional-encoding/)
- [Tutorial 6: Transformers and Multi-Head Attention](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html)
- [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)
- [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html)
- [Master Positional Encoding: Part I](https://towardsdatascience.com/master-positional-encoding-part-i-63c05d90a0c3)
- [Transformer Architecture: The Positional Encoding](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/)
- [Why multi-head self attention works: math, intuitions and 10+1 hidden insights](https://theaisummer.com/self-attention/)