<a href="https://colab.research.google.com/github/pixqc/einsum-puzzles/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Welcome to pixqc/einsum-puzzles!

# Einsum (short for Einstein summation) is a powerful tool for tensor multiplication, summation, and permutation.
# Multi-step tensor operations can be concisely expressed with a single einsum subscript.

import numpy as np

a, b = np.random.randn(2, 3, 4), np.random.randn(2, 3, 4)
out1 = np.transpose(np.tensordot(a, b, axes=([2], [2])), (0, 1, 3, 2))
out2 = np.einsum("ijm,lkm->ijkl", a, b)
print(np.allclose(out1, out2))
# ijm,lkm->ijkl... what? This notebook attempts to demystify einsum.

# We'll start with multiplying two vectors,
# we'll end with implementing multi-head attention.

# Einsum is available on numpy, torch, jax, mlx, and tinygrad.
# We'll use numpy, but the core idea applies to all tensor library.

# (Einsum cheatsheet available at the end of this notebook!)

In [None]:
# Introduction

# In einsum-land, we deal with tensor shapes.
# We tell einsum what the output should look like,
# and it does the operations for us.
# np.einsum("INPUT->OUTPUT", a)
#            │      │
#            │      └─── desired output shape
#            └────────── a.shape

# We use 'ijk' to represent dimensions.
# When our input has one dimension,
# we use one letter to represent that dimension.
a = np.array([1, 2, 3])
out = np.einsum("i->i", a)  # identity function
#                │  │
#                │  └─── output: (3,)
#                └────── a.shape = (3,); i = 3
# 'i->i' means "take the input and keep it the same"
print(np.allclose(a, out))

# When our input has two dimensions,
# we use two letters to represent each dimension.
a = np.array([[1, 2, 3], [4, 5, 6]])
out = np.einsum("ij->ij", a)  # identity function
#                │   │
#                │   └─── output: (2, 3)
#                └─────── a.shape = (2, 3); i=2, j=3
print(np.allclose(a, out))

# Likewise with three dims.
a = np.arange(24).reshape(2, 3, 4)
out = np.einsum("ijk->ijk", a)  # identity function
#                │    │
#                │    └─── output: (2, 3, 4)
#                └──────── a.shape = (2, 3, 4); i=2, j=3, k=4
print(np.allclose(a, out))

# Notes:
# - We can use any letter, not only 'ijk'.
#   Ie. np.einsum('abc->abc', a) is valid.
# - The arrow and output subscript can be ommited.
#   Ie. np.einsum('ijk', a) is valid.
# - This 'ijk->ijk' thing is sometimes called subscript.

In [None]:
# Puzzle #1
# Compute identity of vector a.
a = np.arange(10)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(a, out))

In [None]:
# Puzzle #2
# Compute identity of tensor a.
a = np.arange(120).reshape(5, 4, 3, 2)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(a, out))

In [None]:
# Multiplication

# We separate multiple inputs with comma.
# np.einsum("INPUT1,INPUT2->OUTPUT", a, b)
#            │      │       │
#            │      │       └─── desired output shape
#            │      └─────────── b.shape
#            └────────────────── a.shape

# When we repeat indices in the input,
# einsum elementwise-multiplies along that dimension.
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
out = np.einsum("i,i->i", a, b)
#                │ │  │
#                │ │  └──── output: (3,)
#                │ │        "keep original shape"
#                │ │
#                │ └─────── b.shape = (3,); i = 3
#                ├───────── a.shape = (3,); i = 3
#                │
#                ├───────── 'i' appears twice before the arrow
#                └───────── they are elementwise multiplied
print(np.allclose(a * b, out))

In [None]:
# Puzzle #3
# Multiply vector a and b.
a = np.arange(10)
b = np.arange(10)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(a * b, out))

In [None]:
# Puzzle #4
# Multiply matrix a and b.
a = np.arange(24).reshape(2, 3, 4)
b = np.arange(24).reshape(2, 3, 4)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(a * b, out))

In [None]:
# Summation

# When we omit an index in the output,
# we're telling einsum to sum over that dimension.
a = np.array([1, 2, 3])
out = np.einsum("i->", a)
#                │  │
#                │  └─── output: scalar (all dims summed)
#                └────── a.shape = (3,); i = 3
print(np.allclose(np.sum(a), out))

# 'ijk->' will reduce a 3d tensor into a scalar.
a = np.arange(24).reshape(4, 3, 2)
out = np.einsum("ijk->", a)
#                │    │
#                │    └─── output: scalar (all dims summed)
#                └────── a.shape = (3,); i = 3
print(np.allclose(np.sum(a), out))

In [None]:
# Puzzle #5
# Compute sum of vector a.
a = np.arange(10)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(np.sum(a), out))

In [None]:
# Puzzle #6
# Multiply vector a and b.
a = np.arange(10)
b = np.arange(10)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(a * b, out))

In [None]:
# Puzzle #7
# Compute dot product of vector a and b.
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(np.sum(a * b), out))

In [171]:
# Check our understanding

# Q1: 'i->i produces identity, 'i,i->i' multiplies. Why?
# Q2: What will 'ij->ji' do to a matrix?
# Q3: Is 'ijk' a valid input for a 2d matrix?
# Q4: Can we elementwise add with einsum?

# (Answers are available at the end of this notebook.)

In [None]:
# Sum over axis

# Einsum output:
# Whatever's gone is summed over,
# whatever's left is kept.

# What we learned we can do:
# Omit an index from the output to sum over it.
# What we can also do:
# Keep an index to preserve that dimension.
a = np.arange(6).reshape(2, 3)
out = np.einsum("ij->i", a)
#                │   │
#                │   └─── output: (2,)
#                │        "sum over j, keep i"
#                │
#                └─────── a.shape = (2, 3); i=2, j=3
print(np.allclose(np.sum(a, axis=1), out))

# Same thing but on axis=0
a = np.arange(6).reshape(2, 3)
out = np.einsum("ij->j", a)
#                │   │
#                │   └─── output: (3,)
#                │        "sum over i, keep j"
#                │
#                └─────── a.shape = (2, 3); i=2, j=3
print(np.allclose(np.sum(a, axis=0), out))

In [None]:
# Puzzle #8
# Compute np.sum(a, axis=1)
a = np.arange(10).reshape(2, 5)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(np.sum(a, axis=1), out))

In [None]:
# Puzzle #9
# Compute np.sum(a*b, axis=0)
a = np.arange(10).reshape(2, 5)
b = np.arange(10).reshape(2, 5)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(np.sum(a * b, axis=0), out))

In [None]:
# Puzzle #10
# Compute np.sum(a, axis=1)
a = np.arange(24).reshape(2, 3, 4)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(np.sum(a, axis=1), out))

In [None]:
# Diagonal and trace

# Diagonal: we use the same index for both dimensions.
a = np.arange(9).reshape(3, 3)
out = np.einsum("ii->i", a)
#                │   │
#                │   └─── output: (3,)
#                └─────── a.shape = (3, 3); i=3
# 'ii' is saying: "get items where row and column is the same".
print(np.allclose(np.diag(a), out))

# Trace: it's diagonal but the output is summed.
a = np.arange(9).reshape(3, 3)
out = np.einsum("ii->", a)
#                │   │
#                │   └─── output: scalar
#                └─────── a.shape = (3, 3); i=3
print(np.allclose(np.trace(a), out))

In [None]:
# Puzzle #11
# Compute np.diag(a)
a = np.arange(9).reshape(3, 3)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(np.diag(a), out))

In [None]:
# Puzzle #12
# Compute np.trace(a)
a = np.arange(9).reshape(3, 3)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(np.trace(a), out))

In [None]:
# Broadcasting

# We can use einsum to broadcast.

# Let’s do a quick refresher on tensor broadcasting:
# - Right-align shapes, prepend 1s if needed.
# - Dims are compatible if equal or 1.
# - 1s are broadcast to match the other size.
# - Result shape is max size for each dimension.

# Broadcasting along an axis
a = np.arange(6).reshape(2, 3)
b = np.arange(3)
out = np.einsum("ij,j->ij", a, b)
#                │  │  │
#                │  │  └─── output: (2, 3)
#                │  │       'i' and 'j' are kept
#                │  │
#                │  └────── b.shape = (3,); j=3
#                └───────── a.shape = (2, 3); i=2, j=3
# One way to look at it:
# 'ij,j->ij' is broadcasted to 'ij,ij->ij' by einsum.
print(np.allclose(a * b, out))

# Broadcasting with higher dimensions
a = np.arange(24).reshape(2, 3, 4)
b = np.arange(3)
out = np.einsum("ijk,j->ijk", a, b)
#                │   │     │
#                │   │     └─── output: (2, 3, 4)
#                │   │          'i', 'j', and 'k' are kept
#                │   │
#                │   └───────── b.shape = (3, 1); j=3, ...=(1,)
#                └─────────── a.shape = (2, 3, 4); i=2, j=3, k=4
# Likewise: 'ijk,j->ijk' is broadcasted to 'ijk,ijk->ijk'.
print(np.allclose(a * b[:, None], out))

In [None]:
# Puzzle #13
# Compute np.sum(a*b,axis=1)
# b must be broadcasted.
a = np.arange(6).reshape(2, 3)
b = np.arange(3)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(np.sum(a * b, axis=1), out))

In [None]:
# Puzzle #14
# Compute np.allclose(np.sum(a*b, axis=(1, 2))
# b must be broadcasted.
a = np.arange(24).reshape(2, 3, 4)
b = np.arange(12).reshape(3, 4)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(np.sum(a * b, axis=(1, 2)), out))

In [None]:
# Outer product

# When we use different indices for each input,
# and combine them all in the output,
# we're telling einsum to compute the outer product.
a = np.array([1, 2, 3])
b = np.array([4, 5, 6, 7])
out = np.einsum("i,j->ij", a, b)
#                │ │  │
#                │ │  │
#                │ │  └──── output: (3, 4)
#                │ │        we create all combinations of i and j
#                │ │
#                │ └─────── b.shape = (4,); j = 4
#                └───────── a.shape = (3,); i = 3
print(np.allclose(np.outer(a, b), out))

# This creates a 4d array by combining all elements of both input.
a = np.arange(6).reshape(2, 3)
b = np.arange(8).reshape(2, 4)
out = np.einsum("ij,kl->ijkl", a, b)
#                │  │   │
#                │  │   │
#                │  │   └── output: (2, 3, 2, 4)
#                │  │       we create all combinations of i, j, k, l
#                │  │
#                │  └────── b.shape = (2, 4); k = 2, l = 4
#                └───────── a.shape = (2, 3); i = 2, j = 3
print(np.allclose(np.outer(a, b).reshape(2, 3, 2, 4), out))

In [None]:
# Puzzle #15
# Compute the outer product of vectors a and b
a = np.arange(3)
b = np.arange(4)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(np.outer(a, b), out))

In [None]:
# Puzzle #16
# Compute the outer product of vector a and matrix b.
a = np.arange(3)
b = np.arange(4).reshape(2, 2)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(np.outer(a, b.flatten()).reshape(3, 2, 2), out))

In [None]:
# Transposition and permutation

# This is where einsum shines.
# Einsum permutations are self-documenting.

# Transpose: we swap the order of indices in the output.
a = np.arange(6).reshape(2, 3)
out = np.einsum("ij->ji", a)
#                │   │
#                │   └─── output: (3, 2)
#                │        we swap i and j
#                │
#                └─────── a.shape = (2, 3); i=2, j=3
print(np.allclose(a.T, out))

# Permute: we can rearrange dimensions in any order we want.
a = np.arange(24).reshape(2, 3, 4)
out = np.einsum("ijk->kji", a)
#                │    │
#                │    └─── output: (4, 3, 2)
#                │         we reverse the order of indices
#                │
#                └─────── a.shape = (2, 3, 4); i=2, j=3, k=4
print(np.allclose(np.transpose(a, (2, 1, 0)), out))

# Even with higher dimensions, it's clear what we're doing.
a = np.arange(120).reshape(2, 3, 4, 5)
out = np.einsum("ijkl->jilk", a)
#                │     │
#                │     └─── output: (3, 2, 5, 4)
#                │          we swap i-j and k-l
#                │
#                └───────── a.shape = (2, 3, 4, 5); i=2, j=3, k=4, l=5
print(np.allclose(np.transpose(a, (1, 0, 3, 2)), out))

# We can multiply elementwise then transpose.
a = np.arange(6).reshape(2, 3)
b = np.arange(6).reshape(2, 3)
out = np.einsum("ij,ij->ji", a, b)
#                │  │   │
#                │  │   └─── output: (3, 2)
#                │  │        we multiply elementwise, then transpose
#                │  │
#                │  └─────── b.shape = (2, 3); i=2, j=3
#                └────────── a.shape = (2, 3); i=2, j=3
print(np.allclose((a * b).T, out))

# And even permute.
a = np.arange(24).reshape(2, 3, 4)
b = np.arange(24).reshape(2, 3, 4)
out = np.einsum("ijk,ijk->kji", a, b)
#                │   │     │
#                │   │     └─ output: (4, 3, 2)
#                │   │        we multiply elementwise, then rearrange
#                │   │
#                │   └─────── b.shape = (2, 3, 4); i=2, j=3, k=4
#                └─────────── a.shape = (2, 3, 4); i=2, j=3, k=4
print(np.allclose((a * b).transpose(2, 1, 0), out))

# Note: when multiplying and permuting,
# the multiplication takes precedence.

In [None]:
# Puzzle #17
# Transpose the 2d matrix a.
a = np.arange(6).reshape(2, 3)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(a.T, out))

In [None]:
# Puzzle #18
# Permute the tensor from shape (2,3,4) to (4,2,3)
a = np.arange(24).reshape(2, 3, 4)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
out.shape
print(np.allclose(np.transpose(a, (2, 1, 0)), out))

In [None]:
# Puzzle #19
# Permute the tensor from shape (2,3,4,5) to (5,3,2,4)
a = np.arange(120).reshape(2, 3, 4, 5)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(np.transpose(a, (3, 1, 0, 2)), out))

In [None]:
# Puzzle #20
# Multiply a and b elementwise, then permute to (4,2,3)
a = np.arange(24).reshape(2, 3, 4)
b = np.arange(24).reshape(2, 3, 4)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose((a * b).transpose(2, 0, 1), out))

In [None]:
# Matrix multiplication

# Repeat on input to multiply, omit from ouptut to sum.
# Combine both we get matrix multiplication.
a = np.arange(6).reshape(3, 2)
b = np.arange(6).reshape(2, 3)
out = np.einsum("ij,jk->ik", a, b)
#                │  │   │
#                │  │   ├──── output: (3, 3)
#                │  │   │     i and k remain, j is summed over
#                │  │   │
#                │  └───┼──── second input: b.shape == (2, 3)
#                │      │     j = 2; k = 3
#                │      │
#                ├──────┼──── first input: a.shape == (3, 2)
#                │      │     i = 3; j = 2
#                │      │
#                ├──────┼──── 'j' appears twice on input
#                └──────┼──── they are elementwise multiplied
#                       │
#                       ├──── 'j' is omitted from output
#                       └──── they are summed over
print(np.allclose(a @ b, out))

In [None]:
# Puzzle #21
# Compute a @ b
a = np.arange(6).reshape(3, 2)
b = np.arange(6).reshape(2, 3)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(a @ b, out))

In [None]:
# Puzzle #22
# Compute (a @ b).T
a = np.arange(6).reshape(3, 2)
b = np.arange(6).reshape(2, 3)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose((a @ b).T, out))

In [None]:
# Puzzle #23
# Compute a @ b.T
a = np.arange(6).reshape(3, 2)
b = np.arange(6).reshape(3, 2)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(a @ b.T, out))

In [None]:
# Puzzle #24
# Compute a @ b
a = np.arange(24).reshape(2, 3, 4)
b = np.arange(20).reshape(4, 5)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(a @ b, out))

In [None]:
# Puzzle #25
# Compute a @ b (batched matmul)
a = np.random.randn(10, 3, 4)
b = np.random.randn(10, 4, 5)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(a @ b, out))

In [None]:
# Puzzle #26
# Compute np.tensordot(a,b,axes=([1],[0]))
# The output should be (3, 4, 6)
a = np.random.randn(3, 5, 4)
b = np.random.randn(5, 6)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(np.tensordot(a, b, axes=([1], [0])), out))

In [None]:
# Puzzle #27
# Compute np.tensordot(a,b,axes=([0],[1])).transpose(1,0,2)
# The output should be (4, 3, 6)
a = np.random.randn(5, 3, 4)
b = np.random.randn(6, 5)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(np.tensordot(a, b, axes=([0], [1])).transpose(1, 0, 2), out))

In [None]:
# Puzzle #28
# Compute np.tensordot(a,b,axes=([1],[0])).sum(axis=(1,2)).T
# The output should be (6, 2)
a = np.random.randn(2, 3, 4)
b = np.random.randn(3, 5, 6)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(np.tensordot(a, b, axes=([1], [0])).sum(axis=(1, 2)).T, out))

In [None]:
# Check our understanding

# Q5: Let a=(3,2); b=(3,2), why is a@b.T expressed as 'ij,kj->ik'?
# Q6: Let a=(3,2); b=(3,2), what's the difference between
#     np.einsum('ij,jk->ik', a, b.T) and np.einsum('ij,kj->ik', a, b)?

# (Answers are available at the end of this notebook.)

In [None]:
# Arrowless subscript

# np.einsum('ijk,ijk->', a, b) is a valid einsum.
# np.einsum('ijk,ijk', a, b) is also a valid einsum.

# When we omit arrow and output subscript from einsum:
# - Repeated indices across inputs are multiplied.
# - Non-repeating indices remain in the output.
# - For a single input, it acts as an identity operation.

# Single input: the identity operation.
a = np.arange(6).reshape(2, 3)
out = np.einsum("ij", a)
#                │
#                └─────── a.shape = (2, 3); i=2, j=3
#                         no repeated indices, nothing to multiply
#                         both 'i' and 'j' remain in output
print(np.allclose(np.einsum("ij->ij", a), out))

# Multiple inputs, repeated indices.
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
out = np.einsum("i,i", a, b)
#                │ │
#                │ └─────── b.shape = (3,); i=3
#                └───────── a.shape = (3,); i=3
#                           'i' repeated across inputs, so it's multiplied
#                           no indices remain, so output is scalar
print(np.allclose(np.einsum("i,i->", a, b), out))

# Matrix multiplication.
a = np.arange(6).reshape(2, 3)
b = np.arange(6).reshape(3, 2)
out = np.einsum("ij,jk", a, b)
#                │  │
#                │  └─────── b.shape = (3, 2); j=3, k=2
#                └────────── a.shape = (2, 3); i=2, j=3
#                            'j' repeated across inputs, so it's multiplied
#                            'i' and 'k' remain, forming output shape
print(np.allclose(np.einsum("ij,jk->ik", a, b), out))

# Multiple inputs, no repeated indices: outer product.
a = np.array([1, 2])
b = np.array([3, 4, 5])
out = np.einsum("i,j", a, b)
#                │ │
#                │ └─────── b.shape = (3,); j=3
#                └───────── a.shape = (2,); i=2
#                           no repeated indices, no multiply
#                           both 'i' and 'j' remain, forming output shape
print(np.allclose(np.einsum("i,j->ij", a, b), out))

In [None]:
# Ellipsis

# We can use '...' to represent leftover dimension.
# It's like a placeholder for indices.
# Say we have 'ijkl', and we're only interested in 'kl'.
# Doing '...kl' means the '...' represents 'ij'.

# Transposing [:-2].
a = np.random.randn(2, 3, 4, 5)
out = np.einsum("...ij->...ji", a)
#                 │ │    │  │
#                 │ │    │  └─── output: (2, 3, 5, 4)
#                 │ │    │       we swap i and j at the end
#                 │ │    └────── '...' = (2, 3)
#                 │ │
#                 │ └─────────── a.shape[-2:] = (4, 5); i=4, j=5
#                 └───────────── a.shape[:-2] = '...'=(2, 3)
print(np.allclose(a.transpose(0, 1, 3, 2), out))

# Transposing head and tail.
a = np.random.randn(2, 3, 4, 5)
out = np.einsum("i...k->k...i", a)
#                │ │ │  │ │ │
#                │ │ │  │ │ └─── output[0] = a.shape[3] = 5
#                │ │ │  │ └───── '...' = (3, 4)
#                │ │ │  └─────── output[3] = a.shape[0] = 2
#                │ │ └────────── a.shape[-1] = 5
#                │ └──────────── a.shape[1:3] = '...' = (3, 4)
#                └────────────── a.shape[0] = 2
# It's like saying:
# We only care about 'ik' dim, the '...' represents (3,4).
# When we use '...' on output, it means (3,4).
print(np.allclose(a.transpose(3, 1, 2, 0), out))

# Batched matmul.
a = np.random.randn(2, 3, 4, 5)
b = np.random.randn(5, 6)
out = np.einsum("...j,jk->...k", a, b)
#                 │ │ │    │ │
#                 │ │ │    │ └─── output: (2, 3, 4, 6); k=6
#                 │ │ │    └───── '...' = (2, 3, 4)
#                 │ │ │
#                 │ │ └────────── b = (5, 6); j=5, k=6
#                 │ └──────────── a[-1] = 5; j=5
#                 └────────────── a[:-1] = (2, 3, 4); ...=(2, 3, 4)
print(np.allclose(a @ b, out))

In [None]:
# Puzzle #29
# Transpose the last two dimension of tensor a.
a = np.random.randn(2, 3, 4, 5, 6, 7)
out = np.einsum("<YOUR_ANSWER_HERE>", a)
print(np.allclose(a.transpose(0, 1, 2, 3, 5, 4), out))

In [None]:
# Puzzle #30
a = np.random.randn(3, 4, 5)
b = np.random.randn(5, 6)
out = np.einsum("<YOUR_ANSWER_HERE>", a, b)
print(np.allclose(a @ b, out))

In [203]:
# Check our understanding

# Q7: What are the pros/cons of arrowless subscript?
# Q8: When is ellipsis useful? When should we use it vs. not use it?

# (Answers are available at the end of this notebook.)

In [None]:
# Attention... is all we need.
# Let's implment an attention head with einsum!

# B: batch size
# L: sequence length
# D: model dimension
# H: number of attention heads in a layer
# K: size of each attention key or value
b, l, d, h, k = 16, 10, 32, 4, 8
# In Karpathy's/llm.c's lingo: BLD = BTC


def softmax(x, axis=-1):
  e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
  return e_x / e_x.sum(axis=axis, keepdims=True)


w_q_dk = np.random.randn(d, k)
w_k_dk = np.random.randn(d, k)
w_v_dk = np.random.randn(d, k)
mask = np.where(np.tril(np.ones((l, l))) == 1, 0.0, -np.inf)


def head_np(input_bld):
  q_blk = input_bld @ w_q_dk
  k_blk = input_bld @ w_k_dk
  v_blk = input_bld @ w_v_dk
  sqrt_k = np.sqrt(k_blk.shape[-1])
  scores_bll = (q_blk @ k_blk.transpose(0, 2, 1)) / sqrt_k + mask
  attention_weights_bll = softmax(scores_bll, axis=-1)
  out_blk = attention_weights_bll @ v_blk
  return out_blk


# Puzzle #31
def head(input_bld):
  # <YOUR_ANSWER_HERE> implement a single attention head with einsum.
  pass


input_bld = np.random.randn(b, l, d)
ho = head(input_bld)
ho_np = head_np(input_bld)
print(np.allclose(ho, ho_np))

In [None]:
# Attention... is all we need... part 2.
# Let's implment multi-head attention in einsum!


# B: batch size
# L: sequence length
# D: model dimension
# H: number of attention heads in a layer
# K: size of each attention key or value
b, l, d, h, k = 16, 10, 32, 4, 8
# In Karpathy's/llm.c's lingo: BLD = BTC


w_q_dhk = np.random.randn(d, h, k)
w_k_dhk = np.random.randn(d, h, k)
w_v_dhk = np.random.randn(d, h, k)
w_o_hkd = np.random.randn(h, k, d)

mask = np.where(np.tril(np.ones((l, l))) == 1, 0.0, -np.inf)


def softmax(x, axis=-1):
  e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
  return e_x / e_x.sum(axis=axis, keepdims=True)


def attention_np(input_bld):
  q_blhk = np.dot(input_bld, w_q_dhk.reshape(d, h * k)).reshape(b, l, h, k)
  k_blhk = np.dot(input_bld, w_k_dhk.reshape(d, h * k)).reshape(b, l, h, k)
  v_blhk = np.dot(input_bld, w_v_dhk.reshape(d, h * k)).reshape(b, l, h, k)
  q_bhlk = q_blhk.transpose(0, 2, 1, 3)
  k_bhkl = k_blhk.transpose(0, 2, 3, 1)
  v_bhlk = v_blhk.transpose(0, 2, 1, 3)
  scores_bhll = np.matmul(q_bhlk, k_bhkl) / np.sqrt(k)
  scores_bhll = softmax(scores_bhll + mask, axis=-1)
  out_bhlk = np.matmul(scores_bhll, v_bhlk)
  out_blhk = out_bhlk.transpose(0, 2, 1, 3)
  out_bld = np.dot(out_blhk.reshape(b, l, h * k), w_o_hkd.reshape(h * k, d))
  return out_bld


# Puzzle #32
def attention(input_bld):
  # <YOUR_ANSWER_HERE> implement multi-head attention with einsum.
  pass


input_bld = np.random.randn(b, l, d)
output_np = attention_np(input_bld)
output_einsum = attention(input_bld)
print(np.allclose(output_np, output_einsum))

In [206]:
# Check our understanding: answered

# Q1A: 'i,i->i' has repeated 'i' on inputs, they're multiplied.
#      'i->i' has no repeated inputs.
# Q2A: 'ij->ji' will transpose the matrix.
# Q3A: 'ijk' is not a valid subscript against 2d matrix.
# Q4A: no, the only elementwise operation supported by einsum is multiply.
# Q5A: 'ij,kj->ik' performs a@b.T because the tensor being dotted over
#      is at shape[1] for both matrix, and they're denoted by 'j'. It doens't
#      matter where the mul+sum happens. 'ik' remains after matmul.
# Q6A: np.einsum('ij,jk->ik', a, b.T) transposes b before doing matmul.
#      np.einsum('ij,kj->ik', a, b) handles the transposition for us.
# Q7A: pros of arrowless subscript: shorter; cons: can't transpose.
# Q8A: ellipsis can be useful when working with batches of high-dim tensor.

In [None]:
# Einsum cheatsheet

import numpy as np


def compare(einsum_expr, np_fn, *xs):
  out = np.einsum(einsum_expr, *xs)
  out_np = np_fn(*xs)
  return np.allclose(out, out_np)


# Vector operations
xs = np.random.randn(3)
compare("i", lambda x: x, xs)
compare("i->", lambda x: np.sum(x), xs)

# Matrix operations
xs = np.random.randn(3, 3)
compare("ij->ji", lambda x: x.T, xs)
compare("ij->", lambda x: np.sum(x), xs)
compare("ii->i", lambda x: np.diag(x), xs)
compare("ii->", lambda x: np.trace(x), xs)
compare("ij->j", lambda x: np.sum(x, axis=0), xs)
compare("ij->i", lambda x: np.sum(x, axis=1), xs)

# Elementwise operations with two vectors
xs = np.random.randn(3), np.random.randn(3)
compare("i,i->i", lambda x1, x2: x1 * x2, *xs)
compare("i,i->", lambda x1, x2: np.sum(x1 * x2), *xs)
compare("i,j->ij", lambda x1, x2: np.outer(x1, x2), *xs)
compare("i,j->ji", lambda x1, x2: np.outer(x1, x2).T, *xs)

# Elementwise operations with two matrices
xs = np.random.randn(2, 3, 4), np.random.randn(2, 3, 4)
compare("ijk,ijk->ijk", lambda x1, x2: x1 * x2, *xs)
compare("ijk,ijk->", lambda x1, x2: np.sum(x1 * x2), *xs)
compare("ijk,ijk->ij", lambda x1, x2: np.sum(x1 * x2, axis=2), *xs)
compare("ijk,ijk->k", lambda x1, x2: np.sum(x1 * x2, axis=(0, 1)), *xs)

# Broadcasting vector to matrix
xs = np.random.randn(3, 4), np.random.randn(4)
compare("ij,j->ij", lambda x1, x2: x1 * x2[:, None].T, *xs)
compare("ij,j->i", lambda x1, x2: np.sum(x1 * x2, axis=1), *xs)

# Broadcasting matrix to 3d tensor
xs = np.random.randn(2, 3, 4), np.random.randn(3, 4)
compare("ijk,jk->ijk", lambda x1, x2: x1 * x2, *xs)
compare("ijk,jk->ik", lambda x1, x2: np.sum(x1 * x2, axis=1), *xs)

# Basic ellipsis operations
xs = np.random.randn(2, 3, 4, 5)
compare("...ij->...ji", lambda x: np.swapaxes(x, -2, -1), xs)
compare("i...->...", lambda x: np.sum(x, axis=0), xs)
compare("ij...->...ij", lambda x: np.moveaxis(x, [0, 1], [-2, -1]), xs)

# Ellipsis with broadcasting
xs = np.random.randn(2, 3, 4, 5), np.random.randn(4, 5)
compare("...ij,ij->...i", lambda x1, x2: np.sum(x1 * x2, axis=-1), *xs)

# Element-wise operations and reductions
xs = np.random.randn(3, 2), np.random.randn(3, 2)
compare("ij,ij->ij", lambda x1, x2: x1 * x2, *xs)
compare("ij,ij->", lambda x1, x2: np.sum(x1 * x2), *xs)
compare("ij,ij->j", lambda x1, x2: np.sum(x1 * x2, axis=0), *xs)
compare("ij,ij->i", lambda x1, x2: np.sum(x1 * x2, axis=1), *xs)

# Matrix multiplications
xs = np.random.randn(3, 2), np.random.randn(3, 2)
compare("ki,kj->ij", lambda x1, x2: x1.T @ x2, *xs)
compare("ik,jk->ij", lambda x1, x2: x1 @ x2.T, *xs)
compare("ik,jk->", lambda x1, x2: np.sum(x1 @ x2.T), *xs)
compare("ik,jk->j", lambda x1, x2: np.sum(x1 @ x2.T, axis=0), *xs)
compare("ik,jk->i", lambda x1, x2: np.sum(x1 @ x2.T, axis=1), *xs)

# Combining summation, multiplication, transposition
xs = np.random.randn(2, 3, 4), np.random.randn(2, 3, 4)
compare("ijk,ijk->ijk", lambda x1, x2: x1 * x2, *xs)
compare("ijk,ijk->", lambda x1, x2: np.sum(x1 * x2), *xs)
compare("ijk,ijk->ij", lambda x1, x2: np.sum(x1 * x2, axis=2), *xs)
compare("ijk,ijk->k", lambda x1, x2: np.sum(x1 * x2, axis=(0, 1)), *xs)
compare("ikl,jkl->ij", lambda x1, x2: np.tensordot(x1, x2, axes=([1, 2], [1, 2])), *xs)
compare("ijm,lkm->", lambda x1, x2: np.sum(np.tensordot(x1, x2, axes=([2], [2]))), *xs)
compare("ikl,jkl->j", lambda x1, x2: np.sum(np.tensordot(x1, x2, axes=([1, 2], [1, 2])),axis=0), *xs)  # fmt: skip
compare("ijm,lkm->ijkl", lambda x1, x2: np.transpose(np.tensordot(x1, x2, axes=([2], [2])), (0, 1, 3, 2)), *xs)  # fmt: skip
compare("ijm,lkm->il", lambda x1, x2: np.sum(np.transpose(np.tensordot(x1, x2, axes=([2], [2])), (0, 1, 3, 2)), axis=(1,2)), *xs)  # fmt: skip