### Algorithm: Basic Single-Query Attention

**Input**:  
- $e \in \mathbb{R}^{d_{\text{in}}}$, vector representation of the current token  
- $e_t \in \mathbb{R}^{d_{\text{in}}}$, vector representations of context tokens $t \in [T]$  

**Output**:  
- $\tilde{v} \in \mathbb{R}^{d_{\text{out}}}$, vector representation of the token and context combined.  

**Parameters**:  
- $W_q, W_k \in \mathbb{R}^{d_{\text{attn}} \times d_{\text{in}}}$, $b_q, b_k \in \mathbb{R}^{d_{\text{attn}}}$: the query and key linear projections.  
- $W_v \in \mathbb{R}^{d_{\text{out}} \times d_{\text{in}}}$, $b_v \in \mathbb{R}^{d_{\text{out}}}$: the value linear projection.

---

**Steps**:

1. $q \gets W_q e + b_q$
2. $\forall t: k_t \gets W_k e_t + b_k$
3. $\forall t: v_t \gets W_v e_t + b_v$
4. $\forall t: \alpha_t = \frac{\exp(q^\top k_t / \sqrt{d_{\text{attn}}})}{\sum_{u} \exp(q^\top k_u / \sqrt{d_{\text{attn}}})}$
5. Return $\tilde{v} = \sum_{t=1}^{T} \alpha_t v_t$


In [21]:
din = 100
T = 10
dout = 100
dattn = 50

In [72]:
import numpy as np
import pandas as pd
from math import sqrt

In [3]:
Wq = np.random.rand(dattn, din)
Wk = np.random.rand(dattn, din)
Wv = np.random.rand(dout, din)
bq = np.zeros(dattn)
bk = np.zeros(dattn)
bv = np.zeros(dout)

In [5]:
Wq.shape

(50, 100)

In [56]:
def normalise(a):
    col_sum = np.sum(a, axis=0)
    return a/col_sum

In [71]:
from math import sqrt
def single_query_attention(e, E):
    q = np.matmul(Wq, e) + bq
    K = [np.matmul(Wk, et) + bk for et in E]
    V = [np.matmul(Wv, et) + bv for et in E]
    q, K, V = normalise(q), normalise(K), normalise(V)
    su = sum([np.exp(np.dot(q, kt)/sqrt(dattn)) for kt in K])
    a = [np.exp(np.dot(q, kt)/sqrt(dattn))/su for kt in K]
    Att = a[0]*V[0]
    for i in range(1,T):
        Att = Att + a[i]*V[i]
    return Att

In [51]:
e = np.random.rand(din)

In [52]:
len(e.shape)

1

In [53]:
E = [np.random.rand(din) for _ in range(T)]

In [62]:
Att = single_query_attention(e, E)

In [65]:
Att

array([0.10002971, 0.10002864, 0.10003348, 0.10004385, 0.10002453,
       0.10004076, 0.10002649, 0.10004898, 0.10003641, 0.10003001,
       0.10003688, 0.10004502, 0.10002855, 0.10004628, 0.10002975,
       0.10003923, 0.10004023, 0.10003436, 0.10002789, 0.10003176,
       0.10004013, 0.10003578, 0.10001883, 0.10004355, 0.10003441,
       0.1000269 , 0.10004233, 0.10002851, 0.10002821, 0.1000321 ,
       0.1000355 , 0.10004034, 0.10003454, 0.10003524, 0.10003045,
       0.10002056, 0.10003468, 0.10003826, 0.1000327 , 0.10003903,
       0.10002632, 0.10002373, 0.10003514, 0.10003568, 0.10003533,
       0.10002659, 0.10003964, 0.10003845, 0.10003135, 0.10004984,
       0.10003829, 0.100042  , 0.1000361 , 0.10003198, 0.10003086,
       0.10004494, 0.10003964, 0.10002907, 0.10004636, 0.10003095,
       0.10003546, 0.10002636, 0.10002401, 0.10003907, 0.10003731,
       0.10003925, 0.10003595, 0.10003267, 0.10003538, 0.10004106,
       0.10003958, 0.10003122, 0.10004513, 0.10004382, 0.10003

### Algorithm: $\tilde{V} \gets \text{Attention}(X, Z \mid \mathcal{W}_{qkv}, \text{Mask})$

**/* Computes a single (masked) self- or cross-attention head. */**

---

**Input**:  
- $X \in \mathbb{R}^{d_x \times \ell_x}$, $Z \in \mathbb{R}^{d_z \times \ell_z}$: vector representations of the primary and context sequence.

**Output**:  
- $\tilde{V} \in \mathbb{R}^{d_{\text{out}} \times \ell_x}$: updated representations of tokens in $X$, folding in information from tokens in $Z$.

**Parameters**:  
- $\mathcal{W}_{qkv}$ consisting of:  
  - $W_q \in \mathbb{R}^{d_{\text{attn}} \times d_x}$, $b_q \in \mathbb{R}^{d_{\text{attn}}}$  
  - $W_k \in \mathbb{R}^{d_{\text{attn}} \times d_z}$, $b_k \in \mathbb{R}^{d_{\text{attn}}}$  
  - $W_v \in \mathbb{R}^{d_{\text{out}} \times d_z}$, $b_v \in \mathbb{R}^{d_{\text{out}}}$  

**Hyperparameters**:  
- $\text{Mask} \in \{0, 1\}^{\ell_z \times \ell_x}$, $\uparrow^{(3)}$

---

**Steps**:

1. $Q \gets W_q X + b_q 1^\top \quad$ [Query $\in \mathbb{R}^{d_{\text{attn}} \times \ell_x}$]
2. $K \gets W_k Z + b_k 1^\top \quad$ [Key $\in \mathbb{R}^{d_{\text{attn}} \times \ell_z}$]
3. $V \gets W_v Z + b_v 1^\top \quad$ [Value $\in \mathbb{R}^{d_{\text{out}} \times \ell_z}$]
4. $S \gets K^\top Q \quad$ [Score $\in \mathbb{R}^{\ell_z \times \ell_x}$]
5. $\forall t_z, t_x, \text{ if } \neg \text{Mask}[t_z, t_x] \text{ then } S[t_z, t_x] \gets -\infty$
6. Return $\tilde{V} = V \cdot \text{softmax}(S / \sqrt{d_{\text{attn}}})$


In [85]:
dx = 100 # input token encoding dimension
lx = 20 # input length
dz = 100 # context token encoding dimension
lz = 20 # context window length
dout = 100 # output dimension
dattn = 50 # attention dimension

In [111]:
Wq = np.random.rand(dattn, dx)
Wk = np.random.rand(dattn, dz)
Wv = np.random.rand(dout, dz)
bq = np.random.rand(dattn)
bk = np.random.rand(dattn)
bv = np.random.rand(dout)

In [163]:
def softmax(a):
    a = np.exp(a)
    col_sums = np.sum(a, axis=0)
    for i in range(col_sums.shape[0]):
        if col_sums[i]==0:
            col_sums[i] = 1
    a = a/col_sums
    return a

def masked_attention(X, Z, mask):
    Q = np.matmul(Wq, X) + bq[:, np.newaxis]
    K = np.matmul(Wk, Z) + bk[:, np.newaxis]
    V = np.matmul(Wv, Z) + bv[:, np.newaxis]
    Q, K, V = normalise(Q), normalise(K), normalise(V)
    S = np.matmul(K.T, Q)
    S *= mask
    S = softmax(S/sqrt(dattn))
    Vout = np.matmul(V,S)
    return Vout

In [135]:
X = np.random.rand(dx, lx)

In [89]:
Z = np.random.rand(dz, lz)

In [90]:
lz,lx

(20, 20)

In [123]:
mask = np.ones((lz,lx))

In [124]:
m,n = mask.shape
for i in range(m):
    for j in range(n):
        if i>=j:
            mask[i,j] = -np.inf

In [151]:
mask[8,:]

array([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,   1.,   1.,
         1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.])

In [164]:
mask_Att = masked_attention(X, Z, mask)

In [166]:
mask_Att

array([[0.        , 0.00921922, 0.00925963, ..., 0.00935413, 0.00936589,
        0.0093575 ],
       [0.        , 0.0089599 , 0.00919631, ..., 0.00899994, 0.00898594,
        0.00896961],
       [0.        , 0.0106756 , 0.01061519, ..., 0.01036584, 0.01034745,
        0.01034312],
       ...,
       [0.        , 0.01091861, 0.01055166, ..., 0.01092468, 0.01088339,
        0.010875  ],
       [0.        , 0.01128569, 0.01098261, ..., 0.01075181, 0.01073719,
        0.01075071],
       [0.        , 0.01079543, 0.01090798, ..., 0.01078401, 0.01080107,
        0.01077932]])