# An RNN version of the attention mechanism

To keep things simple we start with a two item sequence.  So, we have two row vectors $x_1$ and $x_2$ both 
$\mathbb{R}^{1 \times k}$.

Classically, we would stack them in a matrix $X$ of size $\mathbb{R}^{2 \times k}$.

$$
X = \begin{bmatrix}
x_1 \\
x_2 
\end{bmatrix}
$$

Looking at the attention equations we get

![](../attachments/2023-03-22-08-14-48.png)

The tricky bit for an RNN to simulate is the dot-product $Q K^T$, but we have a trick up our sleeve.  We can use a matrix multiplication to simulate the dot product.  We can do this by using the Hadamard product to multiply the two matrices element-wise and then summing the rows.  (ok, that last sentence was written by copilot... that is a bit scary 🫢

Our goal is to write the attention mechanism in terms of 

$$
W_1 x \odot W_2 x
$$


We begin with the first item/word/symbol in the sequence.  We will call this $x_1$.  We will use the same $Q$ and $K$ matrices for all items in the sequence.  We will also define the "ones function" $\mathbf{1(x)}$ which returns vector of ones of the same size as $x$.

$$
\begin{align*}
\begin{bmatrix}
0 & \cdots \\
Q & 0 & \cdots \\
K & 0 & \cdots \\
\vdots \\
\end{bmatrix}
\begin{bmatrix}
x_1^T \\
0 \\
\vdots \\
\end{bmatrix}
&
= 
&
\begin{bmatrix}
0 \\
Q(x_1^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
\odot
&
\begin{bmatrix}
\mathbf{1}(x_1^T) \\
\mathbf{1}(x_1^T) \\
\mathbf{1}(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
=
&
\begin{bmatrix}
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\vdots \\
\end{bmatrix}
\begin{bmatrix}
x_1^T \\
0 \\
\vdots \\
\end{bmatrix} 
\end{align*}
$$

which is

$$
\begin{bmatrix}
0 \\
Q(x_1^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
$$

Now we use $x_2$ to get

$$
\begin{bmatrix}
x_2^T \\
Q(x_1^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
$$

$$
\begin{align*}
\begin{bmatrix}
0 & \cdots \\
Q & 0 & 0 & \cdots \\
K & 0 & 0 & \cdots \\
0 & I & 0 & \cdots \\
0 & 0 & I & 0 & \cdots \\
\vdots
\end{bmatrix}
\begin{bmatrix}
x_2^T \\
Q(x_1^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
=
& 
\begin{bmatrix}
0 \\
Q(x_2^T) \\
K(x_2^T) \\
Q(x_1^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
\odot
&
\begin{bmatrix}
\mathbf{1}(x_2^T) \\
\mathbf{1}(x_2^T) \\
\mathbf{1}(x_2^T) \\
\mathbf{1}(x_2^T) \\
\mathbf{1}(x_2^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
=
&
\begin{bmatrix}
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\vdots \\
\end{bmatrix}
\begin{bmatrix}
x_2^T \\
Q(x_1^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
\end{align*}
$$

which is

$$
\begin{bmatrix}
0 \\
Q(x_2^T) \\
K(x_2^T) \\
Q(x_1^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
$$

Ok, to this point we have not really used $W_2$ in any essential way.  However, now that changes! 🤯

$$
\begin{align*}
\begin{bmatrix}
0 & \cdots \\
Q & 0 & 0 & \cdots \\
K & 0 & 0 & \cdots \\
0 & I & 0 & \cdots \\
0 & 0 & I & 0 & \cdots \\
0 & I & 0 & \cdots \\
0 & I & 0 & \cdots \\
0 & 0 & 0 & I & 0 \cdots \\
0 & 0 & 0 & I & 0 \cdots \\
\vdots
\end{bmatrix}
\begin{bmatrix}
x_3^T \\
Q(x_2^T) \\
K(x_2^T) \\
Q(x_1^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
=
& 
\begin{bmatrix}
0 \\
Q(x_3^T) \\
K(x_3^T) \\
Q(x_2^T) \\
K(x_2^T) \\
Q(x_2^T) \\
Q(x_2^T) \\
Q(x_1^T) \\
Q(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
\odot
&
\begin{bmatrix}
\mathbf{1}(x_3^T) \\
\mathbf{1}(x_3^T) \\
\mathbf{1}(x_3^T) \\
\mathbf{1}(x_3^T) \\
\mathbf{1}(x_3^T) \\
K(x_2^T) \\
K(x_1^T) \\
K(x_2^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
=
&
\begin{bmatrix}
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
0 & 0 & I & 0 & \cdots \\
0 & 0 & 0 & 0 & I & 0 \cdots \\
0 & 0 & I & 0 & \cdots \\
0 & 0 & 0 & 0 & I & 0 \cdots \\
\vdots \\
\end{bmatrix}
\begin{bmatrix}
x_3^T \\
Q(x_2^T) \\
K(x_2^T) \\
Q(x_1^T) \\
K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
\end{align*}
$$

which is

$$
\begin{bmatrix}
0 \\
Q(x_3^T) \\
K(x_3^T) \\
Q(x_2^T) \\
K(x_2^T) \\
Q(x_2^T) \odot K(x_2^T) \\
Q(x_2^T) \odot K(x_1^T) \\
Q(x_1^T) \odot K(x_2^T) \\
Q(x_1^T) \odot K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
$$

All that we need now is to do a summation

$$
\begin{align*}
\begin{bmatrix}
0 & \cdots \\
Q & 0 & 0 & \cdots \\
K & 0 & 0 & \cdots \\
0 & I & 0 & \cdots \\
0 & 0 & I & 0 & \cdots \\
0 & I & 0 & \cdots \\
0 & I & 0 & \cdots \\
0 & 0 & 0 & I & 0 & \cdots \\
0 & 0 & 0 & I & 0 & \cdots \\
0 & 0 & 0 & 0 & 0 & \sum & 0  \cdots \\
0 & 0 & 0 & 0 & 0 & 0 & \sum & 0 \cdots \\
0 & 0 & 0 & 0 & 0 & 0 & 0 & \sum & 0 \cdots \\
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & \sum & 0 \cdots \\
\vdots
\end{bmatrix}
\begin{bmatrix}
x_4^T \\
Q(x_3^T) \\
K(x_3^T) \\
Q(x_2^T) \\
K(x_2^T) \\
Q(x_2^T) \odot K(x_2^T) \\
Q(x_2^T) \odot K(x_1^T) \\
Q(x_1^T) \odot K(x_2^T) \\
Q(x_1^T) \odot K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
=
& 
\begin{bmatrix}
0 \\
Q(x_4^T) \\
K(x_4^T) \\
Q(x_3^T) \\
K(x_3^T) \\
Q(x_3^T) \\
Q(x_3^T) \\
Q(x_2^T) \\
Q(x_2^T) \\
\sum Q(x_2^T) \odot K(x_2^T) \\
\sum Q(x_2^T) \odot K(x_1^T) \\
\sum Q(x_1^T) \odot K(x_2^T) \\
\sum Q(x_1^T) \odot K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
&
\odot
&
\begin{bmatrix}
\mathbf{1}(x_4^T) \\
\mathbf{1}(x_4^T) \\
\mathbf{1}(x_4^T) \\
\mathbf{1}(x_4^T) \\
\mathbf{1}(x_4^T) \\
K(x_3^T) \\
K(x_2^T) \\
K(x_3^T) \\
K(x_2^T) \\
\mathbf{1} \\
\mathbf{1} \\
\mathbf{1} \\
\mathbf{1} \\
0 \\
\vdots \\
\end{bmatrix}
&
=
&
\begin{bmatrix}
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
\mathbf{1} & 0 & \cdots \\
0 & 0 & I & 0 & \cdots \\
0 & 0 & 0 & 0 & I & 0 \cdots \\
0 & 0 & I & 0 & \cdots \\
0 & 0 & 0 & 0 & I & 0 \cdots \\
0 & \cdots & 0 & \mathbf{1} & 0 & \cdots \\
0 & \cdots & 0 & \mathbf{1} & 0 & \cdots \\
0 & \cdots & 0 & \mathbf{1} & 0 & \cdots \\
0 & \cdots & 0 & \mathbf{1} & 0 & \cdots \\
\vdots \\
\end{bmatrix}
\begin{bmatrix}
x_4^T \\
Q(x_3^T) \\
K(x_3^T) \\
Q(x_2^T) \\
K(x_2^T) \\
Q(x_2^T) \odot K(x_2^T) \\
Q(x_2^T) \odot K(x_1^T) \\
Q(x_1^T) \odot K(x_2^T) \\
Q(x_1^T) \odot K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
\end{align*}
$$

which is

$$
\begin{bmatrix}
0 \\
Q(x_4^T) \\
K(x_4^T) \\
Q(x_3^T) \\
K(x_3^T) \\
Q(x_3^T) \\
Q(x_3^T) \\
Q(x_2^T) \\
Q(x_2^T) \\
\sum Q(x_2^T) \odot K(x_2^T) \\
\sum Q(x_2^T) \odot K(x_1^T) \\
\sum Q(x_1^T) \odot K(x_2^T) \\
\sum Q(x_1^T) \odot K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
= 
\begin{bmatrix}
0 \\
Q(x_4^T) \\
K(x_4^T) \\
Q(x_3^T) \\
K(x_3^T) \\
Q(x_3^T) \\
Q(x_3^T) \\
Q(x_2^T) \\
Q(x_2^T) \\
Q(x_2^T)^T K(x_2^T) \\
Q(x_2^T)^T K(x_1^T) \\
Q(x_1^T)^T K(x_2^T) \\
Q(x_1^T)^T K(x_1^T) \\
0 \\
\vdots \\
\end{bmatrix}
$$

😎
m