# What is Einsum ?

* Let's say we want to multiply two matrices `A` and `B` followed by calculating the sum of each column resulting in a vector `C`. Using Einstein summation notation, we can write this as

$$c_{j} = \sum_{i} \sum_{j} A_{ik}B_{kj} = A_{ik}B_{kj}$$


In [54]:
import jax
import jax.numpy as jnp
import optax

In [9]:
a = jnp.arange(6).reshape((2, 3))

In [16]:
print(a)

[[0 1 2]
 [3 4 5]]


In [17]:
print(b)

[[0 1]
 [2 3]
 [4 5]]


# Transpose

In [8]:
jnp.einsum("ij->ji", a) # basically the notation is saying in abstract way to change each element from i,j position to j,i -> this is what expected in transpose

Array([[0, 3],
       [1, 4],
       [2, 5]], dtype=int32)

# Matrix Sum

In [10]:
jnp.einsum("ij->", a) # matrix sum is a single scalar output so there is no notation after `->` which means collapse all columns and rows

Array(15, dtype=int32)

# Column Sum

In [13]:
jnp.einsum("ij->j", a) # matrix sum column wise so the output has only how many dimensions in the column so rows are collapsed.

Array([3, 5, 7], dtype=int32)

# Row Sum

In [18]:
jnp.einsum("ij->i", a)

Array([ 3, 12], dtype=int32)

# Matrix-Vector Multiplication




In [20]:
b = jnp.arange(3) # (3, )
jnp.einsum("ij,j->i", a, b) # same as jnp.dot(a, b)

Array([ 5, 14], dtype=int32)

# Matrix-Matrix Multiplication

In [31]:
a = jnp.arange(6).reshape(2, 3)
b = jnp.arange(15).reshape(3, 5)
jnp.einsum('ik,kj->ij', a, b)

Array([[ 25,  28,  31,  34,  37],
       [ 70,  82,  94, 106, 118]], dtype=int32)

In [32]:
jnp.einsum('ik,kj->j', a, b) # matrix multiplication plus columns wise sum

Array([ 95, 110, 125, 140, 155], dtype=int32)

In [33]:
jnp.einsum('ik,kj->i', a, b) # matrix multiplication plus row wise sum

Array([155, 470], dtype=int32)

# Dot Product

In [34]:
a = jnp.arange(3)
b = jnp.arange(3,6)
jnp.einsum('i,i->', a, b) # notation is basically i is index of a matrix and the i index in b multiplied index wise

Array(14, dtype=int32)

# Element Wise Multiplication of Two Matrices and the summation

In [48]:
a = jnp.arange(6).reshape(2, 3)
b = jnp.arange(6,12).reshape(2, 3)
jnp.einsum('ij,ij->ij', a, b)

Array([[ 0,  7, 16],
       [27, 40, 55]], dtype=int32)

In [49]:
jnp.einsum('ij,ij->', a, b)

Array(145, dtype=int32)

# Outer Product

Given two vectors of size $m \times 1$ and $n \times 1$respectively

$$ u = \begin{bmatrix}
  u_{1} \\
  u_{2} \\
  u_{3} \\
  u_{4} \\
\end{bmatrix}
\,
v = \begin{bmatrix}
  v_{1} \\
  v_{2} \\
  v_{3} \\
\end{bmatrix}$$
$$$$
$$ u \otimes  v = uv^{T} = \begin{bmatrix}
  u_{1}v_{1} & u_{1}v_{2} & u_{1}v_{3}\\
  u_{2}v_{1} & u_{2}v_{2} & u_{2}v_{3}\\
  u_{3}v_{1} & u_{3}v_{2} & u_{3}v_{3}\\
  u_{4}v_{1} & u_{4}v_{2} & u_{4}v_{3}\\
\end{bmatrix}$$

In [51]:
a = jnp.arange(3)
b = jnp.arange(3,7)
jnp.einsum('i,j->ij', a, b)

Array([[ 0,  0,  0,  0],
       [ 3,  4,  5,  6],
       [ 6,  8, 10, 12]], dtype=int32)

# Dot-Product Attention

$$ Attention(q, K, V) = softmax(\dfrac {q \cdot K^{T}}{\sqrt d_{k}})V$$

In [73]:
def dot_product_attention(q, K, V):
    """ 
    Dot−Product Attention on one query.
    Args :
        q : a vector with shape [k]
        K: a matrix with shape [m, k]
        V: a ma trix with shape [m, v]
    Returns :
        y : a vector with shape [v]
    """
    
    logits = jnp.einsum("k,mk->m", q, K)
    weights = jax.nn.softmax(logits)
    return jnp.einsum("m,mv->v", weights, V)
    

In [109]:
key = jax.random.PRNGKey(13)
d_model = 512
seq_len = 128
num_heads = 8

In [110]:
q = jax.random.uniform(key, shape=(d_model // num_heads,)) # One token with embedding size of 384 projected to a 64 one of head
K = jax.random.uniform(key, shape=(seq_len, d_model // num_heads)) # consider a sentence with max of 128 sequence length and each token with 384 dimension projected to a 64 one of head
V = jax.random.uniform(key, shape=(seq_len, d_model // num_heads)) # consider a sentence with max of 128 sequence length and each token with 384 dimension projected to a 64 one of head

In [103]:
print(f"q: {q.shape}")
print(f"K: {K.shape}")
print(f"V: {V.shape}")

q: (64,)
K: (128, 64)
V: (128, 64)


In [104]:
attention_scores = dot_product_attention(q, K, V)
attention_scores.shape

(128,)


(64,)

# Scaled Dot-Product Attention

$$ Attention(Q, K, V) = softmax(\dfrac {Q \cdot K^{T}}{\sqrt d_{k}})V$$

In [105]:
def scaled_dot_product_attention(Q, K, V):
    """ 
    Scaled Dot−Product Attention on maximum sequence length of queries.
    Args :
        Q: a matrix with shape [m, q]
        K: a matrix with shape [m, k]
        V: a matrix with shape [m, v]
    Returns :
        y : a vector with shape [m, v]
    """
    
    logits = jnp.einsum("mq,mk->qk", Q, K)
    weights = jax.nn.softmax(logits)
    return jnp.einsum("vv,mv->mv", weights, V)
    

In [111]:
Q = jax.random.uniform(key, shape=(seq_len, d_model // num_heads)) # consider a sentence with max of 128 sequence length and each token with 384 dimension projected to a 64 one of head
K = jax.random.uniform(key, shape=(seq_len, d_model // num_heads)) # consider a sentence with max of 128 sequence length and each token with 384 dimension projected to a 64 one of head
V = jax.random.uniform(key, shape=(seq_len, d_model // num_heads)) # consider a sentence with max of 128 sequence length and each token with 384 dimension projected to a 64 one of head

In [107]:
attention_scores = scaled_dot_product_attention(Q, K, V)
attention_scores.shape

(128, 64)

In [123]:
def multi_head_attention(
    X, M, P_q, P_k, P_v, P_o):
    """Multi-head Attention on maximum sequence length of queries.
    Args:
        X: a matrix with shape of [n, d]
        M: a matrix with shape of [m, d]
        P_q: a tensor with shape of [h, d, k]
        P_k: a tensor with shape of [h, d, k]
        P_v: a tensor with shape of [h, d, v]
        P_o:  a tensor with shape of [h, d, v]
    Returns:
        y: a tensor with shape of [n, d]
    """
    Q = jnp.einsum("nd,hdk->hnk", X, P_q)
    K = jnp.einsum("md,hdk->hmk", M, P_k)
    V = jnp.einsum("md,hdv->hmv", M, P_v)
    
    logits = jnp.einsum("hnk,hmk->hnm", Q, K)
    weights = jax.nn.softmax(logits)
    O = jnp.einsum("hnm,hmv->hnv", weights, V)
    Y = jnp.einsum("hnv,hdv->nd", O, P_o)
    
    return Y

In [125]:
X = jax.random.uniform(key, shape=(seq_len, d_model)) # Input tensor with seq_len and embedding size 
M = jax.random.uniform(key, shape=(seq_len, d_model)) 

# Basically this is the dense layer weights, so when input is passed through dense layer we will get an outputs shape (heads, seql_len, head_size) projection
P_q = jax.random.uniform(key, shape=(seq_len, d_model, d_model // num_heads)) # Projection of input tensor embeddings to each head size which is d_model // num_heads
P_k = jax.random.uniform(key, shape=(seq_len, d_model, d_model // num_heads))
P_v = jax.random.uniform(key, shape=(seq_len, d_model, d_model // num_heads))
P_o = jax.random.uniform(key, shape=(seq_len, d_model, d_model // num_heads))

attention_scores = multi_head_attention(X, M, P_q, P_k, P_v, P_o)
attention_scores.shape

(128, 512)

In [128]:
def batched_multi_head_attention(
    X, M, P_q, P_k, P_v, P_o):
    """Multi-head Attention on maximum sequence length of queries.
    Args:
        X: a matrix with shape of [b, n, d]
        M: a matrix with shape of [b, m, d]
        P_q: a tensor with shape of [b, h, d, k]
        P_k: a tensor with shape of [b, h, d, k]
        P_v: a tensor with shape of [b, h, d, v]
        P_o:  a tensor with shape of [b, h, d, v]
    Returns:
        y: a tensor with shape of [b, n, d]
    """
    Q = jnp.einsum("bnd,bhdk->bhnk", X, P_q)
    K = jnp.einsum("bmd,bhdk->bhmk", M, P_k)
    V = jnp.einsum("bmd,bhdv->bhmv", M, P_v)
    
    logits = jnp.einsum("bhnk,bhmk->bhnm", Q, K)
    weights = jax.nn.softmax(logits)
    O = jnp.einsum("bhnm,bhmv->bhnv", weights, V)
    Y = jnp.einsum("bhnv,bhdv->bnd", O, P_o)
    
    return Y

In [129]:
X = jax.random.uniform(key, shape=(batch_size, seq_len, d_model)) # Input tensor with seq_len and embedding size 
M = jax.random.uniform(key, shape=(batch_size, seq_len, d_model)) 

# Basically this is the dense layer weights, so when input is passed through dense layer we will get an outputs shape (heads, seql_len, head_size) projection
P_q = jax.random.uniform(key, shape=(batch_size, seq_len, d_model, d_model // num_heads)) # Projection of input tensor embeddings to each head size which is d_model // num_heads
P_k = jax.random.uniform(key, shape=(batch_size, seq_len, d_model, d_model // num_heads))
P_v = jax.random.uniform(key, shape=(batch_size, seq_len, d_model, d_model // num_heads))
P_o = jax.random.uniform(key, shape=(batch_size, seq_len, d_model, d_model // num_heads))

attention_scores = batched_multi_head_attention(X, M, P_q, P_k, P_v, P_o)
attention_scores.shape

(32, 128, 512)

# References:

1. https://rockt.github.io/2018/04/30/einsum
2. https://arxiv.org/pdf/1706.03762